In [2]:
import numpy as np
import torch
from NeuralNet import ResidualNetwork as Resnet
from NeuralNet import PolicyNetwork as PolHead
from NeuralNet import ValueNetwork as ValHead

In [9]:
# Model hyperparameters
Filters = 128
Layers = 13
HistoryDepth = 8
TargetDevice = "mps"

model_name = "test1"
general_path = '../../Models/Human/'
target_path = '../../Models/scripted/'

resnetPath = f'{general_path}/ResNet/{model_name}.pt'
policyPath = f'{general_path}/PolHead/{model_name}.pt'
valuePath = f'{general_path}/ValHead/{model_name}.pt'

resnetModel = Resnet(Filters, Layers, HistoryDepth + 1)
resnetModel.load_state_dict(torch.load(resnetPath, map_location=torch.device('cpu')))
resnetModel.eval()
resnetModel.to(TargetDevice)

policyModel = PolHead(Filters)
policyModel.load_state_dict(torch.load(policyPath, map_location=torch.device('cpu')))
policyModel.eval()
policyModel.to(TargetDevice)

valueModel = ValHead(Filters)
valueModel.load_state_dict(torch.load(valuePath, map_location=torch.device('cpu')))
valueModel.eval()
valueModel.to(TargetDevice)

resnetExample = np.zeros((1, HistoryDepth + 1, 15, 15), dtype=np.float32) 
resnetExample = torch.tensor(resnetExample)
resnetExample = resnetExample.to(TargetDevice)

headsExample = np.zeros((1, Filters, 15, 15), dtype=np.float32)
headsExample = torch.tensor(headsExample)
headsExample = headsExample.to(TargetDevice)

traced_resnet = torch.jit.trace(resnetModel, resnetExample)
traced_polhead = torch.jit.trace(policyModel, headsExample)
traced_valhead = torch.jit.trace(valueModel, headsExample)

traced_resnet.save(f'{target_path}/ResNet/{model_name}.pt')
traced_polhead.save(f'{target_path}/PolHead/{model_name}.pt')
traced_valhead.save(f'{target_path}/ValHead/{model_name}.pt')