In [None]:
import sys
import os
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)
import cv2
import io
import imageio
from src.data import CS_VideoData 
from src.dino_f import Dino_f
import argparse
import torch
import pytorch_lightning as pl
import os
import matplotlib.cm as cm
import glob
import os.path as osp
from tqdm import tqdm
import numpy as np
from PIL import Image
import yaml
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as colors
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import einops
import math
import torch.nn as nn
IGNORE_LABEL = 255

In [None]:
def denormalize_images(tensor):
    """
    Denormalize image tensor from ImageNet normalization
    Args:
        tensor: shape [B,T,C,H,W]
    Returns:
        denormalized tensor with values in [0,1]
    """
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1,1,3,1,1).to(tensor.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1,1,3,1,1).to(tensor.device)
    # Denormalize: x = (norm_x * std) + mean
    denorm_images = tensor * std + mean
    # Clip to valid image range [0,1]
    denorm_images = torch.clamp(denorm_images, 0, 1)
    return denorm_images

In [None]:
args = argparse.Namespace()
# Data Parameters
args.data_path = '/home/ubuntu/cityscapes/leftImg8bit_sequence'
args.dst_path = None
args.img_size = (448,896)
args.num_workers = 8
args.num_workers_val = None
args.sequence_length = 5
args.batch_size = 1
args.random_crop = True
args.random_horizontal_flip = True
args.random_time_flip = False
args.timestep_augm = None
args.no_timestep_augm = False
args.use_fc_bias = True
args.feature_extractor = 'dino'
args.dinov2_variant = 'vitb14_reg'
args.d_layers = [2,5,8,11]
args.hidden_dim = 1152
args.heads = 8
args.layers = 12
args.dropout = 0.1
args.loss_type = 'SmoothL1'
args.beta_smoothl1 = 0.1
args.attn_dropout = 0.3
args.step = 1
args.masking = 'simple_replace'
args.train_mask_mode = 'fullmask'
args.seperable_attention = True
args.seperable_window_size = 1
args.train_mask_frames = 1
args.output_activation = 'none'
args.use_first_last = False
args.down_up_sample = False
args.pca_ckpt = "/home/ubuntu/DinoFeatPred/pca/pca_448_l[2_5_8_11]_1152.pth"
args.crop_feats = False
args.sliding_window_inference = False
args.use_bn = True
args.use_cls = False
args.nfeats = 256
args.dpt_out_channels = [128,256,512,512]
# Training parameters
args.max_epochs = 800
args.seed = 123
args.single_step_sample_train = True
args.precision = '32-true'
args.ckpt = None
args.num_gpus = 1
args.accum_iter = 1
args.warmup_p = 0.0
args.lr_base = 1e-3
args.weight_decay = 0
args.scheduler = "cosine"
args.optimizer = "adam"
args.gclip = 1.0
args.evaluate = True
args.eval_midterm = False
args.eval_mode = True
args.use_val_to_train = False
args.use_train_to_val = False
args.vis_attn = True

obj_colors_dict = {
    0: [128, 64, 128],
    1: [244, 35, 232],
    2: [70, 70, 70],
    3: [102, 102, 156],
    4: [190, 153, 153],
    5: [153, 153, 153],
    6: [250, 170, 30],
    7: [220, 220, 0],
    8: [107, 142, 35],
    9: [152, 251, 152],
    10: [70, 130, 180],
    11: [220, 20, 60],
    12: [255, 0, 0],
    13: [0, 0, 142],
    14: [0, 0, 70],
    15: [0, 60, 100],
    16: [0, 80, 100],
    17: [0, 0, 230],
    18: [119, 11, 32]
}
args.class_colors_arr = np.array(list(obj_colors_dict.values()),dtype=np.uint8)

args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu'

## Visualizations

In [None]:
# SEG
args.num_classes = 19
args.eval_modality = "segm"
args.head_ckpt = "/path/to/Checkpoints/head_pca1152.ckpt"

# # Depth 
# args.num_classes = 256
# args.eval_modality = "depth"
# args.head_ckpt = "/path/to/Checkpoints/head_depth_pca1152.ckpt"

# # Normals
# args.num_classes = 3
# args.eval_modality = "surface_normals"
# args.head_ckpt = "/path/to/Checkpoints/head_normals_pca1152.ckpt"

dataset = CS_VideoData(arguments=args,subset="val",batch_size=args.batch_size)
val_dl = dataset.val_dataloader()

In [None]:
ckpt_path = "/path/to/dinof_highres.ckpt"
model = Dino_f.load_from_checkpoint(ckpt_path,args=args,strict=False).to('cuda:7')
model.eval()

### Scene

In [None]:
n = torch.randint(low=0,high=500,size=(1,))
print(n)

