Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
V2 auth is not working for some servers with project.json restore
  • Loading branch information
pranavkm committed Jul 30, 2015
1 parent 0dffd4e commit 81ea724
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 72 deletions.
161 changes: 90 additions & 71 deletions src/NuGet.Protocol.Core.v3/RemoteRepositories/HttpSource.cs
Expand Up @@ -51,99 +51,124 @@ internal async Task<HttpSourceResult> GetAsync(string uri, string cacheKey, Time
}

Logger.LogVerbose(string.Format(CultureInfo.InvariantCulture, " {0} {1}.", "GET", uri));

ICredentials credentials = CredentialStore.Instance.GetCredentials(_baseUri);
retryWithAuthentication:

var messageHandler = await _messageHandlerFactory();
using (var client = new DataClient(messageHandler))
var retry = true;
while (retry)
{
if (credentials != null)
{
messageHandler.Credentials = credentials;
}

var request = new HttpRequestMessage(HttpMethod.Get, uri);
var response = await client.SendAsync(request, cancellationToken);
if (ignoreNotFounds && response.StatusCode == HttpStatusCode.NotFound)
var messageHandler = await _messageHandlerFactory();
using (var client = new DataClient(messageHandler))
{
Logger.LogInformation(string.Format(CultureInfo.InvariantCulture,
" {1} {0} {2}ms", uri, response.StatusCode.ToString(), sw.ElapsedMilliseconds.ToString()));
return new HttpSourceResult();
}
if (credentials != null)
{
messageHandler.Credentials = credentials;
}

if (response.StatusCode == HttpStatusCode.Unauthorized)
{
if (HttpHandlerResourceV3.PromptForCredentials != null)
var request = new HttpRequestMessage(HttpMethod.Get, uri);
STSAuthHelper.PrepareSTSRequest(_baseUri, CredentialStore.Instance, request);
var response = await client.SendAsync(request, cancellationToken);
if (ignoreNotFounds && response.StatusCode == HttpStatusCode.NotFound)
{
credentials = HttpHandlerResourceV3.PromptForCredentials(_baseUri);
Logger.LogInformation(string.Format(CultureInfo.InvariantCulture,
" {1} {0} {2}ms", uri, response.StatusCode.ToString(), sw.ElapsedMilliseconds.ToString()));
return new HttpSourceResult();
}

if (credentials == null)
if (response.StatusCode == HttpStatusCode.Unauthorized)
{
response.EnsureSuccessStatusCode();
if (STSAuthHelper.TryRetrieveSTSToken(_baseUri, CredentialStore.Instance, response))
{
continue;
}

if (HttpHandlerResourceV3.PromptForCredentials != null)
{
credentials = HttpHandlerResourceV3.PromptForCredentials(_baseUri);
}

if (credentials == null)
{
response.EnsureSuccessStatusCode();
}
else
{
continue;
}
}
else

retry = false;
response.EnsureSuccessStatusCode();

if (HttpHandlerResourceV3.CredentialsSuccessfullyUsed != null && credentials != null)
{
client.Dispose();
goto retryWithAuthentication;
HttpHandlerResourceV3.CredentialsSuccessfullyUsed(_baseUri, credentials);
}
}

response.EnsureSuccessStatusCode();
await CreateCacheFile(result, response, cacheAgeLimit, cancellationToken);

if (HttpHandlerResourceV3.CredentialsSuccessfullyUsed != null && credentials != null)
{
HttpHandlerResourceV3.CredentialsSuccessfullyUsed(_baseUri, credentials);
Logger.LogVerbose(string.Format(CultureInfo.InvariantCulture,
" {1} {0} {2}ms", uri, response.StatusCode.ToString(), sw.ElapsedMilliseconds.ToString()));

return result;
}
}

var newFile = result.CacheFileName + "-new";
return result;
}

