In [1]:
import os
import torch
import numpy as np
from src.tools.easyparser import parser
from src.operators.transformer import make_model, subsequent_mask

settings = parser()
# print(settings.initial_args)
settings.initial_args.gpu = 0
settings.initial_args.I_size = 150
settings.initial_args.F_size = 150
settings.initial_args.batch_size = 3
settings.initial_args.max_epochs = 100
args = settings.get_args()
args.weight_decay = 0.2
torch.cuda.set_device(args.gpu)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

In [2]:
param_dict = torch.load(os.path.join(args.params_folder, 'params_100.pt'))['model_state_dict']
model = make_model(H=args.I_size, W=args.I_size, input_channel=1, d_channel=5, d_channel_ff=10, load_params=param_dict).to(device=args.device,dtype=args.value_dtype)

In [3]:
from src.dataseters.GRUs import TyDataset,ToTensor
from torchvision import transforms
transform = transforms.Compose([ToTensor()])

In [4]:
dataset = TyDataset(args=args,train=False,transform=transform)

In [5]:
src = dataset[0]['inputs'].to(device=args.device,dtype=args.value_dtype).unsqueeze(0)
tgt = dataset[0]['targets'].unsqueeze(0).unsqueeze(2).to(device=args.device,dtype=args.value_dtype)
src_mask = torch.ones(1, src.shape[1]).to(device=args.device,dtype=args.value_dtype)
tgt_mask = subsequent_mask(tgt.shape[1]).to(device=args.device,dtype=args.value_dtype)

In [7]:
model.train()
model(src,tgt,src_mask,tgt_mask)

tensor([[[[0.1796, 0.1750, 0.1820,  ..., 0.1789, 0.1920, 0.2138],
          [0.1665, 0.1632, 0.1600,  ..., 0.1541, 0.1268, 0.1661],
          [0.1655, 0.1665, 0.1643,  ..., 0.1575, 0.1222, 0.1662],
          ...,
          [0.1793, 0.1621, 0.1490,  ..., 0.1306, 0.1233, 0.1591],
          [0.1828, 0.1473, 0.1491,  ..., 0.1558, 0.1125, 0.1489],
          [0.1920, 0.1739, 0.1797,  ..., 0.1880, 0.1511, 0.2133]],

         [[0.1705, 0.1638, 0.2003,  ..., 0.1975, 0.1895, 0.2108],
          [0.1589, 0.1592, 0.1734,  ..., 0.1516, 0.1420, 0.1762],
          [0.1607, 0.1675, 0.1740,  ..., 0.1780, 0.1311, 0.1696],
          ...,
          [0.1637, 0.1589, 0.1701,  ..., 0.1239, 0.1267, 0.1683],
          [0.1800, 0.1769, 0.1831,  ..., 0.1623, 0.1345, 0.1647],
          [0.1970, 0.1987, 0.2024,  ..., 0.1999, 0.1727, 0.2271]],

         [[0.2110, 0.2227, 0.2353,  ..., 0.1935, 0.2222, 0.2571],
          [0.1991, 0.1966, 0.1904,  ..., 0.1728, 0.1952, 0.2255],
          [0.1845, 0.1925, 0.2007,  ..., 0

In [10]:
model.eval()
pred = torch.zeros_like(tgt)
for i in range(18):
    (print(pred))
    pred[:,i] = (model(src,pred,src_mask,tgt_mask)[:,i]).detach()

tensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],


         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],


         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],


         ...,


         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
     