In [1]:
# %% [markdown]
# ### Test the trained VPTR

# %%
!pip install lpips
!pip install matplotlib

# %%
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

from pathlib import Path
import random
from datetime import datetime
import time

from model import VPTREnc, VPTRDec, VPTRDisc, init_weights, VPTRFormerFAR, VPTRFormerNAR
from model import GDL, MSELoss, L1Loss, GANLoss, BiPatchNCE
from utils import VidCenterCrop, VidPad, VidResize, VidNormalize, VidReNormalize, VidCrop, VidRandomHorizontalFlip, VidRandomVerticalFlip, VidToTensor
from utils import visualize_batch_clips, save_ckpt, load_ckpt, set_seed, AverageMeters, init_loss_dict, write_summary, resume_training, write_code_files
from utils import set_seed, PSNR, SSIM, MSEScore, get_dataloader
import lpips
import numpy as np


from matplotlib import pyplot as plt
%matplotlib inline

set_seed(2021)



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# %%
# Load the checkpoint file and inspect its contents
import torch

# Load the checkpoint
checkpoint = torch.load('C:\VPTR_jigsaws\jigsaws_suturing\VPTR_ckpts\JIGSAWS_ResnetAE_MSEGDLgan_ckpt\epoch_100.tar', map_location=torch.device('cpu'))

# Print out the keys in the checkpoint
print("Checkpoint keys:", checkpoint.keys())

# Check if 'modules' key is present and print its keys
if 'modules' in checkpoint:
    print("Modules keys:", checkpoint['modules'].keys())
else:
    print("'modules' key not found in the checkpoint")

# Print the optimizer keys if present
if 'optimizers' in checkpoint:
    print("Optimizers keys:", checkpoint['optimizers'].keys())
else:
    print("'optimizers' key not found in the checkpoint")

# Print additional information if needed
print("Checkpoint contains epoch information:", 'epoch' in checkpoint)
print("Checkpoint contains history information:", 'history' in checkpoint)

Checkpoint keys: dict_keys(['epoch', 'loss_dict', 'Module_state_dict', 'optimizer_state_dict', 'code'])
'modules' key not found in the checkpoint
'optimizers' key not found in the checkpoint
Checkpoint contains epoch information: True
Checkpoint contains history information: False


In [3]:
# %%
resume_ckpt = Path('C:\VPTR_jigsaws\jigsaws_suturing\VPTR_ckpts\JIGSAWS_FAR_ckpt\epoch_200.tar') #The trained Transformer checkpoint file
resume_AE_ckpt = Path('C:\VPTR_jigsaws\jigsaws_suturing\VPTR_ckpts\JIGSAWS_ResnetAE_MSEGDLgan_ckpt\epoch_100.tar') #The trained AutoEncoder checkpoint file
num_past_frames = 10
num_future_frames = 20
encH, encW, encC = 8, 8, 528
TSLMA_flag = False
rpe = True
model_flag = 'FAR' #'NAR' for NAR model, 'FAR' for FAR model

img_channels = 3 # 1 for KTH and MovingMNIST, 3 for BAIR
N = 2
device = torch.device('cuda:0')
loss_name_list = ['T_MSE', 'T_GDL', 'T_gan', 'T_total', 'Dtotal', 'Dfake', 'Dreal']

In [4]:
# %%
#Set the padding_type to be "zero" for BAIR dataset
VPTR_Enc = VPTREnc(img_channels, feat_dim = encC, n_downsampling = 3, padding_type = 'reflect').to(device) 

#Set the padding_type to be "zero" for BAIR dataset, set the out_layer to be 'Sigmoid' for MovingMNIST
VPTR_Dec = VPTRDec(img_channels, feat_dim = encC, n_downsampling = 3, out_layer = 'Tanh', padding_type = 'reflect').to(device) 
VPTR_Enc = VPTR_Enc.eval()
VPTR_Dec = VPTR_Dec.eval()

if model_flag == 'NAR':
    VPTR_Transformer = VPTRFormerNAR(num_past_frames, num_future_frames, encH=encH, encW = encW, d_model=encC, 
                                         nhead=8, num_encoder_layers=4, num_decoder_layers=8, dropout=0.1, 
                                         window_size=4, Spatial_FFN_hidden_ratio=4, TSLMA_flag = TSLMA_flag, rpe=rpe).to(device)
