Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 61 additions & 55 deletions Shared/Middleware/CorrelationIdMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,87 +16,91 @@

namespace Shared.Middleware;

public class TenantMiddleware
{
#region Fields
public class TenantMiddleware
{
#region Fields

private readonly RequestDelegate Next;
private readonly RequestDelegate Next;

#endregion
#endregion

#region Constructors
#region Constructors

public TenantMiddleware(RequestDelegate next) {
this.Next = next;
}
public TenantMiddleware(RequestDelegate next)
{
this.Next = next;
}

public const String KeyNameCorrelationId = "correlationId";
public const String KeyNameCorrelationId = "correlationId";

#endregion
#endregion

public async Task InvokeAsync(HttpContext context, TenantContext tenantContext)
{
Stopwatch watch = Stopwatch.StartNew();
public async Task InvokeAsync(HttpContext context, TenantContext tenantContext)
{
Stopwatch watch = Stopwatch.StartNew();

// Detect the tenant from the incoming request
TenantIdentifiers tenantIdentifiers = await this.GetIdentifiersFromContext(context);
// Detect the tenant from the incoming request
TenantIdentifiers tenantIdentifiers = await this.GetIdentifiersFromContext(context);

Boolean.TryParse(ConfigurationReader.GetValueOrDefault("AppSettings","LogsPerTenantEnabled", "false"), out Boolean logPerTenantEnabled);
Boolean.TryParse(ConfigurationReader.GetValueOrDefault("AppSettings", "LogsPerTenantEnabled", "false"), out Boolean logPerTenantEnabled);

// Check the headers for a correlationId
context.Request.Headers.TryGetValue(KeyNameCorrelationId, out StringValues correlationIdHeader);
Guid.TryParse(correlationIdHeader, out Guid correlationId);
// Check the headers for a correlationId
context.Request.Headers.TryGetValue(KeyNameCorrelationId, out StringValues correlationIdHeader);
Guid.TryParse(correlationIdHeader, out Guid correlationId);

if (correlationId != Guid.Empty)
{
tenantContext.SetCorrelationId(correlationId);
context.Items[KeyNameCorrelationId] = correlationId.ToString(); // make it accessible to HttpClient handlers
}

tenantContext.Initialise(tenantIdentifiers, logPerTenantEnabled);
if (correlationId != Guid.Empty)
{
tenantContext.SetCorrelationId(correlationId);
context.Items[KeyNameCorrelationId] = correlationId.ToString(); // make it accessible to HttpClient handlers
}

// Set the current tenant in the TenantContext
TenantContext.CurrentTenant = tenantContext;
tenantContext.Initialise(tenantIdentifiers, logPerTenantEnabled);

String clientIp = context.Connection.RemoteIpAddress?.ToString();
// Set the current tenant in the TenantContext
TenantContext.CurrentTenant = tenantContext;

//Makes sense to start our correlation audit trace here
String logMessage = $"Receiving from {clientIp} => {context.Request.Method} {context.Request.Host}{context.Request.Path}";
String clientIp = context.Connection.RemoteIpAddress?.ToString();

Logger.Logger.LogInformation(logMessage);
//Makes sense to start our correlation audit trace here
String logMessage = $"Receiving from {clientIp} => {context.Request.Method} {context.Request.Host}{context.Request.Path}";

// Call the next middleware
await this.Next(context);
Logger.Logger.LogInformation(logMessage);

watch.Stop();
String afterMessage = $"{context.Response.StatusCode} {logMessage} Duration: {watch.ElapsedMilliseconds}ms";
Logger.Logger.LogInformation(afterMessage);
}
// Call the next middleware
await this.Next(context);

private async Task<TenantIdentifiers> GetIdentifiersFromContext(HttpContext context) =>
context switch
{
_ when context.GetIdentifiersFromToken() is var identifiersFromToken && identifiersFromToken != TenantIdentifiers.Default() => identifiersFromToken,
_ when context.GetIdentifiersFromHeaders() is var identifiersFromHeaders && identifiersFromHeaders != TenantIdentifiers.Default() => identifiersFromHeaders,
_ when context.GetIdentifiersFromRoute() is var identifiersFromHeaders && identifiersFromHeaders != TenantIdentifiers.Default() => identifiersFromHeaders,
//_ when await context.GetIdentifiersFromPayload() is var identifiersFromPayload && identifiersFromPayload != TenantIdentifiers.Default() =>
//identifiersFromPayload,
_ => TenantIdentifiers.Default(),
};
watch.Stop();
String afterMessage = $"{context.Response.StatusCode} {logMessage} Duration: {watch.ElapsedMilliseconds}ms";
Logger.Logger.LogInformation(afterMessage);
}

private async Task<TenantIdentifiers> GetIdentifiersFromContext(HttpContext context) =>
context switch
{
_ when context.GetIdentifiersFromToken() is var identifiersFromToken && identifiersFromToken != TenantIdentifiers.Default() => identifiersFromToken,
_ when context.GetIdentifiersFromHeaders() is var identifiersFromHeaders && identifiersFromHeaders != TenantIdentifiers.Default() => identifiersFromHeaders,
_ when context.GetIdentifiersFromRoute() is var identifiersFromHeaders && identifiersFromHeaders != TenantIdentifiers.Default() => identifiersFromHeaders,
//_ when await context.GetIdentifiersFromPayload() is var identifiersFromPayload && identifiersFromPayload != TenantIdentifiers.Default() =>
//identifiersFromPayload,
_ => TenantIdentifiers.Default(),
};
}


public static class ClaimsPrincipalExtensions
{
public static Boolean IsAuthenticated(this ClaimsPrincipal principal)
{
return principal?.Identity?.IsAuthenticated ?? false;
return principal?.Identity?.IsAuthenticated == true;
}
}

public static class HttpContextExtensionMethods {
public static TenantIdentifiers GetIdentifiersFromToken(this HttpContext context) {
if (!context.User.IsAuthenticated()) {
public static class HttpContextExtensionMethods
{
public static TenantIdentifiers GetIdentifiersFromToken(this HttpContext context)
{
if (!context.User.IsAuthenticated())
{
return TenantIdentifiers.Default();
}

Expand All @@ -109,7 +113,8 @@ public static TenantIdentifiers GetIdentifiersFromToken(this HttpContext context
return estateId == Guid.Empty ? TenantIdentifiers.Default() : new TenantIdentifiers(estateId, merchantId);
}

public static TenantIdentifiers GetIdentifiersFromHeaders(this HttpContext context) {
public static TenantIdentifiers GetIdentifiersFromHeaders(this HttpContext context)
{
// Get the org Id
context.Request.Headers.TryGetValue("estateId", out StringValues estateIdHeader);
Guid.TryParse(estateIdHeader, out Guid estateId);
Expand All @@ -121,7 +126,8 @@ public static TenantIdentifiers GetIdentifiersFromHeaders(this HttpContext conte
return estateId == Guid.Empty ? TenantIdentifiers.Default() : new TenantIdentifiers(estateId, merchantId);
}

public static TenantIdentifiers GetIdentifiersFromRoute(this HttpContext context) {
public static TenantIdentifiers GetIdentifiersFromRoute(this HttpContext context)
{
// Get the org Id

context.Request.RouteValues.TryGetValue("estateId", out object estateIdRouteValue);
Expand Down
Loading