In [1]:
import torch
import numpy as np
from PIL import Image
import os
import torch.nn.functional as F
from tqdm import tqdm
from scipy import interpolate
from pathlib import Path
import matplotlib.pyplot as plt
import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### E* Metric:

In [2]:
import gqvr.model.warp_utils as warp_utils
from gqvr.model.core_raft.utils.utils import InputPadder
from gqvr.model.core_raft.raft import RAFT

class RAFT_args:
    mixed_precision = True
    small = False
    alternate_corr = True # Reduces VRAM significantly in forward pass
    dropout = False

raft_args = RAFT_args()
raft_model = RAFT(raft_args)
raft_things_dict = torch.load("./pretrained_ckpts/models/raft-things.pth")
corrected_state_dict = {}
for k, v in raft_things_dict.items():
    k2 = ".".join(k.split(".")[1:])
    corrected_state_dict[k2] = v

raft_model.load_state_dict(corrected_state_dict)
raft_model.eval().requires_grad_(False).to(device)

def compute_video_warping_error(video_path):
    frames = os.listdir(video_path)
    frames.sort()
    frames = [np.array(Image.open(os.path.join(video_path, f)).convert('RGB').resize((512,512), Image.LANCZOS)) for f in frames]

    warping_error = 0
    err = 0
    
    Num = len(frames)
    tensor_frames = torch.stack([torch.from_numpy(frame) for frame in frames])
    # for i in range(Num):
    N = len(tensor_frames)
    indices = torch.linspace(0, N - 1, Num).long()
    extracted_frames = torch.index_select(tensor_frames, 0, indices)
    with torch.no_grad():
        for i in range(Num - 1):
            frame1 = extracted_frames[i]
            frame2 = extracted_frames[i + 1]

            # Calculate optical flow using Farneback method
            img1 = frame1.permute(2,0,1).float().unsqueeze(0).to(device)/ 255.0
            img2 = frame2.permute(2,0,1).float().unsqueeze(0).to(device)/ 255.0
            # img1 = torch.tensor(img2tensor(frame1)).float().to(device)
            # img2 = torch.tensor(img2tensor(frame2)).float().to(device)

            # Downsample the images by a factor of 2
            img1 = F.interpolate(img1, scale_factor=0.5, mode='bilinear', align_corners=False)
            img2 = F.interpolate(img2, scale_factor=0.5, mode='bilinear', align_corners=False)

            padder = InputPadder(img1.shape)
            img1, img2 = padder.pad(img1, img2)

            ### compute fw flow
            
            _, fw_flow = raft_model(img1, img2, iters=20, test_mode=True) # with optical flow model: RAFT
            fw_flow = warp_utils.tensor2img(fw_flow)
            # Clear cache and temporary data
            torch.cuda.empty_cache()

            ### compute bw flow
            _, bw_flow = raft_model(img2, img1, iters=20, test_mode=True) # with optical flow model: RAFT
            bw_flow = warp_utils.tensor2img(bw_flow)
            torch.cuda.empty_cache()

            ### compute occlusion
            fw_occ, warp_img2 = warp_utils.detect_occlusion(bw_flow, fw_flow, img2)
            warp_img2 = torch.tensor(warp_img2).float().to(device)
            fw_occ = torch.tensor(fw_occ).float().to(device)

            ### load flow
            flow = fw_flow

            ### load occlusion mask
            occ_mask = fw_occ
            noc_mask = 1 - occ_mask

            # ipdb.set_trace()   
            diff = (warp_img2- img1) * noc_mask
            diff_squared = diff ** 2

            
            # Calculate the sum and mean
            N = torch.sum(noc_mask)
            if N == 0:
                N = diff_squared.numel()
            # ipdb.set_trace()
            err += torch.sum(diff_squared) / N

    warping_error = err / (len(extracted_frames) - 1)

    return warping_error

use sdp attention as default
keep default attention mode


  from .autonotebook import tqdm as notebook_tqdm
  import pkg_resources


## PSNR/SSIM from piq

In [3]:
import piq
def compute_full_reference_metrics(gt_img, out_img):
    # PSNR / SSIM
    psnr = piq.psnr(out_img, gt_img, data_range=1., reduction='none')
    ssim = piq.ssim(out_img, gt_img, data_range=1.) 
    lpips = piq.LPIPS(reduction='none')(out_img, gt_img)
    return psnr.item(), ssim.item(), lpips.item()

## PATHs

