Skip to content

Commit

Permalink
Update MSI Helper Service to use a cert instead of secret (#4777)
Browse files Browse the repository at this point in the history
initial

Co-authored-by: Gladwin Johnson <gljohns@microsoft.com>
  • Loading branch information
gladjohn and GladwinJohnson committed May 21, 2024
1 parent 1fdd371 commit 732b1fd
Showing 1 changed file with 81 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Text;
using System.Text.Json.Serialization;
using System.Web;
using System.Security.Cryptography.X509Certificates;

namespace MSIHelperService.Helper
{
Expand All @@ -38,6 +39,7 @@ internal static class MSIHelper
internal static readonly string? s_azureArcWebhookLocation = Environment.GetEnvironmentVariable("AzureArcWebHookLocation");
internal static readonly string? s_oMSAdminClientID = Environment.GetEnvironmentVariable("OMSAdminClientID");
internal static readonly string? s_oMSAdminClientSecret = Environment.GetEnvironmentVariable("OMSAdminClientSecret");
internal static readonly string? s_webAppCertThumbprint = Environment.GetEnvironmentVariable("WebAppCertThumbprint");

//Microsoft authority endpoint
internal const string Authority = "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47";
Expand Down Expand Up @@ -94,7 +96,7 @@ internal enum AzureResource

string[] scopes = new string[] { "https://request.msidlab.com/.default" };

string? token = await GetMSALToken(s_requestAppID, s_requestAppSecret, scopes, logger)
string? token = await GetMSALToken(s_requestAppID, null, GetWebAppCertificate(logger), scopes, logger)
.ConfigureAwait(false);

//clear the default request header for each call
Expand All @@ -105,8 +107,8 @@ internal enum AzureResource

//send the request
HttpResponseMessage result = await httpClient
.GetAsync(s_functionAppUri +
"GetEnvironmentVariables?code="
.GetAsync(s_functionAppUri +
"GetEnvironmentVariables?code="
+ s_functionAppEnvCode)
.ConfigureAwait(false);

Expand Down Expand Up @@ -218,7 +220,7 @@ internal enum AzureResource
//Scopes
string[] scopes = new string[] { "https://request.msidlab.com/.default" };

string? token = await GetMSALToken(s_requestAppID, s_requestAppSecret, scopes, logger)
string? token = await GetMSALToken(s_requestAppID, null, GetWebAppCertificate(logger), scopes, logger)
.ConfigureAwait(false);

//clear the default request header for each call
Expand All @@ -231,7 +233,7 @@ internal enum AzureResource
var encodedUri = HttpUtility.UrlEncode(uri);

HttpResponseMessage result = await httpClient.GetAsync(s_functionAppUri + "getToken?code=" +
s_functionAppMSICode + "&uri=" + encodedUri + "&header=" +identityHeader)
s_functionAppMSICode + "&uri=" + encodedUri + "&header=" + identityHeader)
.ConfigureAwait(false);

string body = await result.Content.ReadAsStringAsync()
Expand Down Expand Up @@ -269,6 +271,7 @@ internal enum AzureResource
string? token = await GetMSALToken(
s_oMSAdminClientID,
s_oMSAdminClientSecret,
null,
scopes,
logger).ConfigureAwait(false);

Expand Down Expand Up @@ -316,7 +319,7 @@ internal enum AzureResource
var errorResponse = ex.Message;

logger.LogError("GetVirtualMachineMSIToken call failed.");

return GetContentResult(errorResponse, "application/json", (int)responseMessage.StatusCode);
}
}
Expand Down Expand Up @@ -348,6 +351,7 @@ internal enum AzureResource
string? token = await GetMSALToken(
s_oMSAdminClientID,
s_oMSAdminClientSecret,
null,
scopes,
logger).ConfigureAwait(false);

Expand Down Expand Up @@ -442,7 +446,7 @@ internal enum AzureResource
return false;
}

logger.LogInformation($"Current Job Status is - { currentJobStatus }.");
logger.LogInformation($"Current Job Status is - {currentJobStatus}.");
}
while (currentJobStatus != "Completed");