else:
    VPTR_Transformer = VPTRFormerFAR(num_past_frames, num_future_frames, encH=encH, encW = encW, d_model=encC, 
                                    nhead=8, num_encoder_layers=12, dropout=0.1, 
                                    window_size=4, Spatial_FFN_hidden_ratio=4, rpe=rpe).to(device)

VPTR_Transformer = VPTR_Transformer.eval()

#load the trained autoencoder, we initialize the discriminator from scratch, for a balanced training
loss_dict, start_epoch = resume_training({'VPTR_Enc': VPTR_Enc, 'VPTR_Dec': VPTR_Dec}, {}, resume_AE_ckpt, loss_name_list)
if resume_ckpt is not None:
    loss_dict, start_epoch = resume_training({'VPTR_Transformer': VPTR_Transformer}, 
                                             {}, resume_ckpt, loss_name_list)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
# %%
train_loader, test_loader, renorm_transform = get_dataloader('Suturing', N, 'C:\\VPTR_jigsaws\\jigsaws_suturing\\frames_split', test_past_frames = 20, test_future_frames = 20, ngpus = 1, num_workers = 1)

In [6]:
def plot_model_result(pred, fig_name, num_frames, n = 2):
    """
    Plot and save figure
    """
    fig, ax = plt.subplots(1, num_frames, figsize = (num_frames, 1))
    fig.subplots_adjust(wspace=0., hspace = 0.)

    for j in range(num_frames):
        ax[j].set_axis_off()
        
        img = pred[:, j, :, :, :].clone()
        img = renorm_transform(img)
        img = torch.clamp(img, min = 0., max = 1.)
        img = img[n, ...]

        img = transforms.ToPILImage()(img)
        ax[j].imshow(img, cmap = 'gray')
    fig.savefig(f'{fig_name}.pdf', bbox_inches = 'tight')
    
def FAR_RIL_test_single_iter(sample, VPTR_Enc, VPTR_Dec, VPTR_Transformer, num_pred, device):
    """
    recurrently inference over Latent space, get a worse result
    """
    past_frames, future_frames = sample
    past_frames = past_frames.to(device)
    future_frames = future_frames.to(device)
    assert num_pred == future_frames.shape[1], "Mismatch between ground truth future frames length and num_pred"

    past_gt_feats = VPTR_Enc(past_frames)
    pred_feats = VPTR_Transformer(past_gt_feats)
    
    pred_frames = VPTR_Dec(pred_feats[:, -1:, ...])
    for i in range(1, num_pred):
        if i == 1:
            input_feats = torch.cat([past_gt_feats, pred_feats[:, -1:, ...]], dim = 1)
        elif i < VPTR_Transformer.num_future_frames:
            input_feats = torch.cat([input_feats, pred_future_feat], dim = 1)
        else:
            input_feats = torch.cat([input_feats, pred_future_feat], dim = 1)
            input_feats = input_feats[:, 1:, ...]
        
        pred_feats = VPTR_Transformer(input_feats)
        pred_future_frame = VPTR_Dec(pred_feats[:, -1:, ...])
        pred_future_feat = pred_feats[:, -1:, ...]
        pred_frames = torch.cat([pred_frames, pred_future_frame], dim = 1)
    
    return pred_frames, future_frames

def FAR_RIP_test_single_iter(sample, VPTR_Enc, VPTR_Dec, VPTR_Transformer, num_pred, device):
    """
    Recursively inference over pixel space.
    """
    past_frames, future_frames = sample
    past_frames = past_frames.to(device)
    future_frames = future_frames.to(device)
    assert num_pred == future_frames.shape[1], "Mismatch between ground truth future frames length and num_pred"

    past_gt_feats = VPTR_Enc(past_frames)
    pred_feats = VPTR_Transformer(past_gt_feats)

    pred_frames = VPTR_Dec(pred_feats[:, -1:, ...])
    for i in range(1, num_pred):
        pred_future_frame = VPTR_Dec(pred_feats[:, -1:, ...])  # Decode the last predicted feature
        pred_future_feat = VPTR_Enc(pred_future_frame)  # Encode the predicted frame

        if i == 1:
            input_feats = torch.cat([past_gt_feats, pred_future_feat], dim=1)  # Concatenate for the first prediction
        else:
            input_feats = torch.cat([input_feats, pred_future_feat], dim=1)  # Concatenate with previous input features
            input_feats = input_feats[:, 1:, ...]  # Remove the oldest frame from input_feats

        pred_feats = VPTR_Transformer(input_feats)
        pred_frames = torch.cat([pred_frames, pred_future_frame], dim=1)

    return pred_frames, future_frames

