In [None]:
def evaluation_pipeline(model, output_path):
    psnr_all = 0.0
    with torch.no_grad():
        for _ in range(times):
            demo_image = model(test_image)
            demo_image = image_normalization('denormalization')(demo_image)
            gt = image_normalization('denormalization')(test_image)
            psnr_all += get_psnr(demo_image, gt)

        # prepare image for visualization
        demo_image = image_normalization('normalization')(demo_image)
        demo_image = demo_image.squeeze()
        demo_image = demo_image.numpy()  # (C, H, W)
        demo_image = demo_image.transpose(1, 2, 0)  # convert to (H, W, C) for PIL

        # convert to PIL image and upscale
        pil_image = Image.fromarray((demo_image * 255).astype(np.uint8))
        new_size = (pil_image.width * upscale_factor, pil_image.height * upscale_factor)
        pil_image = pil_image.resize(new_size, Image.NEAREST)  # Use NEAREST or BICUBIC

    # show the upscaled image
    plt.figure(figsize=(new_size[0] / 100, new_size[1] / 100), dpi=100)
    plt.imshow(pil_image)
    plt.axis('off')
    plt.show()

    pil_image.save(output_path)

    print("Average PSNR is {:.2f} over {} runs on {}".format(psnr_all.item() / times, times, os.path.basename(test_image_dir)))

for snr in [-20, -10, 0, 10, 20, 30]:
    channel_type = 'AWGN'
    model1_fp = r'alignment\models\upscaled_42.pkl'
    model2_fp = r'alignment\models\upscaled_43.pkl'
    aligner_fp = r'alignment\models\giugno\aligner_upscaled_noisy_'+str(snr)+'.pkl'

    test_image_dir = r'demo\kodim23.png'
    times = 10
    resolution = None
    upscale_factor = 1

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    c = 8

    if resolution is None:
        transform = transforms.Compose([transforms.ToTensor(), ])

    else:
        transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((resolution, resolution))])

    test_image = Image.open(test_image_dir)
    test_image.load()
    test_image = transform(test_image)

    model1 = load_deep_jscc(model1_fp, snr, c, channel_type)
    model2 = load_deep_jscc(model2_fp, snr, c, channel_type)

    encoder = copy.deepcopy(model1.encoder)
    decoder = copy.deepcopy(model2.decoder)

    with open(aligner_fp, 'rb') as f:
        aligner = pickle.load(f)

    aligned_model = AlignedDeepJSCC(encoder, decoder, aligner, snr, channel_type)
    unaligned_model = AlignedDeepJSCC(encoder, decoder, None, snr, channel_type)

    print(f"FOR SNR: {snr}")

    print(f"NO MISMATCH")
    evaluation_pipeline(model1, r'alignment\models\giugno\images\\' + 'no_mismatch_' + str(snr) + '.png')

    print("UNALIGNED MODEL")
    evaluation_pipeline(unaligned_model, r'alignment\models\giugno\images\\' + 'unaligned_' + str(snr) + '.png')

    print(f"ALIGNED MODEL")
    evaluation_pipeline(aligned_model, r'alignment\models\giugno\images\\' + 'aligned_' + str(snr) + '.png')

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Set parameters
num_rows = 3
num_cols = 6
batch_size = num_rows * num_cols

# Load CIFAR-10
transform = transforms.ToTensor()
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
images, labels = next(iter(loader))

# Class names and custom headers
class_names = dataset.classes
row_titles = ["No mismatch", "Unaligned", "Aligned"]
row_path = ["no_mismatch", "unaligned", "aligned"]
col_titles = ["SNR -20dB", "SNR -10dB", "SNR 0dB", "SNR 10dB", "SNR 20dB", "SNR 30dB"]
col_path = ["-20", "-10", "0", "10", "20", "30"]

# Create grid
fig, axes = plt.subplots(num_rows, num_cols, figsize=(17, 6))

for i in range(num_rows):
    for j in range(num_cols):
        image_path = r"alignment\models\giugno\images\\" + row_path[i] + "_" + col_path[j] + ".png"

        idx = i * num_cols + j
        img = Image.open(image_path).convert('RGB')
        label = labels[idx]
        ax = axes[i, j]

        ax.imshow(img)
        ax.axis('off')
        #ax.set_title(class_names[label], fontsize=8)

# Add column titles
for j, title in enumerate(col_titles):
    fig.text(
        0.088 + j / num_cols, 1,  # x, y position in figure coords
        title,
        ha='center',
        fontsize=10,
        fontweight='bold'
    )

# Add row titles
for i, title in enumerate(row_titles):
    fig.text(
        0, 0.84 - i / num_rows,  # x, y position in figure coords
        title,
        va='center',
        fontsize=10,
        fontweight='bold',
        rotation='vertical'
    )

plt.tight_layout()
plt.show()