In [4]:
SD21_burst_results = "/nobackup1/aryan/results/sd21_burst"
QUIVER_RESULTS = "/nobackup1/aryan/results/QUIVER"
QBP_RESULTS = "/nobackup1/aryan/results/QBP_BM3D"

scenes = ['xvfi_boat', 'xvfi_train', 
          'test_00014', 'test_00020', 'test_00021', 'test_00031',
          'small_jetengine', 'small_tank', 'small_explosion_0001', 'small_padlock', 'small_moreguns']

In [11]:
cumulative_psnr = []
cumulative_ssim = []
cumulative_lpips = []
cumulative_warping_error = []
# pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = SD21_burst_results + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = SD21_burst_results + f"/{scene}_out"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    lpips_list = []
    # warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert("RGB").resize((512,512), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert("RGB").resize((512,512), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).permute(2,0,1).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).permute(2,0,1).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim, lpips = compute_full_reference_metrics(gt_img, out_img)
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
        lpips_list.append(lpips)
    # pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | LPIPS: {np.mean(lpips_list):.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{np.mean(lpips_list):.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_lpips.append(np.mean(lpips_list))
    # cumulative_warping_error.append(warping_error.cpu().item())
    # pbar.update(1)
# pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average LPIPS: {np.mean(cumulative_lpips):.4f}")
# print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")



xvfi_boat: 27.5939/0.7163/0.4278
xvfi_train: 24.0497/0.7071/0.4400
test_00014: 30.2209/0.8299/0.3442
test_00020: 32.0894/0.9449/0.2055
test_00021: 32.0890/0.8921/0.3251
test_00031: 30.4578/0.8452/0.3102
small_jetengine: 30.1210/0.8531/0.3631
small_tank: 26.1159/0.8642/0.3162
small_explosion_0001: 28.9755/0.8798/0.3917
small_padlock: 34.6854/0.9637/0.2580
small_moreguns: 31.7584/0.9148/0.2486
Average PSNR: 29.8324
Average SSIM: 0.8556
Average LPIPS: 0.3300


In [10]:
cumulative_psnr = []
cumulative_ssim = []
cumulative_lpips = []
# cumulative_warping_error = []
# pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = QUIVER_RESULTS + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = QUIVER_RESULTS + f"/{scene}_out"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    lpips_list = []
    # warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert("RGB").resize((512,512), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert("RGB").resize((512,512), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).permute(2,0,1).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).permute(2,0,1).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim, lpips = compute_full_reference_metrics(gt_img, out_img)
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
        lpips_list.append(lpips)
    # pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | LPIPS: {np.mean(lpips_list):.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{np.mean(lpips_list):.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_lpips.append(np.mean(lpips_list))
    # cumulative_warping_error.append(warping_error.cpu().item())
#     pbar.update(1)
# pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average LPIPS: {np.mean(cumulative_lpips):.4f}")
# print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")



xvfi_boat: 20.8117/0.6389/0.6878
xvfi_train: 25.1905/0.8624/0.4623
test_00014: 26.0922/0.8566/0.4293
test_00020: 23.3382/0.9169/0.2798
test_00021: 23.4950/0.8674/0.3236
test_00031: 27.3181/0.8581/0.4303
small_jetengine: 20.5879/0.8611/0.3755
small_tank: 22.3909/0.8533/0.4166
small_explosion_0001: 17.4564/0.5725/0.4900
small_padlock: 24.8264/0.9455/0.3319
small_moreguns: 15.2187/0.7180/0.4932
Average PSNR: 22.4296
Average SSIM: 0.8137
Average LPIPS: 0.4291


In [13]:
cumulative_psnr = []
cumulative_ssim = []
cumulative_lpips = []
cumulative_warping_error = []
# pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = SD21_burst_results + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = QBP_RESULTS + f"/{scene}"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    lpips_list = []
    # warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim, thisLpips = compute_full_reference_metrics(gt_img.unsqueeze(0), out_img.unsqueeze(0))
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
        lpips_list.append(thisLpips)
    # pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | LPIPS: {np.mean(lpips):.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{np.mean(lpips_list):.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_lpips.append(np.mean(lpips_list))
    # cumulative_warping_error.append(warping_error.detach().cpu().item())
    # pbar.update(1)
# pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average LPIPS: {np.mean(cumulative_lpips):.4f}")
# print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")



