Skip to content

Commit

Permalink
Add CustomValidationFilter to AuthProvider and auto validate OAuth Pr…
Browse files Browse the repository at this point in the history
…oviders have unique emails
  • Loading branch information
mythz committed Aug 2, 2014
1 parent 0ca2adc commit 095da0c
Showing 1 changed file with 82 additions and 33 deletions.
115 changes: 82 additions & 33 deletions src/ServiceStack/Auth/AuthProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ namespace ServiceStack.Auth
public abstract class AuthProvider : IAuthProvider
{
protected static readonly ILog Log = LogManager.GetLogger(typeof(AuthProvider));
public static bool ValidateUniqueEmails = true; //Temporary, remove later when no issues.

public TimeSpan SessionExpiry { get; set; }
public string AuthRealm { get; set; }
public string Provider { get; set; }
public string CallbackUrl { get; set; }
public string RedirectUrl { get; set; }

public Action<AuthUserSession, IAuthTokens, Dictionary<string, string>> LoadUserAuthFilter { get; set; }
public Action<AuthUserSession, IAuthTokens, Dictionary<string, string>> LoadUserAuthFilter { get; set; }

public Func<AuthContext, IHttpResult> CustomValidationFilter { get; set; }

protected AuthProvider()
{
Expand Down Expand Up @@ -139,47 +142,72 @@ public virtual IHttpResult OnAuthenticated(IServiceBase authService, IAuthSessio
authInfo.ForEach((x, y) => tokens.Items[x] = y);
}

var authRepo = authService.TryResolve<IAuthRepository>();
if (authRepo != null)
try
{
var failed = ValidateAccount(authService, authRepo, session, tokens);
if (failed != null)
return failed;
var authRepo = authService.TryResolve<IAuthRepository>();

if (hasTokens)
if (CustomValidationFilter != null)
{
session.UserAuthId = authRepo.CreateOrMergeAuthSession(session, tokens);
var ctx = new AuthContext
{
Service = authService,
AuthProvider = this,
Session = session,
AuthTokens = tokens,
AuthInfo = authInfo,
AuthRepository = authRepo,
};
var response = CustomValidationFilter(ctx);
if (response != null)
{
session.IsAuthenticated = false;
authService.SaveSession(session, SessionExpiry);
return response;
}
}

authRepo.LoadUserAuth(session, tokens);

foreach (var oAuthToken in session.ProviderOAuthAccess)
if (authRepo != null)
{
var authProvider = AuthenticateService.GetAuthProvider(oAuthToken.Provider);
if (authProvider == null) continue;
var userAuthProvider = authProvider as OAuthProvider;
if (userAuthProvider != null)
var failed = ValidateAccount(authService, authRepo, session, tokens);
if (failed != null)
{
userAuthProvider.LoadUserOAuthProvider(session, oAuthToken);
session.IsAuthenticated = false;
authService.SaveSession(session, SessionExpiry);
return failed;
}
}

var httpRes = authService.Request.Response as IHttpResponse;
if (session.UserAuthId != null && httpRes != null)
{
httpRes.Cookies.AddPermanentCookie(HttpHeaders.XUserAuthId, session.UserAuthId);
if (hasTokens)
{
session.UserAuthId = authRepo.CreateOrMergeAuthSession(session, tokens);
}

authRepo.LoadUserAuth(session, tokens);

foreach (var oAuthToken in session.ProviderOAuthAccess)
{
var authProvider = AuthenticateService.GetAuthProvider(oAuthToken.Provider);
if (authProvider == null) continue;
var userAuthProvider = authProvider as OAuthProvider;
if (userAuthProvider != null)
{
userAuthProvider.LoadUserOAuthProvider(session, oAuthToken);
}
}

var httpRes = authService.Request.Response as IHttpResponse;
if (session.UserAuthId != null && httpRes != null)
{
httpRes.Cookies.AddPermanentCookie(HttpHeaders.XUserAuthId, session.UserAuthId);
}
}
}
else
{
if (hasTokens)
else
{
session.UserAuthId = CreateOrMergeAuthSession(session, tokens);
if (hasTokens)
{
session.UserAuthId = CreateOrMergeAuthSession(session, tokens);
}
}
}

try
{
session.IsAuthenticated = true;
session.OnAuthenticated(authService, session, tokens, authInfo);
}
Expand Down Expand Up @@ -222,7 +250,7 @@ public virtual string CreateOrMergeAuthSession(IAuthSession session, IAuthTokens
}

var key = tokens.Provider + ":" + (tokens.UserId ?? tokens.UserName);
return transientUserIdsMap.GetOrAdd(key,
return transientUserIdsMap.GetOrAdd(key,
k => Interlocked.Increment(ref transientUserAuthId)).ToString(CultureInfo.InvariantCulture);
}

Expand Down Expand Up @@ -282,12 +310,23 @@ protected virtual void AssertNotLocked(IUserAuth userAuth)
protected virtual IHttpResult ValidateAccount(IServiceBase authService, IAuthRepository authRepo, IAuthSession session, IAuthTokens tokens)
{
var userAuth = authRepo.GetUserAuth(session, tokens);
var isLocked = userAuth != null && userAuth.LockedDate != null;

if (ValidateUniqueEmails && tokens != null && tokens.Email != null)
{
var userWithEmail = authRepo.GetUserAuthByUserName(tokens.Email);
if (userWithEmail == null) return null;

var isAnotherUser = userAuth == null || (userAuth.Id != userWithEmail.Id);
if (isAnotherUser)
{
return authService.Redirect(session.ReferrerUrl.AddHashParam("f", "EmailAlreadyExists"));
}
}

if (userAuth == null) return null;
var isLocked = userAuth.LockedDate != null;
if (isLocked)
{
session.IsAuthenticated = false;
authService.SaveSession(session, SessionExpiry);
return authService.Redirect(session.ReferrerUrl.AddHashParam("f", "AccountLocked"));
}

Expand All @@ -312,6 +351,16 @@ protected virtual string GetReferrerUrl(IServiceBase authService, IAuthSession s
}
}

public class AuthContext
{
public IServiceBase Service { get; set; }
public AuthProvider AuthProvider { get; set; }
public IAuthSession Session { get; set; }
public IAuthTokens AuthTokens { get; set; }
public Dictionary<string, string> AuthInfo { get; set; }
public IAuthRepository AuthRepository { get; set; }
}

public static class AuthExtensions
{
public static bool IsAuthorizedSafe(this IAuthProvider authProvider, IAuthSession session, IAuthTokens tokens)
Expand Down

0 comments on commit 095da0c

Please sign in to comment.