Skip to content

Commit

Permalink
Updated fix for NH-847 moving responsibility to the driver
Browse files Browse the repository at this point in the history
SVN: branches/2.1.x@4805
  • Loading branch information
FlukeFan committed Oct 29, 2009
1 parent 578fa55 commit d7c398a
Show file tree
Hide file tree
Showing 19 changed files with 158 additions and 137 deletions.
43 changes: 34 additions & 9 deletions src/NHibernate.Test/EngineTest/CallableParserFixture.cs
Expand Up @@ -9,32 +9,57 @@ namespace NHibernate.Test.EngineTest
public class CallableParserFixture
{
[Test]
public void CanFindCallableFunctionName()
public void CanDetermineIsCallable()
{
string query = @"{ call myFunction(:name) }";

SqlString sqlFunction = CallableParser.Parse(query);
Assert.That(sqlFunction.ToString(), Is.EqualTo("myFunction"));
CallableParser.Detail detail = CallableParser.Parse(query);
Assert.That(detail.IsCallable, Is.True);
}

[Test]
public void CanDetermineIsNotCallable()
{
string query = @"SELECT id FROM mytable";

Assert.Throws<ParserException>(() =>
{
SqlString sqlFunction = CallableParser.Parse(query);
});
CallableParser.Detail detail = CallableParser.Parse(query);
Assert.That(detail.IsCallable, Is.False);
}

[Test]
public void CanFindCallableFunctionName()
{
string query = @"{ call myFunction(:name) }";

CallableParser.Detail detail = CallableParser.Parse(query);
Assert.That(detail.FunctionName, Is.EqualTo("myFunction"));
}

[Test]
public void CanFindCallableFunctionNameWithoutParameters()
{
string query = @"{ call myFunction }";

SqlString sqlFunction = CallableParser.Parse(query);
Assert.That(sqlFunction.ToString(), Is.EqualTo("myFunction"));
CallableParser.Detail detail = CallableParser.Parse(query);
Assert.That(detail.FunctionName, Is.EqualTo("myFunction"));
}

[Test]
public void CanDetermineHasReturn()
{
string query = @"{ ? = call myFunction(:name) }";

CallableParser.Detail detail = CallableParser.Parse(query);
Assert.That(detail.HasReturn, Is.True);
}

[Test]
public void CanDetermineHasNoReturn()
{
string query = @"{ call myFunction(:name) }";

CallableParser.Detail detail = CallableParser.Parse(query);
Assert.That(detail.HasReturn, Is.False);
}
}
}
17 changes: 0 additions & 17 deletions src/NHibernate.Test/EngineTest/ParameterParserFixture.cs
Expand Up @@ -59,22 +59,5 @@ FROM tablea
Assert.DoesNotThrow(() => p = recognizer.NamedParameterDescriptionMap["pizza"]);
}

[Test]
public void CanRecogniseNoReturnValueParameter()
{
string query = "{ call myFunction(?) }";
ParamLocationRecognizer recognizer = new ParamLocationRecognizer();
ParameterParser.Parse(query, recognizer);
Assert.That(recognizer.HasReturnValue, Is.False);
}

[Test]
public void CanRecogniseReturnValueParameter()
{
string query = "{ ? = call myFunction(?) }";
ParamLocationRecognizer recognizer = new ParamLocationRecognizer();
ParameterParser.Parse(query, recognizer);
Assert.That(recognizer.HasReturnValue, Is.True);
}
}
}
4 changes: 2 additions & 2 deletions src/NHibernate.Test/SqlTest/Custom/Oracle/Mappings.hbm.xml
Expand Up @@ -39,7 +39,7 @@
</id>
<property name="name" not-null="true"/>
<loader query-ref="person"/>
<sql-insert callable="true" check="none">{ call createPerson(?, ?) }</sql-insert>
<sql-insert check="none">{ call createPerson(?, ?) }</sql-insert>
<sql-update>UPDATE PERSON SET NAME=UPPER(?) WHERE PERID=?</sql-update>
<sql-delete>DELETE FROM PERSON WHERE PERID=?</sql-delete>
</class>
Expand Down Expand Up @@ -204,7 +204,7 @@

