In [2]:
import cv2
import numpy as np
import torch
from src.test.video_codec import process
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf


def create_model(config_path):
    config = OmegaConf.load(config_path)
    model = instantiate_from_config(config.model)
    return model

model = create_model('experiments/global_ablation/uni_v15.yaml').cpu()

# 3. Load checkpoint
ckpt = torch.load('experiments/global_ablation/uni_cap.ckpt', map_location='cuda')
# model_keys = set(model.state_dict())
# filtered = {k: v for k, v in ckpt['state_dict'].items() if k in model_keys}
model.load_state_dict(ckpt, strict=False)
model = model.cuda().eval()
target_resolution = (512, 512)

No module 'xformers'. Proceeding without it.
UniControlNet: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /data/maryam.sana/anaconda3/envs/unicontrolwrap/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /data/maryam.sana/anaconda3/envs/unicontrolwrap/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth


In [4]:
import os
import cv2
import torch
import numpy as np
from pathlib import Path
from torchvision import transforms
from lpips import LPIPS
from pytorch_msssim import ms_ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import pandas as pd
from PIL import Image
from src.test.video_codec import process
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf

# Initialize metrics
psnr_values = []
ms_ssim_values = []
lpips_values = []
video_names = []

# Initialize LPIPS model
lpips_model = LPIPS(net='alex').cuda()

# Image transformation
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    lambda x: x * 255
])

# List of video names
videos = {    "Beauty": {
        "prompt": "A beautiful blonde girl smiling with pink lipstick with black background",
        "path": "Beauty"
    },
    "Jockey": {
        "prompt": "A man riding a brown horse, galloping through a green race track. The man is wearing a yellow and red shirt and also a yellow hat",
        "path": "Jockey"
    },
    "Bosphorus": {
        "prompt": "A man and a woman sitting together on a boat sailing in water. They are both wearing ties. There is also a red flag at end of boat",
        "path": "Bosphorus"
    }}

# Base directory
base_dir = Path('../Ultra_Perceptual_Video_Compression/data/UVG')

for vid in videos.keys():
    try:
        # Paths
        original_image_path = base_dir / vid / 'images' / 'frame_0001.png'
        ref_image_path = base_dir / vid / 'intra_frames' / 'decoded_q4' / 'decoded_frame_0000.png'
        flow_image_path = base_dir / vid / 'optical_flow' / 'optical_flow_gop_8' / 'flow_0000_0001.png'

        # Load images
        target_resolution = (512, 512)
        ref_img = cv2.imread(str(ref_image_path))
        ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
        ref_img = cv2.resize(ref_img, target_resolution)

        flow_img = cv2.imread(str(flow_image_path))
        flow_img = cv2.cvtColor(flow_img, cv2.COLOR_BGR2RGB)
        flow_img = cv2.resize(flow_img, target_resolution)

        local_images = [flow_img, ref_img]

        # Run inference
        prompt = videos[vid]['prompt']
        pred = process(model, local_images, prompt)

        # Save predicted image
        pred_save_path = f'experiments/global_ablation/{vid}.png'
        os.makedirs(os.path.dirname(pred_save_path), exist_ok=True)
        cv2.imwrite(pred_save_path, cv2.cvtColor(pred[0][0], cv2.COLOR_RGB2BGR))

        # # Prepare tensors
        # original_image = Image.open(original_image_path).convert("RGB")
        # pred_image = Image.fromarray(pred[0].astype(np.uint8))

        # original_tensor = transform(original_image).unsqueeze(0).to('cuda')
        # pred_tensor = transform(pred_image).unsqueeze(0).to('cuda')

        # # PSNR
        # psnr_value = psnr(original_tensor.squeeze().cpu().numpy(), pred_tensor.squeeze().cpu().numpy(), data_range=255)
        # if psnr_value > 1000:
        #     continue
        # psnr_values.append(psnr_value)

        # # MS-SSIM
        # ms_ssim_value = ms_ssim(original_tensor, pred_tensor, data_range=255, size_average=True).item()
        # ms_ssim_values.append(ms_ssim_value)

        # # LPIPS
        # lpips_value = lpips_model(original_tensor / 255.0, pred_tensor / 255.0).item()
        # lpips_values.append(lpips_value)

        # video_names.append(vid)

    except Exception as e:
        print(f"Error processing {vid}: {e}")
        continue

# # Create dataframe
# df = pd.DataFrame({
#     'Video': video_names,
#     'PSNR': psnr_values,
#     'MS-SSIM': ms_ssim_values,
#     'LPIPS': lpips_values
# })

