diff --git a/Core/NLua/CheckType.cs b/Core/NLua/CheckType.cs index 8b1e7b6e..89136468 100755 --- a/Core/NLua/CheckType.cs +++ b/Core/NLua/CheckType.cs @@ -102,11 +102,34 @@ internal ExtractValue CheckLuaType (LuaState luaState, int stackPos, Type paramT paramType = paramType.GetElementType (); var underlyingType = Nullable.GetUnderlyingType (paramType); - - if (underlyingType != null) + + if (underlyingType != null) { paramType = underlyingType; // Silently convert nullable types to their non null requics + } var extractKey = GetExtractDictionaryKey (paramType); + + bool netParamIsNumeric = paramType == typeof (int) || + paramType == typeof (uint) || + paramType == typeof (long) || + paramType == typeof (ulong) || + paramType == typeof (short) || + paramType == typeof (ushort) || + paramType == typeof (float) || + paramType == typeof (double) || + paramType == typeof (decimal) || + paramType == typeof (byte); + + // If it is a nullable + if (underlyingType != null) { + // null can always be assigned to nullable + if (luatype == LuaTypes.Nil) { + // Return the correct extractor anyways + if (netParamIsNumeric || paramType == typeof (bool)) + return extractValues [extractKey]; + return extractNetObject; + } + } if (paramType.Equals (typeof(object))) return extractValues [extractKey]; @@ -127,15 +150,6 @@ internal ExtractValue CheckLuaType (LuaState luaState, int stackPos, Type paramT return extractValues [GetExtractDictionaryKey (typeof(double))]; } bool netParamIsString = paramType == typeof (string) || paramType == typeof (char []); - bool netParamIsNumeric = paramType == typeof (int) || - paramType == typeof (uint) || - paramType == typeof (long) || - paramType == typeof (ulong) || - paramType == typeof (short) || - paramType == typeof (float) || - paramType == typeof (double) || - paramType == typeof (decimal) || - paramType == typeof (byte); if (netParamIsNumeric) { if (LuaLib.LuaIsNumber (luaState, stackPos) && !netParamIsString) @@ -149,13 +163,13 @@ internal ExtractValue CheckLuaType (LuaState luaState, int stackPos, Type paramT else if (luatype == LuaTypes.Nil) return extractNetObject; // kevinh - silently convert nil to a null string pointer } else if (paramType == typeof(LuaTable)) { - if (luatype == LuaTypes.Table) + if (luatype == LuaTypes.Table || luatype == LuaTypes.Nil) return extractValues [extractKey]; } else if (paramType == typeof(LuaUserData)) { - if (luatype == LuaTypes.UserData) + if (luatype == LuaTypes.UserData || luatype == LuaTypes.Nil) return extractValues [extractKey]; } else if (paramType == typeof(LuaFunction)) { - if (luatype == LuaTypes.Function) + if (luatype == LuaTypes.Function || luatype == LuaTypes.Nil) return extractValues [extractKey]; } else if (typeof(Delegate).IsAssignableFrom (paramType) && luatype == LuaTypes.Function) return new ExtractValue (new DelegateGenerator (translator, paramType).ExtractGenerated); diff --git a/Core/NLua/ObjectTranslator.cs b/Core/NLua/ObjectTranslator.cs index c3c2123c..25dae263 100755 --- a/Core/NLua/ObjectTranslator.cs +++ b/Core/NLua/ObjectTranslator.cs @@ -829,8 +829,11 @@ internal object GetObject (LuaState luaState, int index) */ internal LuaTable GetTable (LuaState luaState, int index) { - LuaLib.LuaPushValue (luaState, index); - return new LuaTable (LuaLib.LuaRef (luaState, 1), interpreter); + LuaLib.LuaPushValue (luaState, index); + int reference = LuaLib.LuaRef (luaState, 1); + if (reference == -1) + return null; + return new LuaTable (reference, interpreter); } /* @@ -838,8 +841,11 @@ internal LuaTable GetTable (LuaState luaState, int index) */ internal LuaUserData GetUserData (LuaState luaState, int index) { - LuaLib.LuaPushValue (luaState, index); - return new LuaUserData (LuaLib.LuaRef (luaState, 1), interpreter); + LuaLib.LuaPushValue (luaState, index); + int reference = LuaLib.LuaRef (luaState, 1); + if (reference == -1) + return null; + return new LuaUserData(reference, interpreter); } /* @@ -847,8 +853,11 @@ internal LuaUserData GetUserData (LuaState luaState, int index) */ internal LuaFunction GetFunction (LuaState luaState, int index) { - LuaLib.LuaPushValue (luaState, index); - return new LuaFunction (LuaLib.LuaRef (luaState, 1), interpreter); + LuaLib.LuaPushValue (luaState, index); + int reference = LuaLib.LuaRef (luaState, 1); + if (reference == -1) + return null; + return new LuaFunction (reference, interpreter); } /* diff --git a/tests/LuaTests.cs b/tests/LuaTests.cs index 549cc6e7..58a7ce84 100755 --- a/tests/LuaTests.cs +++ b/tests/LuaTests.cs @@ -153,6 +153,26 @@ public void ThrowException () } } + /* + * Tests passing a LuaFunction + */ + [Test] + public void CallLuaFunction() + { + using (Lua lua = new Lua ()) { + lua.DoString ("function someFunc(v1,v2) return v1 + v2 end"); + lua ["funcObject"] = lua.GetFunction ("someFunc"); + + lua.DoString ("luanet.load_assembly('mscorlib')"); + lua.DoString ("luanet.load_assembly('NLuaTest')"); + lua.DoString ("TestClass=luanet.import_type('NLuaTest.Mock.TestClass')"); + lua.DoString ("b = TestClass():TestLuaFunction(funcObject)[0]"); + Assert.AreEqual (3, lua ["b"]); + lua.DoString ("a = TestClass():TestLuaFunction(nil)"); + Assert.AreEqual (null, lua ["a"]); + } + } + /* * Tests capturing an exception */ @@ -620,6 +640,29 @@ public void TestRegisterFunction () Assert.AreEqual (5.0f, Convert.ToSingle (vals1 [0])); } } + + /* + * Tests passing a null object as a parameter to a + * method that accepts a nullable. + */ + [Test] + public void TestNullableParameter () + { + using (Lua lua = new Lua ()) { + lua.DoString ("luanet.load_assembly('NLuaTest')"); + lua.DoString ("TestClass=luanet.import_type('NLuaTest.Mock.TestClass')"); + lua.DoString ("test=TestClass()"); + lua.DoString ("a = test:NullableMethod(nil)"); + Assert.AreEqual (null, lua ["a"]); + lua ["timeVal"] = TimeSpan.FromSeconds (5); + lua.DoString ("b = test:NullableMethod(timeVal)"); + Assert.AreEqual (TimeSpan.FromSeconds (5), lua ["b"]); + lua.DoString ("d = test:NullableMethod2(2)"); + Assert.AreEqual (2, lua ["d"]); + lua.DoString ("c = test:NullableMethod2(nil)"); + Assert.AreEqual (null, lua ["c"]); + } + } /* * Tests if DoString is correctly returning values diff --git a/tests/TestLua.cs b/tests/TestLua.cs index 2629ded8..13fe4b20 100755 --- a/tests/TestLua.cs +++ b/tests/TestLua.cs @@ -374,6 +374,24 @@ public static TestClass makeFromString (String str) set { } } + public TimeSpan? NullableMethod (TimeSpan? input) + { + return input; + } + + public int? NullableMethod2 (int? input) + { + return input; + } + + public object[] TestLuaFunction (LuaFunction func) + { + if (func != null) { + return func.Call (1, 2); + } + return null; + } + public int sum (int x, int y) { return x + y;