In [2]:
import torch
import numpy as np
from PIL import Image
import os
import torch.nn.functional as F
from tqdm import tqdm
from scipy import interpolate
from pathlib import Path
import matplotlib.pyplot as plt
import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### E* Metric:

In [3]:
import gqvr.model.warp_utils as warp_utils
from gqvr.model.core_raft.utils.utils import InputPadder
from gqvr.model.core_raft.raft import RAFT

class RAFT_args:
    mixed_precision = True
    small = False
    alternate_corr = True # Reduces VRAM significantly in forward pass
    dropout = False

raft_args = RAFT_args()
raft_model = RAFT(raft_args)
raft_things_dict = torch.load("./pretrained_ckpts/models/raft-things.pth")
corrected_state_dict = {}
for k, v in raft_things_dict.items():
    k2 = ".".join(k.split(".")[1:])
    corrected_state_dict[k2] = v

raft_model.load_state_dict(corrected_state_dict)
raft_model.eval().requires_grad_(False).to(device)

def compute_video_warping_error(video_path):
    frames = os.listdir(video_path)
    frames.sort()
    frames = [np.array(Image.open(os.path.join(video_path, f)).convert('RGB').resize((512,512), Image.LANCZOS)) for f in frames]

    warping_error = 0
    err = 0
    
    Num = len(frames)
    tensor_frames = torch.stack([torch.from_numpy(frame) for frame in frames])
    # for i in range(Num):
    N = len(tensor_frames)
    indices = torch.linspace(0, N - 1, Num).long()
    extracted_frames = torch.index_select(tensor_frames, 0, indices)
    with torch.no_grad():
        for i in range(Num - 1):
            frame1 = extracted_frames[i]
            frame2 = extracted_frames[i + 1]

            # Calculate optical flow using Farneback method
            img1 = frame1.permute(2,0,1).float().unsqueeze(0).to(device)/ 255.0
            img2 = frame2.permute(2,0,1).float().unsqueeze(0).to(device)/ 255.0
            # img1 = torch.tensor(img2tensor(frame1)).float().to(device)
            # img2 = torch.tensor(img2tensor(frame2)).float().to(device)

            # Downsample the images by a factor of 2
            img1 = F.interpolate(img1, scale_factor=0.5, mode='bilinear', align_corners=False)
            img2 = F.interpolate(img2, scale_factor=0.5, mode='bilinear', align_corners=False)

            padder = InputPadder(img1.shape)
            img1, img2 = padder.pad(img1, img2)

            ### compute fw flow
            
            _, fw_flow = raft_model(img1, img2, iters=20, test_mode=True) # with optical flow model: RAFT
            fw_flow = warp_utils.tensor2img(fw_flow)
            # Clear cache and temporary data
            torch.cuda.empty_cache()

            ### compute bw flow
            _, bw_flow = raft_model(img2, img1, iters=20, test_mode=True) # with optical flow model: RAFT
            bw_flow = warp_utils.tensor2img(bw_flow)
            torch.cuda.empty_cache()

            ### compute occlusion
            fw_occ, warp_img2 = warp_utils.detect_occlusion(bw_flow, fw_flow, img2)
            warp_img2 = torch.tensor(warp_img2).float().to(device)
            fw_occ = torch.tensor(fw_occ).float().to(device)

            ### load flow
            flow = fw_flow

            ### load occlusion mask
            occ_mask = fw_occ
            noc_mask = 1 - occ_mask

            # ipdb.set_trace()   
            diff = (warp_img2- img1) * noc_mask
            diff_squared = diff ** 2

            
            # Calculate the sum and mean
            N = torch.sum(noc_mask)
            if N == 0:
                N = diff_squared.numel()
            # ipdb.set_trace()
            err += torch.sum(diff_squared) / N

    warping_error = err / (len(extracted_frames) - 1)

    return warping_error

use sdp attention as default
keep default attention mode


  from .autonotebook import tqdm as notebook_tqdm
  import pkg_resources


## PSNR/SSIM from piq

In [4]:
import piq
def compute_full_reference_metrics(gt_img, out_img):
    # PSNR / SSIM
    psnr = piq.psnr(out_img, gt_img, data_range=1., reduction='none')
    ssim = piq.ssim(out_img, gt_img, data_range=1.) 
    # lpips = piq.LPIPS(reduction='none')(out_img, gt_img)
    return psnr.item(), ssim.item() # lpips.item()

## PATHs

In [None]:
SD21_burst_results = "/nobackup1/aryan/results/sd21_burst"
QUIVER_RESULTS = "/nobackup1/aryan/results/QUIVER"
QBP_RESULTS = "/nobackup1/aryan/results/QBP_BM3D"