# # Save dataframe
# df

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /data/maryam.sana/anaconda3/envs/unicontrolwrap/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth


Global seed set to 42


Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]
Selected alphas for ddim sampler: a_t: tensor([0.9983, 0.9804, 0.9609, 0.9398, 0.9171, 0.8930, 0.8674, 0.8404, 0.8121,
        0.7827, 0.7521, 0.7207, 0.6885, 0.6557, 0.6224, 0.5888, 0.5551, 0.5215,
        0.4882, 0.4552, 0.4229, 0.3913, 0.3605, 0.3308, 0.3023, 0.2750, 0.2490,
        0.2245, 0.2014, 0.1799, 0.1598, 0.1413, 0.1243, 0.1087, 0.0946, 0.0819,
        0.0705, 0.0604, 0.0514, 0.0435, 0.0365, 0.0305, 0.0254, 0.0210, 0.0172,
        0.0140, 0.0113, 0.0091, 0.0073, 0.0058]); a_(t-1): [0.99914998 0.99829602 0.98038077 0.96087277 0.93978298 0.91713792
 0.89298052 0.86737001 0.84038192 0.81210774 0.78265446 0.75214338
 0.72070938 0.68849909 0.65566933 0.62238538 0.58881873 0.55514455
 0.52153981 0.4881804  0.45523876 0.42288151 0.39126703 0.36

DDIM Sampler: 100%|██████████████████████████████████████████████████████| 50/50 [00:12<00:00,  4.16it/s]
Global seed set to 42


Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]
Selected alphas for ddim sampler: a_t: tensor([0.9983, 0.9804, 0.9609, 0.9398, 0.9171, 0.8930, 0.8674, 0.8404, 0.8121,
        0.7827, 0.7521, 0.7207, 0.6885, 0.6557, 0.6224, 0.5888, 0.5551, 0.5215,
        0.4882, 0.4552, 0.4229, 0.3913, 0.3605, 0.3308, 0.3023, 0.2750, 0.2490,
        0.2245, 0.2014, 0.1799, 0.1598, 0.1413, 0.1243, 0.1087, 0.0946, 0.0819,
        0.0705, 0.0604, 0.0514, 0.0435, 0.0365, 0.0305, 0.0254, 0.0210, 0.0172,
        0.0140, 0.0113, 0.0091, 0.0073, 0.0058]); a_(t-1): [0.99914998 0.99829602 0.98038077 0.96087277 0.93978298 0.91713792
 0.89298052 0.86737001 0.84038192 0.81210774 0.78265446 0.75214338
 0.72070938 0.68849909 0.65566933 0.62238538 0.58881873 0.55514455
 0.52153981 0.4881804  0.45523876 0.42288151 0.39126703 0.36

DDIM Sampler: 100%|██████████████████████████████████████████████████████| 50/50 [00:12<00:00,  4.12it/s]
Global seed set to 42


Selected timesteps for ddim sampler: [  1  21  41  61  81 101 121 141 161 181 201 221 241 261 281 301 321 341
 361 381 401 421 441 461 481 501 521 541 561 581 601 621 641 661 681 701
 721 741 761 781 801 821 841 861 881 901 921 941 961 981]
Selected alphas for ddim sampler: a_t: tensor([0.9983, 0.9804, 0.9609, 0.9398, 0.9171, 0.8930, 0.8674, 0.8404, 0.8121,
        0.7827, 0.7521, 0.7207, 0.6885, 0.6557, 0.6224, 0.5888, 0.5551, 0.5215,
        0.4882, 0.4552, 0.4229, 0.3913, 0.3605, 0.3308, 0.3023, 0.2750, 0.2490,
        0.2245, 0.2014, 0.1799, 0.1598, 0.1413, 0.1243, 0.1087, 0.0946, 0.0819,
        0.0705, 0.0604, 0.0514, 0.0435, 0.0365, 0.0305, 0.0254, 0.0210, 0.0172,
        0.0140, 0.0113, 0.0091, 0.0073, 0.0058]); a_(t-1): [0.99914998 0.99829602 0.98038077 0.96087277 0.93978298 0.91713792
 0.89298052 0.86737001 0.84038192 0.81210774 0.78265446 0.75214338
 0.72070938 0.68849909 0.65566933 0.62238538 0.58881873 0.55514455
 0.52153981 0.4881804  0.45523876 0.42288151 0.39126703 0.36

DDIM Sampler: 100%|██████████████████████████████████████████████████████| 50/50 [00:11<00:00,  4.18it/s]