<database-object>
<create>
CREATE OR REPLACE PROCEDURE createPerson(p_name PERSON.NAME%TYPE, p_id PERSON.PERID%TYPE)
CREATE OR REPLACE PROCEDURE createPerson(unused OUT SYS_REFCURSOR, p_name PERSON.NAME%TYPE, p_id PERSON.PERID%TYPE)
AS
rowcount INTEGER;
BEGIN
Expand Down
Expand Up @@ -8,25 +8,25 @@
namespace="NHibernate.Test.SqlTest"
default-access="field">

<sql-query name="simpleScalar" callable="true">
<sql-query name="simpleScalar">
<return-scalar column="name" type="string"/>
<return-scalar column="value" type="long"/>
{ ? = call simpleScalar(:number) }
</sql-query>

<sql-query name="paramhandling" callable="true">
<sql-query name="paramhandling">
<return-scalar column="value" type="long"/>
<return-scalar column="value2" type="long"/>
{ ? = call testParamHandling(?, ?) }
</sql-query>

<sql-query name="paramhandling_mixed" callable="true">
<sql-query name="paramhandling_mixed">
<return-scalar column="value" type="long"/>
<return-scalar column="value2" type="long"/>
{ ? = call testParamHandling(?,:second) }
</sql-query>

<sql-query name="selectAllEmployments" callable="true">
<sql-query name="selectAllEmployments">
<return alias="emp" class="Employment">
<return-property name="employee" column="EMPLOYEE"/>
<return-property name="employer" column="EMPLOYER"/>
Expand All @@ -42,7 +42,7 @@
{ ? = call allEmployments }
</sql-query>

<sql-query name="selectEmploymentsForRegion" callable="true">
<sql-query name="selectEmploymentsForRegion">
<return alias="emp" class="Employment">
<return-property name="employee" column="EMPLOYEE"/>
<return-property name="employer" column="EMPLOYER"/>
Expand Down
16 changes: 11 additions & 5 deletions src/NHibernate/Driver/DriverBase.cs
Expand Up @@ -136,11 +136,6 @@ public virtual IDbCommand GenerateCommand(CommandType type, SqlString sqlString,
return cmd;
}

public virtual int RegisterResultSetOutParameter(IDbCommand command, int position, bool hasReturnValue)
{
throw new NotImplementedException(GetType().Name + " does not support resultsets via stored procedures");
}