def NAR_test_single_iter(sample, VPTR_Enc, VPTR_Dec, VPTR_Transformer, num_pred, device):
    """
    NAR model inference function, for the case num_pred is divisible for the num_future_frames of training. e.g. num_pred = 20, num_future_frames = 10
    """
    past_frames, future_frames = sample
    past_frames = past_frames.to(device)
    future_frames = future_frames.to(device)
    assert num_pred == future_frames.shape[1], "Mismatch between ground truth future frames length and num_pred"
    assert num_pred % VPTR_Transformer.num_future_frames == 0, "Mismatch of num_pred and trained Transformer"
    
    past_gt_feats = VPTR_Enc(past_frames)
    
    for i in range(0, num_pred//VPTR_Transformer.num_future_frames):
        pred_future_feats = VPTR_Transformer(past_gt_feats)
        if i == 0:
            pred_frames = VPTR_Dec(pred_future_feats)
        else:
            pred_frames = torch.cat([pred_frames, VPTR_Dec(pred_future_feats)], dim = 1)
        past_gt_feats = pred_future_feats
        
    
    return pred_frames, future_frames


def NAR_BAIR_2_to_28_test_single_iter(sample, VPTR_Enc, VPTR_Dec, VPTR_Transformer, num_pred, device):
    """
    Specifically for BAIR dataset, 2 -> 28 prediction.
    """
    past_frames, future_frames = sample
    past_frames = past_frames.to(device)
    future_frames = future_frames.to(device)
    assert num_pred == future_frames.shape[1], "Mismatch between ground truth future frames length and num_pred"
    #assert num_pred % VPTR_Transformer.num_future_frames == 0, "Mismatch of num_pred and trained Transformer"
    
    pred = []
    #prediction 1
    past_gt_feats = VPTR_Enc(past_frames)
    pred_future_feats = VPTR_Transformer(past_gt_feats)
    pred_frames = VPTR_Dec(pred_future_feats)
    pred.append(pred_frames)
    #prediction 2
    past_frames = pred_frames[:, -2:, ...]
    past_gt_feats = VPTR_Enc(past_frames)
    pred_future_feats = VPTR_Transformer(past_gt_feats)
    pred_frames = VPTR_Dec(pred_future_feats)
    pred.append(pred_frames)
    
    #prediction 3
    past_frames = pred_frames[:, -2:, ...]
    past_gt_feats = VPTR_Enc(past_frames)
    pred_future_feats = VPTR_Transformer(past_gt_feats)
    pred_frames = VPTR_Dec(pred_future_feats)
    pred.append(pred_frames[:, 0:-2, ...])
    pred_frames = torch.cat(pred, dim = 1)
    
    return pred_frames, future_frames         

In [7]:
# %%
# Example usage of inference functions (adjust num_pred as needed)
sample = next(iter(test_loader))
num_pred = 10 

with torch.no_grad():
    if model_flag == 'FAR':
        pred_frames, gt_frames = FAR_RIP_test_single_iter(sample, VPTR_Enc, VPTR_Dec, VPTR_Transformer, num_pred, device)
    elif model_flag == 'NAR':
        pred_frames, gt_frames = NAR_test_single_iter(sample, VPTR_Enc, VPTR_Dec, VPTR_Transformer, num_pred, device)


input_feat_shape: torch.Size([2, 20, 528, 16, 16])
temporal_pos_embed_shape: torch.Size([30, 528])
Input x shape: torch.Size([2, 20, 16, 16, 528])
local_window_pos_embed shape: torch.Size([4, 4, 528])
temporal_pos_embed shape: torch.Size([20, 528])
Resized x shape: torch.Size([2, 20, 8, 8, 528])
x1 shape after norm3: torch.Size([20, 128, 528])
attn_mask shape: torch.Size([20, 20])
x1: torch.Size([20, 128, 528])
temporal_pos_embed[:T, None, :]: torch.Size([20, 1, 528])
Output x shape: torch.Size([2, 20, 8, 8, 528])
Input x shape: torch.Size([2, 20, 8, 8, 528])
local_window_pos_embed shape: torch.Size([4, 4, 528])
temporal_pos_embed shape: torch.Size([20, 528])
Resized x shape: torch.Size([2, 20, 8, 8, 528])
x1 shape after norm3: torch.Size([20, 128, 528])
attn_mask shape: torch.Size([20, 20])
x1: torch.Size([20, 128, 528])
temporal_pos_embed[:T, None, :]: torch.Size([20, 1, 528])
Output x shape: torch.Size([2, 20, 8, 8, 528])
Input x shape: torch.Size([2, 20, 8, 8, 528])
local_window_po

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 8 for tensor number 1 in the list.

In [None]:
# %%
# Visualize predictions
plot_model_result(pred_frames, 'pred_frames', num_pred, n=1)

In [None]:
plot_model_result(gt_frames, 'gt_frames', num_pred, n=1)

In [12]:
import numpy as np
from utils import PSNR, SSIM
import lpips

def pred_ave_metrics(model, data_loader, metric_func, renorm_transform, num_future_frames, num_past_frames, device = 'cuda:0', use_lpips = False, gray_scale = True):
    model = model.to(device)
    model = model.eval()
    ave_metric = np.zeros(num_future_frames)
    sample_num = 0

    with torch.no_grad():
        for idx, sample in enumerate(data_loader, 0):
            past_frames, future_frames = sample
            past_frames = past_frames.to(device)
            future_frames = future_frames.to(device)
            pred = model(past_frames)
            for i in range(0, num_future_frames):
                pred_t = pred[:, i, ...]
                future_frames_t = future_frames[:, i, ...]
                
                if not use_lpips:
                    pred_t = renorm_transform(pred_t)
                    future_frames_t = renorm_transform(future_frames_t)
                elif use_lpips and gray_scale:
                    pred_t = pred_t.repeat(1, 3, 1, 1)
                    future_frames_t = future_frames_t.repeat(1, 3, 1, 1)
                    
                m = metric_func(pred_t, future_frames_t)*pred_t.shape[0]
                try:
                    ave_metric[i] += m.mean()
                except AttributeError:
                    ave_metric[i] += m
                
            sample_num += pred.shape[0]

    ave_metric = ave_metric / sample_num
    return ave_metric

In [13]:
# %%
# Initialize metric functions
loss_fn_alex = lpips.LPIPS(net='alex').to(device)
ssim = SSIM()

# %%
# Create the complete model
model = nn.Sequential(VPTR_Enc, VPTR_Transformer, VPTR_Dec)

# %%
# Calculate metrics
psnr_list = pred_ave_metrics(model, test_loader, PSNR, renorm_transform, num_future_frames, num_past_frames, device, use_lpips=False)
ssim_list = pred_ave_metrics(model, test_loader, ssim, renorm_transform, num_future_frames, num_past_frames, device, use_lpips=False)
lpips_list = pred_ave_metrics(model, test_loader, loss_fn_alex, renorm_transform, num_future_frames, num_past_frames, device, use_lpips=True)

# %%
# Print the calculated metrics
print("PSNR:", psnr_list)
print("SSIM:", ssim_list)
print("LPIPS:", lpips_list)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to C:\Users\sc23gd/.cache\torch\hub\checkpoints\alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:09<00:00, 25.3MB/s] 


