Skip to content

Commit

Permalink
Add option to switch FaceAPI prediction mode
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w committed Aug 14, 2018
1 parent 1b97335 commit 36fb41b
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 11 deletions.
31 changes: 25 additions & 6 deletions algorithms/FaceApi/Cli.cs
Expand Up @@ -15,13 +15,13 @@ static void Main(string[] args)
static async Task MainAsync(string[] args)
{
var settings = new Settings(args);
if (!settings.TryParse(out string apiKey, out string apiEndpoint, out double matchThreshold))
if (!settings.TryParse(out string apiKey, out string apiEndpoint, out double matchThreshold, out PredictionMode predictionMode))
{
await Console.Error.WriteLineAsync("Missing api-key and api-endpoint settings");
return;
}

var faceIdentifier = new FaceIdentifier(apiKey, apiEndpoint);
var faceIdentifier = new FaceIdentifier(apiKey, apiEndpoint, predictionMode);

if (settings.TryParseForTraining(out string trainSetRoot))
{
Expand Down Expand Up @@ -63,26 +63,31 @@ static async Task MainAsync(string[] args)

class Settings
{
private static readonly PredictionMode DefaultPredictionMode = PredictionMode.Identify;
private static readonly double DefaultMatchThreshold = 0.6;

private string[] Args { get; }

public Settings(string[] args)
{
Args = args;
}

public bool TryParse(out string apiKey, out string apiEndpoint, out double matchThreshold)
public bool TryParse(out string apiKey, out string apiEndpoint, out double matchThreshold, out PredictionMode predictionMode)
{
if (ApiKey == null || ApiEndpoint == null)
{
apiKey = null;
apiEndpoint = null;
matchThreshold = -1;
matchThreshold = DefaultMatchThreshold;
predictionMode = DefaultPredictionMode;
return false;
}

apiKey = ApiKey;
apiEndpoint = ApiEndpoint;
matchThreshold = MatchThreshold;
predictionMode = PredictionMode;
return true;
}

Expand Down Expand Up @@ -164,18 +169,32 @@ private double MatchThreshold
{
get
{
const double defaultMatchThreshold = 0.6;
var matchThreshold = Environment.GetEnvironmentVariable("FACE_API_MATCH_THRESHOLD");

if (matchThreshold == null || !double.TryParse(matchThreshold, out double parsedMatchThreshold))
{
return defaultMatchThreshold;
return DefaultMatchThreshold;
}

return parsedMatchThreshold;
}
}

private PredictionMode PredictionMode
{
get
{
var predictionMode = Environment.GetEnvironmentVariable("FACE_API_PREDICTION_MODE");

if (predictionMode == null || !Enum.TryParse(typeof(PredictionMode), predictionMode, out object parsedPredictionMode))
{
return DefaultPredictionMode;
}

return (PredictionMode)parsedPredictionMode;
}
}

private string GroupId
{
get => Environment.GetEnvironmentVariable("FACE_API_GROUP_ID");
Expand Down
41 changes: 36 additions & 5 deletions algorithms/FaceApi/FaceIdentifier.cs
Expand Up @@ -17,15 +17,18 @@ public class FaceIdentifier

private TimeLimiter RateLimit { get; }
private FaceClient Client { get; }
private PredictionMode PredictionMode { get; }

public FaceIdentifier(string apiKey, string apiEndpoint)
public FaceIdentifier(string apiKey, string apiEndpoint, PredictionMode predictionMode)
{
Client = new FaceClient(new ApiKeyServiceClientCredentials(apiKey))
{
Endpoint = apiEndpoint
};

RateLimit = TimeLimiter.GetFromMaxCountByInterval(RateLimitRequests, RateLimitInterval);

PredictionMode = predictionMode;
}

public async Task<bool> Predict(string groupId, double matchThreshold, string imagePath1, string imagePath2)
Expand All @@ -39,11 +42,17 @@ public async Task<bool> Predict(string groupId, double matchThreshold, string im
return false;
}

var allPeople = await IdentifyPeople(faces1.Concat(faces2), groupId, matchThreshold);
var people1 = allPeople.Where(candidates => faces1.Contains(candidates.FaceId)).SelectMany(candidates => candidates.Candidates).Select(person => person.PersonId).ToHashSet();
var people2 = allPeople.Where(candidates => faces2.Contains(candidates.FaceId)).SelectMany(candidates => candidates.Candidates).Select(person => person.PersonId).ToHashSet();
switch (PredictionMode)
{
case PredictionMode.Identify:
return await PredictWithIdentify(groupId, matchThreshold, faces1, faces2);

return people1.Any(person => people2.Contains(person));
case PredictionMode.Verify:
return await PredictWithVerify(matchThreshold, faces1, faces2);

default:
throw new NotImplementedException($"{PredictionMode}");
}
}

public async Task<string> Train(string trainSetRoot)
Expand Down Expand Up @@ -171,5 +180,27 @@ private async Task<PersistedFace> AddFace(string groupId, Guid personId, string
});
}
}

private async Task<bool> PredictWithIdentify(string groupId, double matchThreshold, IList<Guid> faces1, IList<Guid> faces2)
{
var allPeople = await IdentifyPeople(faces1.Concat(faces2), groupId, matchThreshold);
var people1 = allPeople.Where(candidates => faces1.Contains(candidates.FaceId)).SelectMany(candidates => candidates.Candidates).Select(person => person.PersonId).ToHashSet();
var people2 = allPeople.Where(candidates => faces2.Contains(candidates.FaceId)).SelectMany(candidates => candidates.Candidates).Select(person => person.PersonId).ToHashSet();

return people1.Any(person => people2.Contains(person));
}

private async Task<bool> PredictWithVerify(double matchThreshold, IList<Guid> faces1, IList<Guid> faces2)
{
var verifications = await Task.WhenAll(faces1.SelectMany(face1 => faces2.Select(face2 => Client.Face.VerifyFaceToFaceAsync(face1, face2))));

return verifications.Any(result => result.Confidence >= matchThreshold);
}
}

public enum PredictionMode
{
Identify,
Verify
}
}

0 comments on commit 36fb41b

Please sign in to comment.