// Zero value of TTL means we always download the latest package
// So we write to a temp file instead of cache
if (cacheAgeLimit.Equals(TimeSpan.Zero))
{
result.CacheFileName = Path.GetTempFileName();
newFile = Path.GetTempFileName();
}
private static Task CreateCacheFile(
HttpSourceResult result,
HttpResponseMessage response,
TimeSpan cacheAgeLimit,
CancellationToken cancellationToken)
{
var newFile = result.CacheFileName + "-new";

// Zero value of TTL means we always download the latest package
// So we write to a temp file instead of cache
if (cacheAgeLimit.Equals(TimeSpan.Zero))
{
result.CacheFileName = Path.GetTempFileName();
newFile = Path.GetTempFileName();
}

// The update of a cached file is divided into two steps:
// 1) Delete the old file. 2) Create a new file with the same name.
// To prevent race condition among multiple processes, here we use a lock to make the update atomic.
await ConcurrencyUtilities.ExecuteWithFileLocked(result.CacheFileName,
action: async token =>
// The update of a cached file is divided into two steps:
// 1) Delete the old file. 2) Create a new file with the same name.
// To prevent race condition among multiple processes, here we use a lock to make the update atomic.
return ConcurrencyUtilities.ExecuteWithFileLocked(result.CacheFileName,
action: async token =>
{
using (var stream = new FileStream(
newFile,
FileMode.Create,
FileAccess.ReadWrite,
FileShare.ReadWrite | FileShare.Delete,
BufferSize,
useAsync: true))
{
using (var stream = new FileStream(
newFile,
FileMode.Create,
FileAccess.ReadWrite,
FileShare.ReadWrite | FileShare.Delete,
BufferSize,
useAsync: true))
{
await response.Content.CopyToAsync(stream);
await stream.FlushAsync(cancellationToken);
}
await response.Content.CopyToAsync(stream);
await stream.FlushAsync(cancellationToken);
}
if (File.Exists(result.CacheFileName))
{
if (File.Exists(result.CacheFileName))
{
// Process B can perform deletion on an opened file if the file is opened by process A
// with FileShare.Delete flag. However, the file won't be actually deleted until A close it.
// This special feature can cause race condition, so we never delete an opened file.
if (!IsFileAlreadyOpen(result.CacheFileName))
{
File.Delete(result.CacheFileName);
}
{
File.Delete(result.CacheFileName);
}
}
// If the destination file doesn't exist, we can safely perform moving operation.
// Otherwise, moving operation will fail.
if (!File.Exists(result.CacheFileName))
{
File.Move(
{
File.Move(
newFile,
result.CacheFileName);
}
}
// Even the file deletion operation above succeeds but the file is not actually deleted,
// we can still safely read it because it means that some other process just updated it
Expand All @@ -156,15 +181,9 @@ internal async Task<HttpSourceResult> GetAsync(string uri, string cacheKey, Time
BufferSize,
useAsync: true);
return 0;
},
token: cancellationToken);

Logger.LogVerbose(string.Format(CultureInfo.InvariantCulture,
" {1} {0} {2}ms", uri, response.StatusCode.ToString(), sw.ElapsedMilliseconds.ToString()));

return result;
}
return 0;
},
token: cancellationToken);
}