xvfi_boat: 8.3300/0.3637/0.6899
xvfi_train: 12.5565/0.1992/0.6039
test_00014: 12.7536/0.4321/0.5022
test_00020: 13.3873/0.6029/0.3731
test_00021: 18.2547/0.4847/0.4233
test_00031: 11.3371/0.3526/0.5730
small_jetengine: 14.4696/0.1396/0.6288
small_tank: 8.8035/0.4691/0.3818
small_explosion_0001: 17.9341/0.1378/0.4911
small_padlock: 10.4272/0.4686/0.3911
small_moreguns: 9.2464/0.6936/0.3986
Average PSNR: 12.5000
Average SSIM: 0.3949
Average LPIPS: 0.4961


In [5]:
SD21_Stage2_results = "/nobackup1/aryan/results/sd21/burst_testset_evaluation_s2"
cumulative_psnr = []
cumulative_ssim = []
cumulative_warping_error = []
pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = SD21_burst_results + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = SD21_Stage2_results + f"/{scene}"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim = compute_full_reference_metrics(gt_img.unsqueeze(0), out_img.unsqueeze(0))
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
    pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | Warping Error: {warping_error:.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{warping_error:.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_warping_error.append(warping_error.detach().cpu().item())
    pbar.update(1)
pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  warp_img2 = torch.tensor(warp_img2).float().to(device)
Scene: xvfi_boat | PSNR: 27.4399 | SSIM: 0.7111 | Warping Error: 0.0018:   9%|▉         | 1/11 [00:07<01:15,  7.52s/it]

xvfi_boat: 27.4399/0.7111/0.0018


Scene: xvfi_train | PSNR: 24.2301 | SSIM: 0.7153 | Warping Error: 0.0036:  18%|█▊        | 2/11 [00:15<01:10,  7.82s/it]

xvfi_train: 24.2301/0.7153/0.0036


Scene: test_00014 | PSNR: 30.1029 | SSIM: 0.8293 | Warping Error: 0.0020:  27%|██▋       | 3/11 [01:06<03:38, 27.29s/it]

test_00014: 30.1029/0.8293/0.0020


Scene: test_00020 | PSNR: 32.3760 | SSIM: 0.9463 | Warping Error: 0.0012:  36%|███▋      | 4/11 [01:52<04:05, 35.05s/it]

test_00020: 32.3760/0.9463/0.0012


Scene: test_00021 | PSNR: 32.0519 | SSIM: 0.8960 | Warping Error: 0.0012:  45%|████▌     | 5/11 [02:45<04:07, 41.24s/it]

test_00021: 32.0519/0.8960/0.0012


Scene: test_00031 | PSNR: 30.0710 | SSIM: 0.8439 | Warping Error: 0.0019:  55%|█████▍    | 6/11 [06:14<08:11, 98.26s/it]

test_00031: 30.0710/0.8439/0.0019


Scene: small_jetengine | PSNR: 30.1096 | SSIM: 0.8549 | Warping Error: 0.0029:  64%|██████▎   | 7/11 [06:37<04:55, 73.85s/it]

small_jetengine: 30.1096/0.8549/0.0029


Scene: small_tank | PSNR: 26.2065 | SSIM: 0.8655 | Warping Error: 0.0030:  73%|███████▎  | 8/11 [06:45<02:38, 52.77s/it]     

small_tank: 26.2065/0.8655/0.0030


Scene: small_explosion_0001 | PSNR: 29.1489 | SSIM: 0.8912 | Warping Error: 0.0049:  82%|████████▏ | 9/11 [07:01<01:22, 41.36s/it]

small_explosion_0001: 29.1489/0.8912/0.0049


Scene: small_padlock | PSNR: 34.8471 | SSIM: 0.9663 | Warping Error: 0.0014:  91%|█████████ | 10/11 [07:09<00:30, 30.92s/it]      

small_padlock: 34.8471/0.9663/0.0014


Scene: small_moreguns | PSNR: 31.5905 | SSIM: 0.9074 | Warping Error: 0.0011: 100%|██████████| 11/11 [07:57<00:00, 43.43s/it]

small_moreguns: 31.5905/0.9074/0.0011
Average PSNR: 29.8340
Average SSIM: 0.8570
Average Warping Error: 0.0023





In [None]:
SD21_Stage1_results = "/nobackup1/aryan/results/sd21/burst_testset_evaluation_s1"
cumulative_psnr = []
cumulative_ssim = []
cumulative_warping_error = []
pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = QUIVER_RESULTS + f"/{scene}_gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out = SD21_Stage1_results + f"/{scene}"
    out_imgs = os.listdir(test_dir_out)
    out_imgs.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list = []
    ssim_list = []
    warping_error = compute_video_warping_error(test_dir_out)
    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert('L').resize((1024,1024), Image.LANCZOS)
        out_img = Image.open(os.path.join(test_dir_out, out_imgs[i])).convert('L').resize((1024,1024), Image.LANCZOS)
        gt_img = torch.tensor(np.array(gt_img)).float().unsqueeze(0).to(device) / 255.0
        out_img = torch.tensor(np.array(out_img)).float().unsqueeze(0).to(device) / 255.0
        psnr, ssim = compute_full_reference_metrics(gt_img.unsqueeze(0), out_img.unsqueeze(0))
        # print( psnr, ssim)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
    pbar.set_description(f"Scene: {scene} | PSNR: {np.mean(psnr_list):.4f} | SSIM: {np.mean(ssim_list):.4f} | Warping Error: {warping_error:.4f}")
    print(f"{scene}: {np.mean(psnr_list):.4f}/{np.mean(ssim_list):.4f}/{warping_error:.4f}")
    cumulative_psnr.append(np.mean(psnr_list))
    cumulative_ssim.append(np.mean(ssim_list))
    cumulative_warping_error.append(warping_error.detach().cpu().item())
    pbar.update(1)
pbar.close()

print(f"Average PSNR: {np.mean(cumulative_psnr):.4f}")
print(f"Average SSIM: {np.mean(cumulative_ssim):.4f}")
print(f"Average Warping Error: {np.mean(cumulative_warping_error):.4f}")

Scene: xvfi_train | PSNR: 27.3034 | SSIM: 0.7572 | Warping Error: 0.0032:  18%|█▊        | 2/11 [00:32<02:26, 16.29s/it]
  warp_img2 = torch.tensor(warp_img2).float().to(device)
Scene: xvfi_boat | PSNR: 16.9676 | SSIM: 0.7116 | Warping Error: 0.0013:   9%|▉         | 1/11 [00:07<01:11,  7.10s/it]

xvfi_boat: 16.9676/0.7116/0.0013


Scene: xvfi_train | PSNR: 24.0845 | SSIM: 0.7626 | Warping Error: 0.0032:  18%|█▊        | 2/11 [00:14<01:04,  7.13s/it]

xvfi_train: 24.0845/0.7626/0.0032


Scene: test_00014 | PSNR: 22.8988 | SSIM: 0.8112 | Warping Error: 0.0013:  27%|██▋       | 3/11 [00:58<03:14, 24.30s/it]

test_00014: 22.8988/0.8112/0.0013


Scene: test_00020 | PSNR: 28.8180 | SSIM: 0.9455 | Warping Error: 0.0010:  36%|███▋      | 4/11 [01:43<03:45, 32.29s/it]

test_00020: 28.8180/0.9455/0.0010


Scene: test_00021 | PSNR: 15.2004 | SSIM: 0.6951 | Warping Error: 0.0010:  45%|████▌     | 5/11 [02:28<03:40, 36.79s/it]

test_00021: 15.2004/0.6951/0.0010


Scene: test_00031 | PSNR: 23.6604 | SSIM: 0.8440 | Warping Error: 0.0016:  55%|█████▍    | 6/11 [05:32<07:14, 86.96s/it]

test_00031: 23.6604/0.8440/0.0016


Scene: small_jetengine | PSNR: 30.9402 | SSIM: 0.8831 | Warping Error: 0.0026:  64%|██████▎   | 7/11 [05:52<04:20, 65.07s/it]

small_jetengine: 30.9402/0.8831/0.0026


Scene: small_tank | PSNR: 23.4911 | SSIM: 0.8795 | Warping Error: 0.0032:  73%|███████▎  | 8/11 [05:59<02:19, 46.41s/it]     

small_tank: 23.4911/0.8795/0.0032


Scene: small_explosion_0001 | PSNR: 29.4817 | SSIM: 0.9326 | Warping Error: 0.0048:  82%|████████▏ | 9/11 [06:12<01:12, 36.22s/it]

small_explosion_0001: 29.4817/0.9326/0.0048


Scene: small_padlock | PSNR: 35.4416 | SSIM: 0.9707 | Warping Error: 0.0014:  91%|█████████ | 10/11 [06:19<00:27, 27.03s/it]      

small_padlock: 35.4416/0.9707/0.0014


Scene: small_moreguns | PSNR: 30.9350 | SSIM: 0.9499 | Warping Error: 0.0013: 100%|██████████| 11/11 [07:00<00:00, 38.26s/it]

small_moreguns: 30.9350/0.9499/0.0013
Average PSNR: 25.6290
Average SSIM: 0.8532
Average Warping Error: 0.0021





: 

In [19]:
EVAL_DIR = "/media/agarg54/Extreme SSD/code/gQVR/eval_s1_vs_s2_vs_s3"
scenes = ['test_00014', 'test_00021', 'test_00031', 'full_tank', 'small_moreguns']

In [21]:
def normalize(tens):
    return (tens - tens.min()) / (tens.max() - tens.min())

In [26]:
cumulative_psnr_s3 = []
cumulative_ssim_s3 = []
cumulative_warping_error_s3 = []

cumulative_psnr_s2 = []
cumulative_ssim_s2 = []
cumulative_warping_error_s2 = []

cumulative_psnr_s1 = []
cumulative_ssim_s1 = []
cumulative_warping_error_s1 = []

pbar = tqdm(total=len(scenes))
for scene in scenes:
    test_dir_gt = EVAL_DIR + f"/{scene}/gt"
    gt_imgs = os.listdir(test_dir_gt)
    gt_imgs.sort()
    test_dir_out_s3 = EVAL_DIR + f"/{scene}/out_s3"
    test_dir_out_s2 = EVAL_DIR + f"/{scene}/out_s2"
    test_dir_out_s1 = EVAL_DIR + f"/{scene}/out_s1"

    out_imgs_s3 = os.listdir(test_dir_out_s3)
    out_imgs_s3.sort()
    out_imgs_s2 = os.listdir(test_dir_out_s2)
    out_imgs_s2.sort()
    out_imgs_s1 = os.listdir(test_dir_out_s1)
    out_imgs_s1.sort()
    # test_dir_lq = SD21_burst_results + f"/{scene}_lq"
    psnr_list_s3 = []
    ssim_list_s3 = []
    psnr_list_s2 = []
    ssim_list_s2 = []
    psnr_list_s1 = []
    ssim_list_s1 = []

    if len(out_imgs_s3) > 1:
        warping_error_s3 = compute_video_warping_error(test_dir_out_s3)
        cumulative_warping_error_s3.append(warping_error_s3.detach().cpu().item())
        warping_error_s2 = compute_video_warping_error(test_dir_out_s2)
        cumulative_warping_error_s2.append(warping_error_s2.detach().cpu().item())
        warping_error_s1 = compute_video_warping_error(test_dir_out_s1)
        cumulative_warping_error_s1.append(warping_error_s1.detach().cpu().item())
    else:
        cumulative_warping_error_s1.append(0.)
        cumulative_warping_error_s2.append(0.)
        cumulative_warping_error_s3.append(0.)

    for i in range(len(gt_imgs)):
        gt_img = Image.open(os.path.join(test_dir_gt, gt_imgs[i])).convert('L').resize((512,512), Image.LANCZOS)
        out_img_s3 = Image.open(os.path.join(test_dir_out_s3, out_imgs_s3[i])).convert('L').resize((512,512), Image.LANCZOS)
        out_img_s2 = Image.open(os.path.join(test_dir_out_s2, out_imgs_s2[i])).convert('L').resize((512,512), Image.LANCZOS)
        out_img_s1 = Image.open(os.path.join(test_dir_out_s1, out_imgs_s1[i])).convert('L').resize((512,512), Image.LANCZOS)

        gt_img = torch.tensor(np.array(gt_img)).float().unsqueeze(0).to(device) / 255.0
        out_img_s3 = torch.tensor(np.array(out_img_s3)).float().unsqueeze(0).to(device) / 255.0
        out_img_s2 = torch.tensor(np.array(out_img_s2)).float().unsqueeze(0).to(device) / 255.0
        out_img_s1 = torch.tensor(np.array(out_img_s1)).float().unsqueeze(0).to(device) / 255.0

        gt_img, out_img_s2, out_img_s3, out_img_s1 = normalize(gt_img), normalize(out_img_s2), normalize(out_img_s3), normalize(out_img_s1)

        psnr_s3, ssim_s3 = compute_full_reference_metrics(gt_img.unsqueeze(0), out_img_s3.unsqueeze(0))
        psnr_s2, ssim_s2 = compute_full_reference_metrics(gt_img.unsqueeze(0), out_img_s2.unsqueeze(0))
        psnr_s1, ssim_s1 = compute_full_reference_metrics(gt_img.unsqueeze(0), out_img_s1.unsqueeze(0))
        # print( psnr, ssim)
        psnr_list_s3.append(psnr_s3)
        ssim_list_s3.append(ssim_s3)

        psnr_list_s2.append(psnr_s2)
        ssim_list_s2.append(ssim_s2)

        psnr_list_s1.append(psnr_s1)
        ssim_list_s1.append(ssim_s1)
    # pbar.set_description(f"Scene: {scene}")
    print(f"{scene} S3: {np.mean(psnr_list_s3):.4f}/{np.mean(ssim_list_s3):.4f}/{warping_error_s3:.4f}")
    print(f"{scene} S2: {np.mean(psnr_list_s2):.4f}/{np.mean(ssim_list_s2):.4f}/{warping_error_s2:.4f}")
    print(f"{scene} S1: {np.mean(psnr_list_s1):.4f}/{np.mean(ssim_list_s1):.4f}/{warping_error_s1:.4f}")
    cumulative_psnr_s3.append(np.mean(psnr_list_s3))
    cumulative_ssim_s3.append(np.mean(ssim_list_s3))
    cumulative_psnr_s2.append(np.mean(psnr_list_s2))
    cumulative_ssim_s2.append(np.mean(ssim_list_s2))
    cumulative_psnr_s1.append(np.mean(psnr_list_s1))
    cumulative_ssim_s1.append(np.mean(ssim_list_s1))
    pbar.update(1)
pbar.close()
print("-----------------------------------")
print("Ablation: S1 vs S2 vs S3 - On realistic sim bursts (PSNR/SSIM/E*)")
print("-----------------------------------")
print(f"Average PSNR S1: {np.mean(cumulative_psnr_s1):.4f}")
print(f"Average SSIM S1: {np.mean(cumulative_ssim_s1):.4f}")
print(f"Average Warping Error S1: {np.mean(cumulative_warping_error_s1)}")
print("-----------------------------------")
print(f"Average PSNR S2: {np.mean(cumulative_psnr_s2):.4f}")
print(f"Average SSIM S2: {np.mean(cumulative_ssim_s2):.4f}")
print(f"Average Warping Error S2: {np.mean(cumulative_warping_error_s2)}")
print("-----------------------------------")
print(f"Average PSNR S3: {np.mean(cumulative_psnr_s3):.4f}")
print(f"Average SSIM S3: {np.mean(cumulative_ssim_s3):.4f}")
print(f"Average Warping Error S3: {np.mean(cumulative_warping_error_s3)}")
# print(cumulative_warping_error_s2, cumulative_warping_error_s3)


  warp_img2 = torch.tensor(warp_img2).float().to(device)
 20%|██        | 1/5 [00:03<00:15,  3.98s/it]

test_00014 S3: 26.9673/0.8088/0.0044
test_00014 S2: 24.5517/0.7345/0.0043
test_00014 S1: 24.4701/0.7606/0.0042


 40%|████      | 2/5 [00:07<00:11,  3.94s/it]

test_00021 S3: 24.9667/0.9157/0.0034
test_00021 S2: 27.0705/0.9225/0.0046
test_00021 S1: 22.1235/0.9088/0.0044


 60%|██████    | 3/5 [00:35<00:29, 14.57s/it]

test_00031 S3: 29.4883/0.8344/0.0059
test_00031 S2: 28.3928/0.8133/0.0070
test_00031 S1: 12.6355/0.3451/0.0101
full_tank S3: 26.0583/0.8620/0.0059
full_tank S2: 19.2660/0.8164/0.0070
full_tank S1: 19.4205/0.8245/0.0101


100%|██████████| 5/5 [00:37<00:00,  7.43s/it]

small_moreguns S3: 30.6675/0.9262/0.0263
small_moreguns S2: 21.2880/0.9436/0.0267
small_moreguns S1: 21.5380/0.9553/0.0268
-----------------------------------
Ablation: S1 vs S2 vs S3 - On realistic sim bursts (PSNR/SSIM/E*)
-----------------------------------
Average PSNR S1: 20.0375
Average SSIM S1: 0.7589
Average Warping Error S1: 0.009088441077619792
-----------------------------------
Average PSNR S2: 24.1138
Average SSIM S2: 0.8461
Average Warping Error S2: 0.008507703337818385
-----------------------------------
Average PSNR S3: 27.6296
Average SSIM S3: 0.8694
Average Warping Error S3: 0.008005498722195626