In [None]:
s = 2 # ==>> 341, 223, 99  # 24 turn  [13, 26, 266, 276,  333 ]
for i, batch in enumerate(val_dl):
    # data, gt_img, gt_modal = batch
    data, gt_img, gt_modal, gt_future_img, gt_path = batch
    print(i, gt_path)
    if i == s:
        break
B, sl, C, H, W = data.shape
print(data.shape)
print(gt_img.shape)
print(gt_modal.shape)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=args.sequence_length, figsize=(20, 5))
denorm_frames = denormalize_images(data)
for i in range(args.sequence_length):
    axes[i].imshow(denorm_frames[0,i].permute(1,2,0).cpu().numpy())
    axes[i].axis('off')
    axes[i].set_title(f"Frame {i}")
plt.tight_layout()
plt.show()


In [None]:
if args.eval_modality == "segm":
    vis_gt_segm = np.zeros((1024,2048,3),dtype=np.uint8)
    gt_mask = gt_modal.squeeze()!=IGNORE_LABEL
    vis_gt_segm[gt_mask,:] =  args.class_colors_arr[gt_modal.squeeze()[gt_mask]]
    T.ToPILImage(mode='RGB')(vis_gt_segm)
    plt.imshow(vis_gt_segm)
    plt.show()
elif args.eval_modality == "depth":
    plt.imshow(gt_modal.squeeze(),cmap='turbo')
    plt.show()
elif args.eval_modality == "surface_normals":
    plt.imshow(F.normalize(gt_modal.squeeze(),p=2,dim=0).cpu().numpy().transpose(1,2,0))
    # plt.imshow(gt_modal.squeeze().cpu().numpy().transpose(1,2,0))
    plt.show()

In [None]:
with torch.no_grad():
    x = model.preprocess(data.to(args.device))
    print("PCA_feats shape",x.shape)
    masked_soft_tokens, mask = model.get_mask_tokens(x, mode="full_mask",mask_frames=1)
    mask = mask.to(x.device)
    if model.args.vis_attn:
        _, final_tokens, attn_weights = model.forward(x, masked_soft_tokens, mask)
    else:
        loss, final_tokens = model.forward(x, masked_soft_tokens, mask)
        print(loss)
    prediction = model.postprocess(final_tokens)
print(prediction.shape)

In [None]:
smoothl1 = nn.SmoothL1Loss(reduction='none', beta=0.1)
loss_tokens = smoothl1(final_tokens[:,-1],x[:,-1]).mean(dim=-1)
plt.imshow(loss_tokens.squeeze().cpu().numpy())

