Skip to content

Commit

Permalink
Exploration into better support for custom type handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
mgravell committed Jun 25, 2014
1 parent 8876605 commit e26ee0a
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 14 deletions.
176 changes: 162 additions & 14 deletions Dapper NET40/SqlMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,64 @@ public interface ICustomQueryParameter
void AddParameter(IDbCommand command, string name);
}

/// <summary>
/// Implement this interface to perform custom type-based parameter handling and value parsing
/// </summary>
public interface ITypeHandler
{
/// <summary>
/// Assign the value of a parameter before a command executes
/// </summary>
/// <param name="parameter">The parameter to configure</param>
/// <param name="value">Parameter value</param>
void SetValue(IDbDataParameter parameter, object value);

/// <summary>
/// Parse a database value back to a typed value
/// </summary>
/// <param name="value">The value from the database</param>
/// <param name="destinationType">The type to parse to</param>
/// <returns>The typed value</returns>
object Parse(Type destinationType, object value);
}

/// <summary>
/// Base-class for simple type-handlers
/// </summary>
public abstract class TypeHandler<T> : ITypeHandler
{
/// <summary>
/// Assign the value of a parameter before a command executes
/// </summary>
/// <param name="parameter">The parameter to configure</param>
/// <param name="value">Parameter value</param>
public abstract void SetValue(IDbDataParameter parameter, T value);

/// <summary>
/// Parse a database value back to a typed value
/// </summary>
/// <param name="value">The value from the database</param>
/// <returns>The typed value</returns>
public abstract T Parse(object value);

void ITypeHandler.SetValue(IDbDataParameter parameter, object value)
{
if (value is DBNull)
{
parameter.Value = value;
}
else
{
SetValue(parameter, (T)value);
}
}

object ITypeHandler.Parse(Type destinationType, object value)
{
return Parse(value);
}
}

/// <summary>
/// Implement this interface to change default mapping of reader columns to type memebers
/// </summary>
Expand Down Expand Up @@ -568,10 +626,63 @@ public static void AddTypeMap(Type type, DbType dbType)
typeMap[type] = dbType;
}

/// <summary>
/// Configire the specified type to be processed by a custom handler
/// </summary>
public static void AddTypeHandler(Type type, ITypeHandler handler)
{
if (type == null) throw new ArgumentNullException("type");
#pragma warning disable 618
typeof(TypeHandlerCache<>).MakeGenericType(type).GetMethod("SetHandler", BindingFlags.Static | BindingFlags.NonPublic).Invoke(null, new object[] { handler });
#pragma warning restore 618
if (handler == null) typeHandlers.Remove(type);
else typeHandlers[type] = handler;

}

/// <summary>
/// Not intended for direct usage
/// </summary>
[Obsolete("Not intended for direct usage", false)]
[Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
public static class TypeHandlerCache<T>
{
/// <summary>
/// Not intended for direct usage
/// </summary>
[Obsolete("Not intended for direct usage", true)]
public static T Parse(object value)
{
return (T)handler.Parse(typeof(T), value);

}

/// <summary>
/// Not intended for direct usage
/// </summary>
[Obsolete("Not intended for direct usage", true)]
public static void SetValue(IDbDataParameter parameter, object value)
{
handler.SetValue(parameter, value);
}

internal static void SetHandler(ITypeHandler handler)
{
#pragma warning disable 618
TypeHandlerCache<T>.handler = handler;
#pragma warning restore 618
}

private static ITypeHandler handler;
}

private static readonly Dictionary<Type, ITypeHandler> typeHandlers = new Dictionary<Type, ITypeHandler>();

internal const string LinqBinary = "System.Data.Linq.Binary";
internal static DbType LookupDbType(Type type, string name)
internal static DbType LookupDbType(Type type, string name, out ITypeHandler handler)
{
DbType dbType;
handler = null;
var nullUnderlyingType = Nullable.GetUnderlyingType(type);
if (nullUnderlyingType != null) type = nullUnderlyingType;
if (type.IsEnum && !typeMap.ContainsKey(type))
Expand All @@ -591,8 +702,11 @@ internal static DbType LookupDbType(Type type, string name)
return DynamicParameters.EnumerableMultiParameter;
}


throw new NotSupportedException(string.Format("The member {0} of type {1} cannot be used as a parameter value", name, type));
if (typeHandlers.TryGetValue(type, out handler))
{
return DbType.Object;
}
throw new NotSupportedException(string.Format("The member {0} of type {1} cannot be used as a parameter value", name, type));
}


