In [1]:
import torch
from models.network import *

In [2]:
device = "cuda"
num_frame = 10
num_channel = 6
img_size = 120

model = MS2TAN(
    dim_list=[256, 192, 128],
    num_frame=num_frame,
    image_size=img_size,
    patch_list=[12, 10, 8],
    in_chans=num_channel,
    out_chans=num_channel,
    depth_list=[2, 2, 2],
    heads_list=[8, 6, 4],
    dim_head_list=[32, 32, 32],
).to(device)
init_weights(model)

num_patches=100,             num_positions=1000,             patch_dim_in=864,             patch_dim_out=864
num_patches=144,             num_positions=1440,             patch_dim_in=600,             patch_dim_out=600
num_patches=225,             num_positions=2250,             patch_dim_in=384,             patch_dim_out=384


In [3]:
total = sum(p.numel() for p in model.parameters())
print("Total params: %.2fM" % (total/1e6))

Total params: 4.59M


In [4]:
batch_size = 1

# input and output time-series images
X = torch.randn(batch_size, num_frame, num_channel, img_size, img_size).to(device)
y = torch.randn(batch_size, num_frame, num_channel, img_size, img_size).to(device)

# artificial masked pixels in trainset
artificial = torch.randn(batch_size, num_frame, 1, img_size, img_size).to(device)

# hint tensor for each missing pixels (both artificial and real)
hint_tensor = torch.randn(batch_size, num_frame, 1, img_size, img_size).to(device)

In [5]:
# forward
out = model(X, (hint_tensor, artificial), y, mode='val')

# each immediate result
out_list = out['hist_list']
for idx, res in enumerate(out_list):
    print(f'Immediate result {idx}:', res.shape)

# final result after replacement
final_result = out['replace_out']
print(f'Final result:', final_result.shape)

Immediate result 0: torch.Size([1, 10, 6, 120, 120])
Immediate result 1: torch.Size([1, 10, 6, 120, 120])
Immediate result 2: torch.Size([1, 10, 6, 120, 120])
Final result: torch.Size([1, 10, 6, 120, 120])