Expand Down Expand Up @@ -489,7 +493,7 @@ internal enum AzureResource
if (!string.IsNullOrEmpty(jobId))
{
logger.LogInformation("Job ID retrieved from the Azure Runbook.");
logger.LogInformation($"Job Id is - { jobId }.");
logger.LogInformation($"Job Id is - {jobId}.");
}
else
{
Expand Down Expand Up @@ -526,41 +530,93 @@ internal enum AzureResource
/// </summary>
/// <param name="clientID"></param>
/// <param name="secret"></param>
/// <param name="x509Certificate2"></param>
/// <param name="scopes"></param>
/// <param name="logger"></param>
/// <returns></returns>
private static async Task<string?> GetMSALToken(
string? clientID,
string? secret,
string[] scopes,
string? clientID,
string? secret,
X509Certificate2? x509Certificate2,
string[] scopes,
ILogger logger)
{
logger.LogInformation("GetMSALToken Function called.");

//Confidential Client Application Builder
IConfidentialClientApplication app = ConfidentialClientApplicationBuilder.Create(clientID)
.WithClientSecret(secret)
.WithAuthority(new Uri(Authority))
.WithCacheOptions(CacheOptions.EnableSharedCacheOptions)
.Build();
ConfidentialClientApplicationBuilder builder = ConfidentialClientApplicationBuilder.Create(clientID)
.WithAuthority(new Uri(Authority))
.WithCacheOptions(CacheOptions.EnableSharedCacheOptions);

//Acquire Token For Client using MSAL
// Configure either a client secret or a certificate
if (!string.IsNullOrEmpty(secret))
{
builder.WithClientSecret(secret);
}
else if (x509Certificate2 != null)
{
builder.WithCertificate(x509Certificate2);
}
else
{
logger.LogError("No valid authentication method provided (neither secret nor certificate).");
return null;
}

IConfidentialClientApplication app = builder.Build();

// Acquire Token For Client using MSAL
try
{
AuthenticationResult result = await app.AcquireTokenForClient(scopes)
.ExecuteAsync()
.ConfigureAwait(false);

logger.LogInformation("MSAL Token acquired successfully.");
logger.LogInformation($"MSAL Token source is : { result.AuthenticationResultMetadata.TokenSource }");
logger.LogInformation($"MSAL Token source is: {result.AuthenticationResultMetadata.TokenSource}");

return result.AccessToken;
}
catch (MsalException ex)
{
logger.LogError(ex.Message);
return ex.Message;
logger.LogError($"Failed to acquire token: {ex.Message}");
return null;
}
}

/// <summary>
/// Gets the Web App Certificate
/// </summary>
/// <param name="logger"></param>
/// <returns></returns>
/// <exception cref="Exception"></exception>
private static X509Certificate2 GetWebAppCertificate(ILogger logger)
{
// The thumbprint of the certificate you want to load
string? thumbprint = s_webAppCertThumbprint;

if (string.IsNullOrEmpty(thumbprint))
{
logger.LogError("Thumbprint not found in the environment variables!");
throw new Exception("Unable to load Web App Certificate due to missing thumbprint!");
}

using (X509Store store = new X509Store(StoreName.My, StoreLocation.CurrentUser))
{
store.Open(OpenFlags.ReadOnly);

X509Certificate2Collection certCollection = store.Certificates;

X509Certificate2Collection currentCerts = certCollection.Find(
X509FindType.FindByThumbprint, thumbprint, false);

if (currentCerts.Count > 0)
{
return currentCerts[0];
}
}

logger.LogError("Certificate not found in the Web App Cert Store!");
throw new Exception("Unable to load Web App Certificate!");
}

/// <summary>
Expand All @@ -570,7 +626,7 @@ internal enum AzureResource
/// <param name="httpClient"></param>
/// <returns></returns>
private static void ClearDefaultRequestHeaders(
ILogger logger,
ILogger logger,
HttpClient httpClient)
{
logger.LogInformation("ClearDefaultRequestHeaders Function called.");
Expand Down Expand Up @@ -608,8 +664,8 @@ internal enum AzureResource
/// <param name="statusCode"></param>
/// <returns></returns>
private static ContentResult GetContentResult(
string content,
string contentEncoding,
string content,
string contentEncoding,
int statusCode)
{
return new ContentResult
Expand Down

0 comments on commit 732b1fd

Please sign in to comment.