Skip to content

Commit

Permalink
fixing null casting
Browse files Browse the repository at this point in the history
  • Loading branch information
abdelaziz-mahdy committed Sep 21, 2023
1 parent 8a4c42d commit 9a75e76
Showing 1 changed file with 43 additions and 40 deletions.
83 changes: 43 additions & 40 deletions lib/pytorch_lite.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import 'package:path_provider/path_provider.dart';
import 'package:pytorch_lite/enums/model_type.dart';
import 'package:pytorch_lite/image_utils_isolate.dart';
import 'package:pytorch_lite/pigeon.dart';
import 'package:collection/collection.dart';

export 'enums/dtype.dart';
export 'package:pytorch_lite/pigeon.dart';
Expand Down Expand Up @@ -180,40 +181,43 @@ class ClassificationModel {
}

///predicts image but returns the raw net output
Future<List<double>?> getImagePredictionList(Uint8List imageAsBytes,
Future<List<double>> getImagePredictionList(Uint8List imageAsBytes,
{List<double> mean = torchVisionNormMeanRGB,
List<double> std = torchVisionNormSTDRGB}) async {
// Assert mean std
assert(mean.length == 3, "Mean should have size of 3");
assert(std.length == 3, "STD should have size of 3");
final List<double>? prediction = (await ModelApi().getImagePredictionList(
_index, imageAsBytes, null, null, null, mean, std)) as List<double>?;
final List<double> prediction = (await ModelApi().getImagePredictionList(
_index, imageAsBytes, null, null, null, mean, std))
.whereNotNull()
.toList();
return prediction;
}

///predicts image but returns the output as probabilities
///[image] takes the File of the image
Future<List<double>?> getImagePredictionListProbabilities(
Future<List<double>> getImagePredictionListProbabilities(
Uint8List imageAsBytes,
{List<double> mean = torchVisionNormMeanRGB,
List<double> std = torchVisionNormSTDRGB}) async {
List<double>? prediction =
List<double> prediction =
await getImagePredictionList(imageAsBytes, mean: mean, std: std);

return getProbabilities(prediction!);
return getProbabilities(prediction);
}

///predicts image but returns the raw net output
Future<List<double>?> getImagePredictionListFromBytesList(
Future<List<double>> getImagePredictionListFromBytesList(
List<Uint8List> imageAsBytesList, int imageWidth, int imageHeight,
{List<double> mean = torchVisionNormMeanRGB,
List<double> std = torchVisionNormSTDRGB}) async {
// Assert mean std
assert(mean.length == 3, "Mean should have size of 3");
assert(std.length == 3, "STD should have size of 3");
final List<double>? prediction = await ModelApi().getImagePredictionList(
_index, null, imageAsBytesList, imageWidth, imageHeight, mean, std)
as List<double>?;
final List<double> prediction = (await ModelApi().getImagePredictionList(
_index, null, imageAsBytesList, imageWidth, imageHeight, mean, std))
.whereNotNull()
.toList();
return prediction;
}

Expand All @@ -222,29 +226,29 @@ class ClassificationModel {
List<Uint8List> imageAsBytesList, int imageWidth, int imageHeight,
{List<double> mean = torchVisionNormMeanRGB,
List<double> std = torchVisionNormSTDRGB}) async {
final List<double>? prediction = await getImagePredictionListFromBytesList(
final List<double> prediction = await getImagePredictionListFromBytesList(
imageAsBytesList, imageWidth, imageHeight,
mean: mean, std: std);

int maxScoreIndex = softMax(prediction!);
int maxScoreIndex = softMax(prediction);
return labels[maxScoreIndex];
}

///predicts image but returns the output as probabilities
///[image] takes the File of the image
Future<List<double>?> getImagePredictionListProbabilitiesFromBytesList(
Future<List<double>> getImagePredictionListProbabilitiesFromBytesList(
List<Uint8List> imageAsBytesList, int imageWidth, int imageHeight,
{List<double> mean = torchVisionNormMeanRGB,
List<double> std = torchVisionNormSTDRGB}) async {
final List<double>? prediction = await getImagePredictionListFromBytesList(
final List<double> prediction = await getImagePredictionListFromBytesList(
imageAsBytesList, imageWidth, imageHeight,
mean: mean, std: std);

return getProbabilities(prediction!);
return getProbabilities(prediction);
}

///predicts image but returns the raw net output
Future<List<double>?> getCameraImagePredictionList(
Future<List<double>> getCameraImagePredictionList(
CameraImage cameraImage,
int rotation, {
List<double> mean = torchVisionNormMeanRGB,
Expand Down Expand Up @@ -275,11 +279,11 @@ class ClassificationModel {
List<double> std = torchVisionNormSTDRGB,
PreProcessingMethod preProcessingMethod = PreProcessingMethod.imageLib,
}) async {
final List<double>? prediction = await getCameraImagePredictionList(
final List<double> prediction = await getCameraImagePredictionList(
cameraImage, rotation,
mean: mean, std: std, preProcessingMethod: preProcessingMethod);

int maxScoreIndex = softMax(prediction!);
int maxScoreIndex = softMax(prediction);
return labels[maxScoreIndex];
}

Expand All @@ -291,11 +295,11 @@ class ClassificationModel {
List<double> std = torchVisionNormSTDRGB,
PreProcessingMethod preProcessingMethod = PreProcessingMethod.imageLib,
}) async {
final List<double>? prediction = await getCameraImagePredictionList(
final List<double> prediction = await getCameraImagePredictionList(
cameraImage, rotation,
mean: mean, std: std, preProcessingMethod: preProcessingMethod);

return getProbabilities(prediction!);
return getProbabilities(prediction);
}
}

Expand Down Expand Up @@ -352,16 +356,12 @@ class ModelObjectDetection {
{double minimumScore = 0.5,
double iOUThreshold = 0.5,
int boxesLimit = 10}) async {
final List<ResultObjectDetection> prediction = await ModelApi()
.getImagePredictionListObjectDetection(
_index,
imageAsBytes,
null,
null,
null,
minimumScore,
iOUThreshold,
boxesLimit) as List<ResultObjectDetection>;
final List<ResultObjectDetection> prediction = (await ModelApi()
.getImagePredictionListObjectDetection(_index, imageAsBytes, null,
null, null, minimumScore, iOUThreshold, boxesLimit))
.whereNotNull()
.toList();
;
return prediction;
}

Expand All @@ -371,16 +371,19 @@ class ModelObjectDetection {
{double minimumScore = 0.5,
double iOUThreshold = 0.5,
int boxesLimit = 10}) async {
final List<ResultObjectDetection> prediction = await ModelApi()
.getImagePredictionListObjectDetection(
_index,
null,
imageAsBytesList,
imageWidth,
imageHeight,
minimumScore,
iOUThreshold,
boxesLimit) as List<ResultObjectDetection>;
final List<ResultObjectDetection> prediction = (await ModelApi()
.getImagePredictionListObjectDetection(
_index,
null,
imageAsBytesList,
imageWidth,
imageHeight,
minimumScore,
iOUThreshold,
boxesLimit))
.whereNotNull()
.toList();
;
return prediction;
}

Expand Down

0 comments on commit 9a75e76

Please sign in to comment.