Skip to content

Commit

Permalink
Merge pull request #133 from isaacbrodsky/master
Browse files Browse the repository at this point in the history
Fix calling methods with nullable parameters (null could not be passed)
  • Loading branch information
viniciusjarina committed Dec 30, 2014
2 parents fb177cd + c162d48 commit 84a593f
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 20 deletions.
42 changes: 28 additions & 14 deletions Core/NLua/CheckType.cs
Expand Up @@ -102,11 +102,34 @@ internal ExtractValue CheckLuaType (LuaState luaState, int stackPos, Type paramT
paramType = paramType.GetElementType ();

var underlyingType = Nullable.GetUnderlyingType (paramType);

if (underlyingType != null)
if (underlyingType != null) {
paramType = underlyingType; // Silently convert nullable types to their non null requics
}

var extractKey = GetExtractDictionaryKey (paramType);

bool netParamIsNumeric = paramType == typeof (int) ||
paramType == typeof (uint) ||
paramType == typeof (long) ||
paramType == typeof (ulong) ||
paramType == typeof (short) ||
paramType == typeof (ushort) ||
paramType == typeof (float) ||
paramType == typeof (double) ||
paramType == typeof (decimal) ||
paramType == typeof (byte);

// If it is a nullable
if (underlyingType != null) {
// null can always be assigned to nullable
if (luatype == LuaTypes.Nil) {
// Return the correct extractor anyways
if (netParamIsNumeric || paramType == typeof (bool))
return extractValues [extractKey];
return extractNetObject;
}
}

if (paramType.Equals (typeof(object)))
return extractValues [extractKey];
Expand All @@ -127,15 +150,6 @@ internal ExtractValue CheckLuaType (LuaState luaState, int stackPos, Type paramT
return extractValues [GetExtractDictionaryKey (typeof(double))];
}
bool netParamIsString = paramType == typeof (string) || paramType == typeof (char []);
bool netParamIsNumeric = paramType == typeof (int) ||
paramType == typeof (uint) ||
paramType == typeof (long) ||
paramType == typeof (ulong) ||
paramType == typeof (short) ||
paramType == typeof (float) ||
paramType == typeof (double) ||
paramType == typeof (decimal) ||
paramType == typeof (byte);

if (netParamIsNumeric) {
if (LuaLib.LuaIsNumber (luaState, stackPos) && !netParamIsString)
Expand All @@ -149,13 +163,13 @@ internal ExtractValue CheckLuaType (LuaState luaState, int stackPos, Type paramT
else if (luatype == LuaTypes.Nil)
return extractNetObject; // kevinh - silently convert nil to a null string pointer
} else if (paramType == typeof(LuaTable)) {
if (luatype == LuaTypes.Table)
if (luatype == LuaTypes.Table || luatype == LuaTypes.Nil)
return extractValues [extractKey];
} else if (paramType == typeof(LuaUserData)) {
if (luatype == LuaTypes.UserData)
if (luatype == LuaTypes.UserData || luatype == LuaTypes.Nil)
return extractValues [extractKey];
} else if (paramType == typeof(LuaFunction)) {
if (luatype == LuaTypes.Function)
if (luatype == LuaTypes.Function || luatype == LuaTypes.Nil)
return extractValues [extractKey];
} else if (typeof(Delegate).IsAssignableFrom (paramType) && luatype == LuaTypes.Function)
return new ExtractValue (new DelegateGenerator (translator, paramType).ExtractGenerated);
Expand Down
21 changes: 15 additions & 6 deletions Core/NLua/ObjectTranslator.cs
Expand Up @@ -829,26 +829,35 @@ internal object GetObject (LuaState luaState, int index)
*/
internal LuaTable GetTable (LuaState luaState, int index)
{
LuaLib.LuaPushValue (luaState, index);
return new LuaTable (LuaLib.LuaRef (luaState, 1), interpreter);
LuaLib.LuaPushValue (luaState, index);
int reference = LuaLib.LuaRef (luaState, 1);
if (reference == -1)
return null;
return new LuaTable (reference, interpreter);
}

/*
* Gets the userdata in the index positon of the Lua stack.
*/
internal LuaUserData GetUserData (LuaState luaState, int index)
{
LuaLib.LuaPushValue (luaState, index);
return new LuaUserData (LuaLib.LuaRef (luaState, 1), interpreter);
LuaLib.LuaPushValue (luaState, index);
int reference = LuaLib.LuaRef (luaState, 1);
if (reference == -1)
return null;
return new LuaUserData(reference, interpreter);
}

/*
* Gets the function in the index positon of the Lua stack.
*/
internal LuaFunction GetFunction (LuaState luaState, int index)
{
LuaLib.LuaPushValue (luaState, index);
return new LuaFunction (LuaLib.LuaRef (luaState, 1), interpreter);
LuaLib.LuaPushValue (luaState, index);
int reference = LuaLib.LuaRef (luaState, 1);
if (reference == -1)
return null;
return new LuaFunction (reference, interpreter);
}

/*
Expand Down
43 changes: 43 additions & 0 deletions tests/LuaTests.cs
Expand Up @@ -153,6 +153,26 @@ public void ThrowException ()
}
}

/*
* Tests passing a LuaFunction
*/
[Test]
public void CallLuaFunction()
{
using (Lua lua = new Lua ()) {
lua.DoString ("function someFunc(v1,v2) return v1 + v2 end");
lua ["funcObject"] = lua.GetFunction ("someFunc");

lua.DoString ("luanet.load_assembly('mscorlib')");
lua.DoString ("luanet.load_assembly('NLuaTest')");
lua.DoString ("TestClass=luanet.import_type('NLuaTest.Mock.TestClass')");
lua.DoString ("b = TestClass():TestLuaFunction(funcObject)[0]");
Assert.AreEqual (3, lua ["b"]);
lua.DoString ("a = TestClass():TestLuaFunction(nil)");
Assert.AreEqual (null, lua ["a"]);
}
}

/*
* Tests capturing an exception
*/
Expand Down Expand Up @@ -620,6 +640,29 @@ public void TestRegisterFunction ()
Assert.AreEqual (5.0f, Convert.ToSingle (vals1 [0]));
}
}