scenes = ['xvfi_boat', 'xvfi_train', 
          'test_00014', 'test_00020', 'test_00021', 'test_00031',
          'small_jetengine', 'small_tank', 'small_explosion_0001', 'small_padlock', 'small_moreguns']

In [30]:
cumulative_psnr = []
cumulative_ssim = []
cumulative_warping_error = []
pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = SD21_burst_results + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = SD21_burst_results + f"/{scene}_out"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert("RGB").resize((512,512), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert("RGB").resize((512,512), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).permute(2,0,1).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).permute(2,0,1).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim = compute_full_reference_metrics(gt_img, out_img)
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
    pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | Warping Error: {warping_error:.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{warping_error:.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_warping_error.append(warping_error.cpu().item())
    pbar.update(1)
pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")

  warp_img2 = torch.tensor(warp_img2).float().to(device)
Scene: xvfi_boat | PSNR: 27.5939 | SSIM: 0.7163 | Warping Error: 0.0017:   9%|▉         | 1/11 [00:06<01:09,  6.93s/it]

xvfi_boat: 27.5939/0.7163/0.0017


Scene: xvfi_train | PSNR: 24.0497 | SSIM: 0.7071 | Warping Error: 0.0034:  18%|█▊        | 2/11 [00:13<01:02,  6.96s/it]

xvfi_train: 24.0497/0.7071/0.0034


Scene: test_00014 | PSNR: 30.2209 | SSIM: 0.8299 | Warping Error: 0.0018:  27%|██▋       | 3/11 [00:57<03:10, 23.85s/it]

test_00014: 30.2209/0.8299/0.0018


Scene: test_00020 | PSNR: 32.0894 | SSIM: 0.9449 | Warping Error: 0.0012:  36%|███▋      | 4/11 [01:41<03:42, 31.72s/it]

test_00020: 32.0894/0.9449/0.0012


Scene: test_00021 | PSNR: 32.0890 | SSIM: 0.8921 | Warping Error: 0.0012:  45%|████▌     | 5/11 [02:25<03:36, 36.15s/it]

test_00021: 32.0890/0.8921/0.0012


Scene: test_00031 | PSNR: 30.4578 | SSIM: 0.8452 | Warping Error: 0.0018:  55%|█████▍    | 6/11 [05:26<07:07, 85.45s/it]

test_00031: 30.4578/0.8452/0.0018


Scene: small_jetengine | PSNR: 30.1210 | SSIM: 0.8531 | Warping Error: 0.0025:  64%|██████▎   | 7/11 [05:46<04:15, 63.92s/it]

small_jetengine: 30.1210/0.8531/0.0025


Scene: small_tank | PSNR: 26.1159 | SSIM: 0.8642 | Warping Error: 0.0030:  73%|███████▎  | 8/11 [05:52<02:16, 45.59s/it]     

small_tank: 26.1159/0.8642/0.0030


Scene: small_explosion_0001 | PSNR: 28.9755 | SSIM: 0.8798 | Warping Error: 0.0044:  82%|████████▏ | 9/11 [06:06<01:11, 35.58s/it]

small_explosion_0001: 28.9755/0.8798/0.0044


Scene: small_padlock | PSNR: 34.6854 | SSIM: 0.9637 | Warping Error: 0.0014:  91%|█████████ | 10/11 [06:12<00:26, 26.54s/it]      

small_padlock: 34.6854/0.9637/0.0014


Scene: small_moreguns | PSNR: 31.7584 | SSIM: 0.9148 | Warping Error: 0.0011: 100%|██████████| 11/11 [06:53<00:00, 37.57s/it]

small_moreguns: 31.7584/0.9148/0.0011
Average PSNR: 29.8324
Average SSIM: 0.8556
Average Warping Error: 0.0021





In [None]:
cumulative_psnr = []
cumulative_ssim = []
cumulative_warping_error = []
pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = QUIVER_RESULTS + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = QUIVER_RESULTS + f"/{scene}_out"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert("RGB").resize((512,512), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert("RGB").resize((512,512), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).permute(2,0,1).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).permute(2,0,1).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim = compute_full_reference_metrics(gt_img, out_img)
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
    pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | Warping Error: {warping_error:.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{warping_error:.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_warping_error.append(warping_error.cpu().item())
    pbar.update(1)
pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")

  0%|          | 0/11 [01:24<?, ?it/s]
  warp_img2 = torch.tensor(warp_img2).float().to(device)
Scene: xvfi_boat | PSNR: 20.8117 | SSIM: 0.6389 | Warping Error: 0.0009:   9%|▉         | 1/11 [00:07<01:13,  7.30s/it]

xvfi_boat: 20.8117/0.6389/0.0009


Scene: xvfi_train | PSNR: 25.1905 | SSIM: 0.8624 | Warping Error: 0.0018:  18%|█▊        | 2/11 [00:14<01:05,  7.30s/it]

xvfi_train: 25.1905/0.8624/0.0018


Scene: test_00014 | PSNR: 26.0922 | SSIM: 0.8566 | Warping Error: 0.0015:  27%|██▋       | 3/11 [01:00<03:20, 25.01s/it]

test_00014: 26.0922/0.8566/0.0015


Scene: test_00020 | PSNR: 23.3382 | SSIM: 0.9169 | Warping Error: 0.0028:  36%|███▋      | 4/11 [01:46<03:52, 33.22s/it]

test_00020: 23.3382/0.9169/0.0028


Scene: test_00021 | PSNR: 23.4950 | SSIM: 0.8674 | Warping Error: 0.0011:  45%|████▌     | 5/11 [02:32<03:47, 37.90s/it]

test_00021: 23.4950/0.8674/0.0011


Scene: test_00031 | PSNR: 27.3181 | SSIM: 0.8581 | Warping Error: 0.0015:  55%|█████▍    | 6/11 [05:42<07:27, 89.53s/it]

test_00031: 27.3181/0.8581/0.0015


Scene: small_jetengine | PSNR: 20.5879 | SSIM: 0.8611 | Warping Error: 0.0017:  64%|██████▎   | 7/11 [06:02<04:27, 66.96s/it]

small_jetengine: 20.5879/0.8611/0.0017


Scene: small_tank | PSNR: 22.3909 | SSIM: 0.8533 | Warping Error: 0.0022:  73%|███████▎  | 8/11 [06:09<02:23, 47.76s/it]     

small_tank: 22.3909/0.8533/0.0022


Scene: small_explosion_0001 | PSNR: 17.4564 | SSIM: 0.5725 | Warping Error: 0.0021:  82%|████████▏ | 9/11 [06:23<01:14, 37.27s/it]

small_explosion_0001: 17.4564/0.5725/0.0021


Scene: small_padlock | PSNR: 24.8264 | SSIM: 0.9455 | Warping Error: 0.0016:  91%|█████████ | 10/11 [06:30<00:27, 27.80s/it]      

small_padlock: 24.8264/0.9455/0.0016


Scene: small_moreguns | PSNR: 15.2187 | SSIM: 0.7180 | Warping Error: 0.0019: 100%|██████████| 11/11 [07:13<00:00, 39.37s/it]

small_moreguns: 15.2187/0.7180/0.0019
Average PSNR: 22.4296
Average SSIM: 0.8137
Average Warping Error: 0.0017





In [None]:
cumulative_psnr = []
cumulative_ssim = []
cumulative_warping_error = []
pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = SD21_burst_results + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = QBP_RESULTS + f"/{scene}"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim = compute_full_reference_metrics(gt_img.unsqueeze(0), out_img.unsqueeze(0))
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
    pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | Warping Error: {warping_error:.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{warping_error:.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_warping_error.append(warping_error.detach().cpu().item())
    pbar.update(1)
pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")

Scene: xvfi_train | PSNR: 12.8230 | SSIM: 0.2287 | Warping Error: 0.0003:  18%|█▊        | 2/11 [00:45<03:24, 22.69s/it]
  warp_img2 = torch.tensor(warp_img2).float().to(device)


xvfi_boat: 10.9781/0.4913/0.0007




xvfi_train: 13.0492/0.2487/0.0003




test_00014: 13.2860/0.5317/0.0006




test_00020: 18.6312/0.7070/0.0019




test_00021: 13.3375/0.4337/0.0007




test_00031: 14.0013/0.4696/0.0007




small_jetengine: 14.7061/0.1623/0.0002




small_tank: 8.8698/0.4778/0.0023




small_explosion_0001: 19.7825/0.1607/0.0004




small_padlock: 11.2782/0.5458/0.0005


Scene: small_moreguns | PSNR: 9.2624 | SSIM: 0.6958 | Warping Error: 0.0024: 100%|██████████| 11/11 [12:52<00:00, 70.22s/it]

small_moreguns: 9.2624/0.6958/0.0024
Average PSNR: 13.3802
Average SSIM: 0.4477
Average Warping Error: 0.0010





In [None]:
SD21_Stage2_results = "/nobackup1/aryan/results/sd21/burst_testset_evaluation_s2"
cumulative_psnr = []
cumulative_ssim = []
cumulative_warping_error = []
pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = SD21_burst_results + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = QBP_RESULTS + f"/{scene}"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim = compute_full_reference_metrics(gt_img.unsqueeze(0), out_img.unsqueeze(0))
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
    pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | Warping Error: {warping_error:.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{warping_error:.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_warping_error.append(warping_error.detach().cpu().item())
    pbar.update(1)
pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")