In [None]:
# PREDICTED IMAGE
pred_feats = prediction[:,-1]
pred_feats_list = [pred_feats[:,:,:,i*model.feature_dim:(i+1)*model.feature_dim] for i in range(model.d_num_layers)]
pred_feats_list = [einops.rearrange(x, 'b h w c -> b (h w) c',h=H//model.patch_size, w=W//model.patch_size) for x in pred_feats_list]
pred_modal = model.head(pred_feats_list,model.patch_h,model.patch_w)
pred_modal = F.interpolate(pred_modal, size=(1024,2048), mode='bilinear', align_corners=False)
# ORACLE IMAGE
oracle_feats = model.extract_features(gt_img.to(args.device))
oracle_feats_list = [oracle_feats[:,:,i*model.feature_dim:(i+1)*model.feature_dim] for i in range(model.d_num_layers)]
oracle_pred = model.head(oracle_feats_list,model.patch_h,model.patch_w)
oracle_pred = F.interpolate(oracle_pred, size=(1024,2048), mode='bilinear', align_corners=False)
if args.eval_modality == "segm":
    pred_modal = torch.argmax(pred_modal, dim=1)
    pred_modal_rgb = args.class_colors_arr[pred_modal.squeeze().cpu().numpy()]
    plt.figure()
    plt.imshow(T.ToPILImage(mode='RGB')(pred_modal_rgb))
    oracle_pred = torch.argmax(oracle_pred, dim=1)
    oracle_pred_rgb = args.class_colors_arr[oracle_pred.squeeze().cpu().numpy()]
    plt.figure()
    plt.imshow(T.ToPILImage(mode='RGB')(oracle_pred_rgb))
elif args.eval_modality == "depth":
    pred_modal = torch.argmax(pred_modal, dim=1)
    pred_modal_rgb = cm.turbo(pred_modal.squeeze().cpu().numpy())[...,:3]
    plt.figure()
    plt.imshow(pred_modal_rgb)
    oracle_pred = torch.argmax(oracle_pred, dim=1)
    oracle_pred_rgb = cm.turbo(oracle_pred.squeeze().cpu().numpy())[...,:3]
    plt.figure()
    plt.imshow(oracle_pred_rgb)
elif args.eval_modality == "surface_normals":
    pred_modal_rgb = F.normalize(pred_modal,p=2,dim=1).cpu().numpy().transpose(0,2,3,1)
    plt.figure()
    plt.imshow(pred_modal_rgb.squeeze())
    oracle_pred_rgb = F.normalize(oracle_pred,p=2,dim=1).cpu().numpy().transpose(0,2,3,1)
    plt.figure()
    plt.imshow(oracle_pred_rgb.squeeze())

In [None]:
name = osp.basename(gt_path[0])
unroll_steps = 1 # [1 or 3]
assert unroll_steps in [1,3]
with torch.no_grad():
    x = model.preprocess(data.to(args.device))
    print(x.shape)
    for i in range(unroll_steps):
        masked_soft_tokens, mask = model.get_mask_tokens(x, mode="full_mask",mask_frames=1)
        mask = mask.to(x.device)
        if model.args.vis_attn:
            _, final_tokens, attn_weights = model.forward(x, masked_soft_tokens, mask)
        else:
            loss, final_tokens = model.forward(x, masked_soft_tokens, mask)
        x[:,-1] = final_tokens[:,-1]
        x[:,0:-1] = x[:,1:].clone()
    prediction = model.postprocess(x)
time = '(t+'+str(unroll_steps*3)+')'
modal_name =  name.replace("leftImg8bit","pred_"+args.eval_modality+time)
oracle_modal_name = name.replace("leftImg8bit","oracle_"+args.eval_modality+time)
# PREDICTED IMAGE
pred_feats = prediction[:,-1]
pred_feats_list = [pred_feats[:,:,:,i*model.feature_dim:(i+1)*model.feature_dim] for i in range(model.d_num_layers)]
pred_feats_list = [einops.rearrange(x, 'b h w c -> b (h w) c',h=H//model.patch_size, w=W//model.patch_size) for x in pred_feats_list]
pred_modal = model.head(pred_feats_list,model.patch_h,model.patch_w)
pred_modal = F.interpolate(pred_modal, size=(1024,2048), mode='bilinear', align_corners=False)
# ORACLE IMAGE
if unroll_steps == 1:
    oracle_feats = model.extract_features(gt_img.to(args.device))
    oracle_feats_list = [oracle_feats[:,:,i*model.feature_dim:(i+1)*model.feature_dim] for i in range(model.d_num_layers)]
    oracle_pred = model.head(oracle_feats_list,model.patch_h,model.patch_w)
    oracle_pred = F.interpolate(oracle_pred, size=(1024,2048), mode='bilinear', align_corners=False)
elif unroll_steps == 3:
    # Future Image
    future_feats = model.extract_features(gt_future_img.to(args.device))
    future_feats_list = [future_feats[:,:,i*model.feature_dim:(i+1)*model.feature_dim] for i in range(model.d_num_layers)]
    future_pred = model.head(future_feats_list,model.patch_h,model.patch_w)
    future_pred = F.interpolate(future_pred, size=(1024,2048), mode='bilinear', align_corners=False)
else:
    assert False
if args.eval_modality == "segm":
    pred_modal = torch.argmax(pred_modal, dim=1)
    pred_modal_rgb = args.class_colors_arr[pred_modal.squeeze().cpu().numpy()]
    plt.figure()
    plt.imshow(T.ToPILImage(mode='RGB')(pred_modal_rgb))
    T.ToPILImage(mode='RGB')(pred_modal_rgb).save(modal_name)
    if unroll_steps == 1:
        oracle_pred = torch.argmax(oracle_pred, dim=1)
        oracle_pred_rgb = args.class_colors_arr[oracle_pred.squeeze().cpu().numpy()]
        plt.figure()
        plt.imshow(T.ToPILImage(mode='RGB')(oracle_pred_rgb))
        T.ToPILImage(mode='RGB')(oracle_pred_rgb).save(oracle_modal_name)
    elif unroll_steps==3:
        future_pred = torch.argmax(future_pred, dim=1)
        future_pred_rgb = args.class_colors_arr[future_pred.squeeze().cpu().numpy()]
        plt.figure()
        plt.imshow(T.ToPILImage(mode='RGB')(future_pred_rgb))
        T.ToPILImage(mode='RGB')(future_pred_rgb).save(oracle_modal_name)
elif args.eval_modality == "depth":
    pred_modal = torch.argmax(pred_modal, dim=1)
    pred_modal_rgb = cm.turbo(pred_modal.squeeze().cpu().numpy())[...,:3]
    plt.figure()
    plt.imshow(pred_modal_rgb)
    T.ToPILImage()(pred_modal_rgb).save(modal_name)
    if unroll_steps == 1:
        oracle_pred = torch.argmax(oracle_pred, dim=1)
        oracle_pred_rgb = cm.turbo(oracle_pred.squeeze().cpu().numpy())[...,:3]
        plt.figure()
        plt.imshow(oracle_pred_rgb)
        T.ToPILImage()(oracle_pred_rgb).save(oracle_modal_name)
    elif unroll_steps==3:
        future_pred = torch.argmax(future_pred, dim=1)
        future_pred_rgb = cm.turbo(future_pred.squeeze().cpu().numpy())[...,:3]
        plt.figure()
        plt.imshow(future_pred_rgb)
        T.ToPILImage()(future_pred_rgb).save(oracle_modal_name)
elif args.eval_modality == "surface_normals":
    pred_modal_rgb = F.normalize(pred_modal,p=2,dim=1).cpu().numpy().transpose(0,2,3,1)
    plt.figure()
    plt.imshow(pred_modal_rgb.squeeze())
    T.ToPILImage()(pred_modal_rgb.squeeze()).save(modal_name)
    if unroll_steps == 1:
        oracle_pred_rgb = F.normalize(oracle_pred,p=2,dim=1).cpu().numpy().transpose(0,2,3,1)
        plt.figure()
        plt.imshow(oracle_pred_rgb.squeeze())
        T.ToPILImage()(oracle_pred_rgb.squeeze()).save(oracle_modal_name)
    elif unroll_steps==3:
        future_pred_rgb = F.normalize(future_pred,p=2,dim=1).cpu().numpy().transpose(0,2,3,1)
        plt.figure()
        plt.imshow(future_pred_rgb.squeeze())
        T.ToPILImage()(future_pred_rgb.squeeze()).save(oracle_modal_name)


In [None]:
name = osp.basename(gt_path[0])
f_name = name.split('_leftImg8bit')[0]
unroll_steps = 16 # [1 or 3]
context_length = data.shape[1] - 1
denorm_data = denormalize_images(data)
for i in range(context_length):
    # Save predictions
    time = '(t-'+str(9-i*3)+')'
    context_frame = denorm_data[0, i] # Assuming batch size 1
    ctx_name = name.replace("leftImg8bit","context"+time)
    ctx_path =  osp.join("/home/ubuntu/Dino_Predictions_Unroll",f_name, ctx_name)
    os.makedirs(osp.dirname(ctx_path),exist_ok=True)
    T.ToPILImage()(context_frame).save(ctx_path)
with torch.no_grad():
    x = model.preprocess(data.to(args.device))
    print(x.shape)
    for i in range(unroll_steps):
        masked_soft_tokens, mask = model.get_mask_tokens(x, mode="full_mask",mask_frames=1)
        mask = mask.to(x.device)
        if model.args.vis_attn:
            _, final_tokens, attn_weights = model.forward(x, masked_soft_tokens, mask)
        else:
            loss, final_tokens = model.forward(x, masked_soft_tokens, mask)
        x[:,-1] = final_tokens[:,-1]
        x[:,0:-1] = x[:,1:].clone()
        prediction = model.postprocess(x)
        time = '(t+'+str(3+i*3)+')'
        modal_name =  name.replace("leftImg8bit","pred_"+args.eval_modality+time)
        modal_path =  osp.join("/home/ubuntu/Dino_Predictions_Unroll",f_name, modal_name)
        os.makedirs(osp.dirname(modal_path),exist_ok=True)
        # PREDICTED IMAGE
        pred_feats = prediction[:,-1]
        pred_feats_list = [pred_feats[:,:,:,i*model.feature_dim:(i+1)*model.feature_dim] for i in range(model.d_num_layers)]
        pred_feats_list = [einops.rearrange(x, 'b h w c -> b (h w) c',h=H//model.patch_size, w=W//model.patch_size) for x in pred_feats_list]
        pred_modal = model.head(pred_feats_list,model.patch_h,model.patch_w)
        pred_modal = F.interpolate(pred_modal, size=(1024,2048), mode='bilinear', align_corners=False)
        pred_modal = torch.argmax(pred_modal, dim=1)
        pred_modal_rgb = args.class_colors_arr[pred_modal.squeeze().cpu().numpy()]
        T.ToPILImage(mode='RGB')(pred_modal_rgb).save(modal_path)
        # pred_modal = torch.argmax(pred_modal, dim=1)
        # pred_modal_rgb = cm.turbo(pred_modal.squeeze().cpu().numpy())[...,:3]
        # T.ToPILImage(mode='RGB')(pred_modal_rgb).save(modal_path)
        # pred_modal_rgb = F.normalize(pred_modal,p=2,dim=1).cpu().numpy().transpose(0,2,3,1)
        # T.ToPILImage()(pred_modal_rgb.squeeze()).save(modal_path)