private void SetCommandTimeout(IDbCommand cmd, object envTimeout)
{
if (commandTimeout >= 0)
Expand Down Expand Up @@ -215,12 +210,23 @@ public IDbDataParameter GenerateParameter(IDbCommand command, string name, SqlTy

public void PrepareCommand(IDbCommand command)
{
OnBeforePrepare(command);

if (SupportsPreparingCommands && prepareSql)
{
command.Prepare();
}
}

/// <summary>
/// Override to make any adjustments to the IDbCommand object. (e.g., Oracle custom OUT parameter)
/// Parameters have been bound by this point, so their order can be adjusted too.
/// This is analagous to the RegisterResultSetOutParameter() function in Hibernate.
/// </summary>
protected virtual void OnBeforePrepare(IDbCommand command)
{
}

public IDbDataParameter GenerateOutputParameter(IDbCommand command)
{
IDbDataParameter param = GenerateParameter(command, "ReturnValue", SqlTypeFactory.Int32);
Expand Down
11 changes: 0 additions & 11 deletions src/NHibernate/Driver/IDriver.cs
Expand Up @@ -79,17 +79,6 @@ public interface IDriver
/// <returns>An IDbCommand with the CommandText and Parameters fully set.</returns>
IDbCommand GenerateCommand(CommandType type, SqlString sqlString, SqlType[] parameterTypes);

/// <summary>
/// Registers an OUT parameter which will be returing a
/// <see cref="IDataReader"/>. How this is accomplished varies greatly
/// from DB to DB, hence its inclusion here.
/// </summary>
/// <param name="command">The <see cref="IDbCommand"/> with CommandType.StoredProcedure.</param>
/// <param name="position">The bind position at which to register the OUT param.</param>
/// <param name="hasReturnValue">Whether the out parameter is a return value, or an out parameter.</param>
/// <returns>The number of (contiguous) bind positions used.</returns>
int RegisterResultSetOutParameter(IDbCommand command, int position, bool hasReturnValue);

/// <summary>
/// Prepare the <paramref name="command" /> by calling <see cref="IDbCommand.Prepare()" />.
/// May be a no-op if the driver does not support preparing commands, or for any other reason.
Expand Down
12 changes: 10 additions & 2 deletions src/NHibernate/Driver/OracleClientDriver.cs
@@ -1,5 +1,6 @@
using System.Data;
using System.Data.OracleClient;
using NHibernate.Engine.Query;
using NHibernate.SqlTypes;

namespace NHibernate.Driver
Expand Down Expand Up @@ -48,10 +49,17 @@ protected override void InitializeParameter(IDbDataParameter dbParam, string nam
}
}

public override int RegisterResultSetOutParameter(IDbCommand command, int position, bool hasReturnValue)
protected override void OnBeforePrepare(IDbCommand command)
{
base.OnBeforePrepare(command);

CallableParser.Detail detail = CallableParser.Parse(command.CommandText);

if (!detail.IsCallable)
return;

throw new System.NotImplementedException(GetType().Name +
" does not support resultsets via stored procedures." +
" does not support CallableStatement syntax (stored procedures)." +
" Consider using OracleDataClientDriver instead.");
}
}
Expand Down
20 changes: 14 additions & 6 deletions src/NHibernate/Driver/OracleDataClientDriver.cs
@@ -1,6 +1,7 @@
using System.Data;
using System.Reflection;
using NHibernate.AdoNet;
using NHibernate.Engine.Query;
using NHibernate.SqlTypes;
using NHibernate.Util;

Expand Down Expand Up @@ -82,17 +83,24 @@ protected override void InitializeParameter(IDbDataParameter dbParam, string nam
}
}

public override int RegisterResultSetOutParameter(IDbCommand command, int position, bool hasReturnValue)
protected override void OnBeforePrepare(IDbCommand command)
{
base.OnBeforePrepare(command);

CallableParser.Detail detail = CallableParser.Parse(command.CommandText);

if (!detail.IsCallable)
return;

command.CommandType = CommandType.StoredProcedure;
command.CommandText = detail.FunctionName;

IDbDataParameter outCursor = command.CreateParameter();
outCursor.ParameterName = "";
oracleDbType.SetValue(outCursor, oracleDbTypeRefCursor, null);

outCursor.Direction = hasReturnValue ? ParameterDirection.ReturnValue : ParameterDirection.Output;

command.Parameters.Insert(position, outCursor);
outCursor.Direction = detail.HasReturn ? ParameterDirection.ReturnValue : ParameterDirection.Output;

return 1;
command.Parameters.Insert(0, outCursor);
}

#region IEmbeddedBatcherFactoryProvider Members
Expand Down
28 changes: 27 additions & 1 deletion src/NHibernate/Driver/SqlStringFormatter.cs
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using NHibernate.SqlCommand;
using NHibernate.Engine.Query;

