@@ -85,9 +85,9 @@ std::string Arg::type_name() const {
8585 default :
8686 return " None" ;
8787 }
88-
88+
8989}
90-
90+
9191const torch::jit::IValue* Arg::IValue () const {
9292 if (type_ == Type::kIValue ) {
9393 return ptr_.ivalue ;
@@ -150,7 +150,7 @@ double Arg::unwrapToDouble(double default_val) {
150150
151151double Arg::unwrapToDouble () {
152152 return this ->unwrapTo <double >();
153- }
153+ }
154154
155155bool Arg::unwrapToBool (bool default_val) {
156156 return this ->unwrapTo <bool >(default_val);
@@ -194,26 +194,41 @@ c10::List<bool> Arg::unwrapToBoolList() {
194194
195195template <typename T>
196196T Arg::unwrapTo (T default_val) {
197- if (isIValue ()) {
198- // TODO: implement Tag Checking
199- return ptr_.ivalue ->to <T>();
197+ try {
198+ return this ->unwrapTo <T>();
199+ } catch (trtorch::Error& e) {
200+ LOG_DEBUG (" In arg unwrapping, returning default value provided (" << e.what () << " )" );
201+ return default_val;
200202 }
201- LOG_DEBUG (" In arg unwrapping, returning default value provided" );
202- return default_val;
203203}
204204
205-
206205template <typename T>
207206T Arg::unwrapTo () {
208- if (isIValue ()) {
209- // TODO: Implement Tag checking
210- return ptr_.ivalue ->to <T>();
211- // TODO: Exception
212- // LOG_INTERNAL_ERROR("Requested unwrapping of arg IValue assuming it was " << typeid(T).name() << " however type is " << ptr_.ivalue->type());
213-
207+ TRTORCH_CHECK (isIValue (), " Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
208+ auto ivalue = ptr_.ivalue ;
209+ bool correct_type = false ;
210+ if (typeid (T) == typeid (double )) {
211+ correct_type = ivalue->isDouble ();
212+ } else if (typeid (T) == typeid (bool )) {
213+ correct_type = ivalue->isBool ();
214+ } else if (typeid (T) == typeid (int64_t )) {
215+ correct_type = ivalue->isInt ();
216+ } else if (typeid (T) == typeid (at::Tensor)) {
217+ correct_type = ivalue->isTensor ();
218+ } else if (typeid (T) == typeid (c10::Scalar)) {
219+ correct_type = ivalue->isScalar ();
220+ } else if (typeid (T) == typeid (c10::List<int64_t >)) {
221+ correct_type = ivalue->isIntList ();
222+ } else if (typeid (T) == typeid (c10::List<double >)) {
223+ correct_type = ivalue->isDoubleList ();
224+ } else if (typeid (T) == typeid (c10::List<bool >)) {
225+ correct_type = ivalue->isBoolList ();
226+ } else {
227+ TRTORCH_THROW_ERROR (" Requested unwrapping of arg to an unsupported type: " << typeid (T).name ());
214228 }
215- TRTORCH_THROW_ERROR (" Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
216- return T ();
229+
230+ TRTORCH_CHECK (correct_type, " Requested unwrapping of arg IValue assuming it was " << typeid (T).name () << " however type is " << *(ptr_.ivalue ->type ()));
231+ return ptr_.ivalue ->to <T>();
217232}
218233
219234
0 commit comments