Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 12 additions & 49 deletions src/stream-net/StreamClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,31 @@ public class StreamClient : IStreamClient

readonly RestClient _client;
readonly StreamClientOptions _options;
readonly string _apiSecret;
readonly IStreamClientToken _streamClientToken;
readonly string _apiKey;

public StreamClient(string apiKey, string apiSecret, StreamClientOptions options = null)
public StreamClient(string apiKey, string apiSecretOrToken, StreamClientOptions options = null)
{
if (string.IsNullOrWhiteSpace(apiKey))
throw new ArgumentNullException("apiKey", "Must have an apiKey");
if (string.IsNullOrWhiteSpace(apiSecret))
throw new ArgumentNullException("apiSecret", "Must have an apiSecret");
if (string.IsNullOrWhiteSpace(apiSecretOrToken))
throw new ArgumentNullException("apiSecret", "Must have an apiSecret or user session token");

_apiKey = apiKey;
_apiSecret = apiSecret;
_streamClientToken = StreamClientToken.For(apiSecretOrToken);
_options = options ?? StreamClientOptions.Default;
_client = new RestClient(GetBaseUrl(_options.Location), TimeSpan.FromMilliseconds(_options.Timeout));
}

private StreamClient(string apiKey, string apiSecret, RestClient client, StreamClientOptions options = null)
private StreamClient(string apiKey, IStreamClientToken streamClientToken, RestClient client, StreamClientOptions options = null)
{
if (string.IsNullOrWhiteSpace(apiKey))
throw new ArgumentNullException("apiKey", "Must have an apiKey");
if (string.IsNullOrWhiteSpace(apiSecret))
throw new ArgumentNullException("apiSecret", "Must have an apiSecret");
if (streamClientToken is null)
throw new ArgumentNullException("streamClientToken", "Must have a streamClientToken");

_apiKey = apiKey;
_apiSecret = apiSecret;
_streamClientToken = streamClientToken;
_options = options ?? StreamClientOptions.Default;
_client = client;
}
Expand Down Expand Up @@ -95,15 +95,7 @@ public async Task ActivityPartialUpdate(string id = null, ForeignIDTime foreignI

public string CreateUserSessionToken(string userId, IDictionary<string, object> extraData = null)
{
var payload = new Dictionary<string, object>
{
{"user_id", userId}
};
if (extraData != null)
{
extraData.ForEach(x => payload[x.Key] = x.Value);
}
return this.JWToken(payload);
return _streamClientToken.CreateUserSessionToken(userId, extraData);
}

/// <summary>
Expand Down Expand Up @@ -146,7 +138,7 @@ public Personalization Personalization
get
{
var _personalization = new RestClient(GetBasePersonalizationUrl(_options.PersonalizationLocation), TimeSpan.FromMilliseconds(_options.PersonalizationTimeout));
return new Personalization(new StreamClient(_apiKey, _apiSecret, _personalization, _options));
return new Personalization(new StreamClient(_apiKey, _streamClientToken, _personalization, _options));
}
}

Expand Down Expand Up @@ -217,14 +209,6 @@ internal Task<RestResponse> MakeRequest(RestRequest request)
return _client.Execute(request);
}

private static string Base64UrlEncode(byte[] input)
{
return Convert.ToBase64String(input)
.Replace('+', '-')
.Replace('/', '_')
.Trim('=');
}

internal string JWToken(string feedId, string userID = null)
{
var payload = new Dictionary<string, string>()
Expand All @@ -237,28 +221,7 @@ internal string JWToken(string feedId, string userID = null)
{
payload["user_id"] = userID;
}
return this.JWToken(payload);
}

internal string JWToken(object payload)
{
var segments = new List<string>();

byte[] headerBytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(StreamClient.JWTHeader));
byte[] payloadBytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(payload));

segments.Add(Base64UrlEncode(headerBytes));
segments.Add(Base64UrlEncode(payloadBytes));

var stringToSign = string.Join(".", segments.ToArray());
var bytesToSign = Encoding.UTF8.GetBytes(stringToSign);

using (var sha = new HMACSHA256(Encoding.UTF8.GetBytes(_apiSecret)))
{
byte[] signature = sha.ComputeHash(bytesToSign);
segments.Add(Base64UrlEncode(signature));
}
return string.Join(".", segments.ToArray());
return _streamClientToken.For(payload);
}
}
}
97 changes: 97 additions & 0 deletions src/stream-net/StreamClientToken.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Security.Cryptography;
using System.Text;

namespace Stream
{
public interface IStreamClientToken
{
string CreateUserSessionToken(string userId, IDictionary<string, object> extraData = null);

string For(object payload);
}

public static class StreamClientToken
{
public static IStreamClientToken For(string apiSecretOrToken)
{
return apiSecretOrToken.Contains(".")
? (IStreamClientToken) new StreamApiSessionToken(apiSecretOrToken)
: (IStreamClientToken) new StreamApiSecret(apiSecretOrToken);
}
}

public class StreamApiSessionToken : IStreamClientToken
{
private readonly string _sessionToken;

public StreamApiSessionToken(string sessionToken)
{
_sessionToken = sessionToken;
}

public string CreateUserSessionToken(string userId, IDictionary<string, object> extraData = null)
{
throw new InvalidOperationException("Clients connecting using a user session token cannot create additional user session tokens");
}

public string For(object payload)
{
return _sessionToken;
}
}

public class StreamApiSecret : IStreamClientToken
{
private readonly string _apiSecret;

public StreamApiSecret(string apiSecret)
{
_apiSecret = apiSecret;
}

private static string Base64UrlEncode(byte[] input)
{
return Convert.ToBase64String(input)
.Replace('+', '-')
.Replace('/', '_')
.Trim('=');
}

public string CreateUserSessionToken(string userId, IDictionary<string, object> extraData = null)
{
var payload = new Dictionary<string, object>
{
{"user_id", userId}
};
if (extraData != null)
{
extraData.ForEach(x => payload[x.Key] = x.Value);
}
return For(payload);
}

public string For(object payload)
{
var segments = new List<string>();

byte[] headerBytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(StreamClient.JWTHeader));
byte[] payloadBytes = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(payload));

segments.Add(Base64UrlEncode(headerBytes));
segments.Add(Base64UrlEncode(payloadBytes));

var stringToSign = string.Join(".", segments.ToArray());
var bytesToSign = Encoding.UTF8.GetBytes(stringToSign);

using (var sha = new HMACSHA256(Encoding.UTF8.GetBytes(_apiSecret)))
{
byte[] signature = sha.ComputeHash(bytesToSign);
segments.Add(Base64UrlEncode(signature));
}
return string.Join(".", segments.ToArray());
}
}
}