namespace NHibernate.Driver
{
Expand All @@ -15,6 +16,9 @@ public class SqlStringFormatter : ISqlStringVisitor
private readonly Dictionary<int, int> queryIndexToNumberOfPreceedingParameters = new Dictionary<int, int>();
private readonly Dictionary<int, int> parameterIndexToQueryIndex = new Dictionary<int, int>();

private bool hasReturnParameter = false;
private bool foundReturnParameter = false;

public SqlStringFormatter(ISqlParameterFormatter formatter, string multipleQueriesSeparator)
{
this.formatter = formatter;
Expand All @@ -24,6 +28,7 @@ public SqlStringFormatter(ISqlParameterFormatter formatter, string multipleQueri
public void Format(SqlString text)
{
DetermineNumberOfPreceedingParametersForEachQuery(text);
foundReturnParameter = false;
text.Visit(this);
}

Expand All @@ -44,6 +49,13 @@ void ISqlStringVisitor.String(SqlString sqlString)

void ISqlStringVisitor.Parameter(Parameter parameter)
{
if (hasReturnParameter && !foundReturnParameter)
{
result.Append(parameter);
foundReturnParameter = true;
return;
}

string name;

if (queryIndexToNumberOfPreceedingParameters.Count == 0)
Expand Down Expand Up @@ -80,6 +92,13 @@ private void DetermineNumberOfPreceedingParametersForEachQuery(SqlString text)
int currentParameterIndex = 0;
int currentQueryParameterCount = 0;
int currentQueryIndex = 0;
hasReturnParameter = false;
foundReturnParameter = false;

CallableParser.Detail callableDetail = CallableParser.Parse(text.ToString());

if (callableDetail.IsCallable && callableDetail.HasReturn)
hasReturnParameter = true;

foreach (object part in text.Parts)
{
Expand All @@ -95,7 +114,14 @@ private void DetermineNumberOfPreceedingParametersForEachQuery(SqlString text)

if (parameter != null)
{
parameterIndexToQueryIndex[currentParameterIndex] = currentQueryIndex;
if (hasReturnParameter && !foundReturnParameter)
{
foundReturnParameter = true;
}
else
{
parameterIndexToQueryIndex[currentParameterIndex] = currentQueryIndex;
}
currentQueryParameterCount++;
currentParameterIndex++;
}
Expand Down
32 changes: 24 additions & 8 deletions src/NHibernate/Engine/Query/CallableParser.cs
Expand Up @@ -8,26 +8,42 @@ namespace NHibernate.Engine.Query
{
public static class CallableParser
{

public class Detail
{
public bool IsCallable;
public bool HasReturn;
public string FunctionName;
}

private static readonly Regex functionNameFinder = new Regex(@"\{[\S\s]*call[\s]+([\w]+)[^\w]");
private static readonly int NewLineLength = Environment.NewLine.Length;

public static SqlString Parse(string sqlString)
public static Detail Parse(string sqlString)
{
bool isCallableSyntax = sqlString.IndexOf("{") == 0 &&
sqlString.IndexOf("}") == (sqlString.Length - 1) &&
sqlString.IndexOf("call") > 0;
Detail callableDetail = new Detail();

if (!isCallableSyntax)
throw new ParserException("Expected callable syntax {? = call procedure_name[(?, ?, ...)]} but got: " + sqlString);
callableDetail.IsCallable = sqlString.IndexOf("{") == 0 &&
sqlString.IndexOf("}") == (sqlString.Length - 1) &&
sqlString.IndexOf("call") > 0;

if (!callableDetail.IsCallable)
return callableDetail;

Match functionMatch = functionNameFinder.Match(sqlString);

if ((!functionMatch.Success) || (functionMatch.Groups.Count < 2))
throw new HibernateException("Could not determine function name for callable SQL: " + sqlString);

string function = functionMatch.Groups[1].Value;
return new SqlString(function);
callableDetail.FunctionName = functionMatch.Groups[1].Value;

callableDetail.HasReturn = sqlString.IndexOf("call") > 0 &&
sqlString.IndexOf("?") > 0 &&
sqlString.IndexOf("=") > 0 &&
sqlString.IndexOf("?") < sqlString.IndexOf("call") &&
sqlString.IndexOf("=") < sqlString.IndexOf("call");

return callableDetail;
}
}
}

0 comments on commit d7c398a

Please sign in to comment.