/*
* Tests passing a null object as a parameter to a
* method that accepts a nullable.
*/
[Test]
public void TestNullableParameter ()
{
using (Lua lua = new Lua ()) {
lua.DoString ("luanet.load_assembly('NLuaTest')");
lua.DoString ("TestClass=luanet.import_type('NLuaTest.Mock.TestClass')");
lua.DoString ("test=TestClass()");
lua.DoString ("a = test:NullableMethod(nil)");
Assert.AreEqual (null, lua ["a"]);
lua ["timeVal"] = TimeSpan.FromSeconds (5);
lua.DoString ("b = test:NullableMethod(timeVal)");
Assert.AreEqual (TimeSpan.FromSeconds (5), lua ["b"]);
lua.DoString ("d = test:NullableMethod2(2)");
Assert.AreEqual (2, lua ["d"]);
lua.DoString ("c = test:NullableMethod2(nil)");
Assert.AreEqual (null, lua ["c"]);
}
}

/*
* Tests if DoString is correctly returning values
Expand Down
18 changes: 18 additions & 0 deletions tests/TestLua.cs
Expand Up @@ -374,6 +374,24 @@ public static TestClass makeFromString (String str)
set { }
}

public TimeSpan? NullableMethod (TimeSpan? input)
{
return input;
}

public int? NullableMethod2 (int? input)
{
return input;
}

public object[] TestLuaFunction (LuaFunction func)
{
if (func != null) {
return func.Call (1, 2);
}
return null;
}

public int sum (int x, int y)
{
return x + y;
Expand Down

0 comments on commit 84a593f

Please sign in to comment.