private async Task<HttpSourceResult> TryCache(string uri,
Expand Down
150 changes: 150 additions & 0 deletions src/NuGet.Protocol.Core.v3/RemoteRepositories/STSAuthHelper.cs
@@ -0,0 +1,150 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
#if !DNXCORE50
using System.IdentityModel.Protocols.WSTrust;
using System.IdentityModel.Tokens;
using System.Linq;
#endif
using System.Net;
using System.Net.Http;
#if !DNXCORE50
using System.ServiceModel;
using System.ServiceModel.Security;
#endif
using System.Text;
using NuGet.Configuration;

namespace NuGet.Protocol.Core.v3
{
public class STSAuthHelper
{
/// <summary>
/// Response header that specifies the WSTrust13 Windows Transport endpoint.
/// </summary>
/// <remarks>
/// TODO: Is there a way to discover this \ negotiate this endpoint?
/// </remarks>
private const string STSEndPointHeader = "X-NuGet-STS-EndPoint";

/// <summary>
/// Response header that specifies the realm to authenticate for. In most cases this would be the gallery we are going up against.
/// </summary>
private const string STSRealmHeader = "X-NuGet-STS-Realm";

/// <summary>
/// Request header that contains the SAML token.
/// </summary>
private const string STSTokenHeader = "X-NuGet-STS-Token";

/// <summary>
/// Adds the SAML token as a header to the request if it is already cached for this host.
/// </summary>
public static void PrepareSTSRequest(
Uri feedUri,
CredentialStore credentialStore,
HttpRequestMessage request)
{
#if !DNXCORE50
var credentials = credentialStore.GetCredentials(feedUri) as STSCredentials;

if (credentials != null)
{
request.Headers.TryAddWithoutValidation(STSTokenHeader, credentials.STSToken);
}
#endif
}

/// <summary>
/// Attempts to retrieve a SAML token if the response indicates that server requires STS-based auth.
/// </summary>
public static bool TryRetrieveSTSToken(
Uri feedUri,
CredentialStore credentialStore,
HttpResponseMessage response)
{
#if DNXCORE50
return false;
#else
if (response.StatusCode != HttpStatusCode.Unauthorized)
{
// We only care to do STS auth if the server returned a 401
return false;
}

var endPoint = GetHeader(response, STSEndPointHeader);
var realm = GetHeader(response, STSRealmHeader);
if (string.IsNullOrEmpty(endPoint) || string.IsNullOrEmpty(realm))
{
// The server does not conform to our STS-auth requirements.
return false;
}

var credentials = credentialStore.GetCredentials(feedUri) as STSCredentials;

if (credentials == null)
{
var stsToken = GetSTSToken(feedUri, endPoint, realm);
if (stsToken != null)
{
stsToken = Convert.ToBase64String(Encoding.UTF8.GetBytes(stsToken));
credentials = new STSCredentials(stsToken);
credentialStore.Add(feedUri, credentials);
return true;
}
}

return false;
}

private static string GetSTSToken(Uri requestUri, string endPoint, string realm)
{
var binding = new WS2007HttpBinding(SecurityMode.Transport);
var factory = new WSTrustChannelFactory(binding, endPoint)
{
TrustVersion = TrustVersion.WSTrust13
};

var endPointReference = new EndpointReference(realm);
var requestToken = new RequestSecurityToken
{
RequestType = RequestTypes.Issue,
KeyType = KeyTypes.Bearer,
AppliesTo = endPointReference
};

var channel = factory.CreateChannel();
var responseToken = channel.Issue(requestToken) as GenericXmlSecurityToken;
return responseToken?.TokenXml.OuterXml;
}

private static string GetHeader(HttpResponseMessage response, string header)
{
IEnumerable<string> values;
if (response.Headers.TryGetValues(header, out values))
{
return values.FirstOrDefault();
}

return null;
}

private class STSCredentials : ICredentials
{
public STSCredentials(string stsToken)
{
STSToken = stsToken;
}

public string STSToken { get; }

public NetworkCredential GetCredential(Uri uri, string authType)
{
throw new NotSupportedException();
}
#endif
}
}
}
4 changes: 3 additions & 1 deletion src/NuGet.Protocol.Core.v3/project.json
Expand Up @@ -20,8 +20,10 @@
"net45": {
"frameworkAssemblies": {
"System.Collections.Concurrent": "",
"System.IdentityModel": "",
"System.Net.Http": "",
"System.Net.Http.WebRequest": ""
"System.Net.Http.WebRequest": "",
"System.ServiceModel": ""
}
},
"dnxcore50": {
Expand Down

0 comments on commit 81ea724

Please sign in to comment.