In [None]:
%cd ..

import torch
import os
import matplotlib.pyplot as plt
import numpy as np
import pickle
import copy

from PIL import Image
from torchvision import transforms
from utils import get_psnr, image_normalization
from alignment.alignment_utils import load_deep_jscc
from alignment.alignment_model import *
from alignment.alignment_validation import *

os.getcwd()
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [None]:
model1_fp = r'alignment/models/autoencoders/rayleigh_snr_-10_seed_42.pkl'
# model1_fp = r'alignment/models/autoencoders/snr_0_seed_42.pkl'

model2_fp = r'alignment/models/autoencoders/rayleigh_snr_-10_seed_43.pkl'
# model2_fp = r'alignment/models/autoencoders/snr_0_seed_43.pkl'

# aligner_fp = r'/home/lorenzo/repos/Deep-JSCC-PyTorch/alignment/models/plots/psnr_vs_pilots_5/aligner_mlp_10000.pkl'
aligner_fp = r'/home/lorenzo/repos/Deep-JSCC-PyTorch/alignment/models/plots/psnr_vs_snr/aligner_twoconv_ae_-10_snr_0_seed_42.pth'

snr = 0
channel = "Rayleigh"
image_path = r'demo/kodim23.png'
# image_path = r'demo/0002.jpg'
times = 10
resolution = None
upscale_factor = 1
c = 8

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

no_mismatch_model, unaligned_model, aligned_model = prepare_models(model1_fp, model2_fp, aligner_fp, snr, c, resolution, channel, device)

In [None]:
# without semantic mismatch
visualization_pipeline(no_mismatch_model, image_path, resolution, times, upscale_factor)

# with semantic mismatch, without aligning
visualization_pipeline(unaligned_model, image_path, resolution, times, upscale_factor)

# with semantic mismatch, with aligning
visualization_pipeline(aligned_model, image_path, resolution, times, upscale_factor)

In [None]:
def evaluation_pipeline(model, output_path):
    psnr_all = 0.0
    with torch.no_grad():
        for _ in range(times):
            demo_image = model(test_image)
            demo_image = image_normalization('denormalization')(demo_image)
            gt = image_normalization('denormalization')(test_image)
            psnr_all += get_psnr(demo_image, gt)

        # prepare image for visualization
        demo_image = image_normalization('normalization')(demo_image)
        demo_image = demo_image.squeeze()
        demo_image = demo_image.numpy()  # (C, H, W)
        demo_image = demo_image.transpose(1, 2, 0)  # convert to (H, W, C) for PIL

        # convert to PIL image and upscale
        pil_image = Image.fromarray((demo_image * 255).astype(np.uint8))
        new_size = (pil_image.width * upscale_factor, pil_image.height * upscale_factor)
        pil_image = pil_image.resize(new_size, Image.NEAREST)  # Use NEAREST or BICUBIC

    # show the upscaled image
    plt.figure(figsize=(new_size[0] / 100, new_size[1] / 100), dpi=100)
    plt.imshow(pil_image)
    plt.axis('off')
    # plt.show()

    pil_image.save(output_path)

    print("Average PSNR is {:.2f} over {} runs on {}".format(psnr_all.item() / times, times, os.path.basename(test_image_dir)))

for snr in [-20, -10, 0, 10, 20, 30]:
    channel_type = 'AWGN'
    model1_fp = f'alignment/models/autoencoders/upscaled_42.pkl'
    model2_fp = f'alignment/models/autoencoders/upscaled_43.pkl'
    aligner_fp = f'alignment/models/plots/high_res/aligner_twoconv_{snr}.pth'

    test_image_dir = f'demo/kodim23.png'
    times = 10
    resolution = None
    upscale_factor = 1

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    c = 8

    if resolution is None:
        transform = transforms.Compose([transforms.ToTensor(), ])

    else:
        transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((resolution, resolution))])

    test_image = Image.open(test_image_dir)
    test_image.load()
    test_image = transform(test_image)

    no_mismatch_model, unaligned_model, aligned_model = prepare_models(model1_fp, model2_fp, aligner_fp, snr, c, resolution, channel_type, device)

    evaluation_pipeline(no_mismatch_model, f'alignment/models/plots/high_res/images/no_mismatch_{snr}.png')
    evaluation_pipeline(unaligned_model, f'alignment/models/plots/high_res/images/unaligned_{snr}.png')
    evaluation_pipeline(aligned_model, f'alignment/models/plots/high_res/images/aligned_{snr}.png')