Expand Down Expand Up @@ -2412,13 +2526,14 @@ internal static IList<LiteralToken> GetLiteralTokens(string sql)
if (typeof(ICustomQueryParameter).IsAssignableFrom(prop.PropertyType))
{
il.Emit(OpCodes.Ldloc_0); // stack is now [parameters] [typed-param]
il.Emit(OpCodes.Callvirt, prop.GetGetMethod()); // stack is [parameters] [dbstring]
il.Emit(OpCodes.Ldarg_0); // stack is now [parameters] [dbstring] [command]
il.Emit(OpCodes.Ldstr, prop.Name); // stack is now [parameters] [dbstring] [command] [name]
il.Emit(OpCodes.Callvirt, prop.GetGetMethod()); // stack is [parameters] [custom]
il.Emit(OpCodes.Ldarg_0); // stack is now [parameters] [custom] [command]
il.Emit(OpCodes.Ldstr, prop.Name); // stack is now [parameters] [custom] [command] [name]
il.EmitCall(OpCodes.Callvirt, prop.PropertyType.GetMethod("AddParameter"), null); // stack is now [parameters]
continue;
}
DbType dbType = LookupDbType(prop.PropertyType, prop.Name);
ITypeHandler handler;
DbType dbType = LookupDbType(prop.PropertyType, prop.Name, out handler);
if (dbType == DynamicParameters.EnumerableMultiParameter)
{
// this actually represents special handling for list types;
Expand Down Expand Up @@ -2452,7 +2567,7 @@ internal static IList<LiteralToken> GetLiteralTokens(string sql)
il.Emit(OpCodes.Ldstr, prop.Name); // stack is now [parameters] [parameters] [parameter] [parameter] [name]
il.EmitCall(OpCodes.Callvirt, typeof(IDataParameter).GetProperty("ParameterName").GetSetMethod(), null);// stack is now [parameters] [parameters] [parameter]
}
if (dbType != DbType.Time) // https://connect.microsoft.com/VisualStudio/feedback/details/381934/sqlparameter-dbtype-dbtype-time-sets-the-parameter-to-sqldbtype-datetime-instead-of-sqldbtype-time
if (dbType != DbType.Time && handler == null) // https://connect.microsoft.com/VisualStudio/feedback/details/381934/sqlparameter-dbtype-dbtype-time-sets-the-parameter-to-sqldbtype-datetime-instead-of-sqldbtype-time
{
il.Emit(OpCodes.Dup);// stack is now [parameters] [[parameters]] [parameter] [parameter]
EmitInt32(il, (int)dbType);// stack is now [parameters] [[parameters]] [parameter] [parameter] [db-type]
Expand Down Expand Up @@ -2520,7 +2635,17 @@ internal static IList<LiteralToken> GetLiteralTokens(string sql)
if (allDone != null) il.MarkLabel(allDone.Value);
// relative stack [boxed value or DBNull]
}
il.EmitCall(OpCodes.Callvirt, typeof(IDataParameter).GetProperty("Value").GetSetMethod(), null);// stack is now [parameters] [[parameters]] [parameter]

if (handler != null)
{
#pragma warning disable 618
il.Emit(OpCodes.Call, typeof(TypeHandlerCache<>).MakeGenericType(prop.PropertyType).GetMethod("SetValue")); // stack is now [parameters] [[parameters]] [parameter]
#pragma warning restore 618
}
else
{
il.EmitCall(OpCodes.Callvirt, typeof(IDataParameter).GetProperty("Value").GetSetMethod(), null);// stack is now [parameters] [[parameters]] [parameter]
}