Loading model from: c:\Users\sc23gd\.conda\envs\VPTR\Lib\site-packages\lpips\weights\v0.1\alex.pth
input_feat_shape: torch.Size([2, 10, 528, 16, 16])
temporal_pos_embed_shape: torch.Size([30, 528])
Input x shape: torch.Size([2, 10, 16, 16, 528])
local_window_pos_embed shape: torch.Size([4, 4, 528])
temporal_pos_embed shape: torch.Size([10, 528])
Resized x shape: torch.Size([2, 10, 8, 8, 528])
x1 shape after norm3: torch.Size([10, 128, 528])
attn_mask shape: torch.Size([10, 10])
x1: torch.Size([10, 128, 528])
temporal_pos_embed[:T, None, :]: torch.Size([10, 1, 528])
Output x shape: torch.Size([2, 10, 8, 8, 528])
Input x shape: torch.Size([2, 10, 8, 8, 528])
local_window_pos_embed shape: torch.Size([4, 4, 528])
temporal_pos_embed shape: torch.Size([10, 528])
Resized x shape: torch.Size([2, 10, 8, 8, 528])
x1 shape after norm3: torch.Size([10, 128, 528])
attn_mask shape: torch.Size([10, 10])
x1: torch.Size([10, 128, 528])
temporal_pos_embed[:T, None, :]: torch.Size([10, 1, 528])
Output x 

RuntimeError: The size of tensor a (64) must match the size of tensor b (128) at non-singleton dimension 3