if (prop.PropertyType == typeof(string))
{
Expand Down Expand Up @@ -2739,6 +2864,15 @@ private static IDataReader ExecuteReaderImpl(IDbConnection cnn, ref CommandDefin
return val is DBNull ? null : Enum.ToObject(effectiveType, val);
};
}
ITypeHandler handler;
if(typeHandlers.TryGetValue(type, out handler))
{
return r =>
{
var val = r.GetValue(index);
return val is DBNull ? null : handler.Parse(type, val);
};
}
return r =>
{
var val = r.GetValue(index);
Expand Down Expand Up @@ -2980,7 +3114,16 @@ public static void SetTypeMap(Type type, ITypeMap map)
TypeCode dataTypeCode = Type.GetTypeCode(dataType), unboxTypeCode = Type.GetTypeCode(unboxType);
if (dataType == unboxType || dataTypeCode == unboxTypeCode || dataTypeCode == Type.GetTypeCode(nullUnderlyingType))
{
il.Emit(OpCodes.Unbox_Any, unboxType); // stack is now [target][target][typed-value]
if (typeHandlers.ContainsKey(unboxType))
{
#pragma warning disable 618
il.EmitCall(OpCodes.Call, typeof(TypeHandlerCache<>).MakeGenericType(unboxType).GetMethod("Parse"), null); // stack is now [target][target][typed-value]
#pragma warning restore 618
}
else
{
il.Emit(OpCodes.Unbox_Any, unboxType); // stack is now [target][target][typed-value]
}
}
else
{
Expand Down Expand Up @@ -3745,8 +3888,8 @@ protected void AddParameters(IDbCommand command, SqlMapper.Identity identity)
string name = Clean(param.Name);
var isCustomQueryParameter = val is SqlMapper.ICustomQueryParameter;

if (dbType == null && val != null && !isCustomQueryParameter) dbType = SqlMapper.LookupDbType(val.GetType(), name);

SqlMapper.ITypeHandler handler = null;
if (dbType == null && val != null && !isCustomQueryParameter) dbType = SqlMapper.LookupDbType(val.GetType(), name, out handler);
if (dbType == DynamicParameters.EnumerableMultiParameter)
{
#pragma warning disable 612, 618
Expand All @@ -3771,8 +3914,13 @@ protected void AddParameters(IDbCommand command, SqlMapper.Identity identity)
{
p = (IDbDataParameter)command.Parameters[name];
}

p.Value = val ?? DBNull.Value;
if (handler == null)
{
p.Value = val ?? DBNull.Value;
} else
{
handler.SetValue(p, val ?? DBNull.Value);
}
p.Direction = param.ParameterDirection;
var s = val as string;
if (s != null)
Expand Down
4 changes: 4 additions & 0 deletions Tests/DapperTests NET40.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
<Reference Include="LinFu.DynamicProxy">
<HintPath>NHibernate\LinFu.DynamicProxy.dll</HintPath>
</Reference>
<Reference Include="Microsoft.SqlServer.Types, Version=11.0.0.0, Culture=neutral, PublicKeyToken=89845dcd8080cc91, processorArchitecture=MSIL">
<SpecificVersion>False</SpecificVersion>
<HintPath>C:\Program Files (x86)\Microsoft SQL Server\100\SDK\Assemblies\Microsoft.SqlServer.Types.dll</HintPath>
</Reference>
<Reference Include="Mono.Security">
<HintPath>..\packages\Npgsql.2.0.11\lib\Net40\Mono.Security.dll</HintPath>
</Reference>
Expand Down
64 changes: 64 additions & 0 deletions Tests/Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
using System.Data.Common;
using System.Globalization;
using System.Threading;
using System.Data.Entity.Spatial;
using Microsoft.SqlServer.Types;
#if POSTGRESQL
using Npgsql;
#endif
Expand Down Expand Up @@ -2885,6 +2887,68 @@ public void SupportInit()
obj.Flags.Equals(31);
}

public void GuidIn_SO_24177902()
{
// invent and populate
Guid a = Guid.NewGuid(), b = Guid.NewGuid(), c = Guid.NewGuid(), d = Guid.NewGuid();
connection.Execute("create table #foo (i int, g uniqueidentifier)");
connection.Execute("insert #foo(i,g) values(@i,@g)",
new[] { new { i = 1, g = a }, new { i = 2, g = b },
new { i = 3, g = c },new { i = 4, g = d }});

// check that rows 2&3 yield guids b&c
var guids = connection.Query<Guid>("select g from #foo where i in (2,3)").ToArray();
guids.Length.Equals(2);
guids.Contains(a).Equals(false);
guids.Contains(b).Equals(true);
guids.Contains(c).Equals(true);
guids.Contains(d).Equals(false);

// in query on the guids
var rows = connection.Query("select * from #foo where g in @guids order by i", new { guids })
.Select(row => new { i = (int)row.i, g = (Guid)row.g }).ToArray();
rows.Length.Equals(2);
rows[0].i.Equals(2);
rows[0].g.Equals(b);
rows[1].i.Equals(3);
rows[1].g.Equals(c);
}

class HazGeo
{
public int Id { get;set; }
public DbGeography Geo { get; set; }
}
public void DBGeography_SO24405645_SO24402424()
{
global::Dapper.SqlMapper.AddTypeHandler(typeof(DbGeography), new GeographyMapper());
connection.Execute("create table #Geo (id int, geo geography)");

var obj = new HazGeo
{
Id = 1,
Geo = DbGeography.LineFromText("LINESTRING(-122.360 47.656, -122.343 47.656 )", 4326)
};
connection.Execute("insert #Geo(id, geo) values (@Id, @Geo)", obj);
var row = connection.Query<HazGeo>("select * from #Geo where id=1").SingleOrDefault();
row.IsNotNull();
row.Id.IsEqualTo(1);
row.Geo.IsNotNull();
}

class GeographyMapper : Dapper.SqlMapper.TypeHandler<DbGeography>
{
public override void SetValue(IDbDataParameter parameter, DbGeography value)
{
parameter.Value = value == null ? (object)DBNull.Value : (object)SqlGeography.Parse(value.AsText());
((SqlParameter)parameter).UdtTypeName = "GEOGRAPHY";
}
public override DbGeography Parse(object value)
{
return (value == null || value is DBNull) ? null : DbGeography.FromText(value.ToString());
}
}

class WithInit : ISupportInitialize
{
public string Value { get; set; }
Expand Down

0 comments on commit e26ee0a

Please sign in to comment.