In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
import numpy as np
from models.vision import LeNetMnist, weights_init, LeNet
from utils import label_to_onehot, cross_entropy_for_onehot
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
import lpips
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Change Checkpoints Path

checkpoint_name = 'LTI-checkpoints/CIFAR10_LeNet_MLP-10000_batch-16_0.0001_5000_256_best.pt'

In [3]:
checkpoint = torch.load(checkpoint_name)

In [8]:
gt = checkpoint['gt_data']
rec = checkpoint['reconstructed_imgs']

In [9]:
print(gt.shape)
print(rec.shape)

torch.Size([50000, 3, 64, 64])
torch.Size([50000, 3, 64, 64])


In [22]:
idxs = []

for i in range(64):
    idxs.append(48 + i)

In [23]:
rec_64 = rec[idxs].cpu()
gt_64 = gt[idxs].cpu()

In [24]:
mean = [0.4914672374725342, 0.4822617471218109, 0.4467701315879822]
std = [0.24703224003314972, 0.24348513782024384, 0.26158785820007324]

s = torch.tensor(std).view(3, 1, 1)

dm = torch.as_tensor(mean)[:, None, None]
ds = torch.as_tensor(std)[:, None, None]

In [25]:
rec_64_norm = (rec_64 - dm) / ds
gt_64_norm = (gt_64 - dm) / ds

In [26]:
import lpips
import metrics

lpips_scorer = lpips.LPIPS(net="alex")

ssim_score = metrics.cw_ssim(rec_64_norm, gt_64_norm, scales=5)
psnr_score = metrics.psnr(rec_64_norm, gt_64_norm, factor=1 / s)
lpips_score = lpips_scorer(rec_64_norm, gt_64_norm, normalize=True)

print(f'ssim:{ssim_score}')
print(f'psnr:{psnr_score}')
print(f'lpips:{lpips_score.mean()}')

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/u3637153/.local/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth
ssim:0.4115622341632843
psnr:16.58763313293457
lpips:0.48088008165359497


In [3]:
def grid_plot(tensor, file_name):
    tensor = tensor.clone().detach()
    tensor.clamp_(0, 1)
    num_images = 64
    if num_images == 1:
        fig, axes = plt.subplots(1, 1, figsize=(1, 1))
    elif num_images in [4, 8]:
        fig, axes = plt.subplots(1, num_images, figsize=(num_images, num_images))
    else:
        fig, axes = plt.subplots(num_images // 8, 8, figsize=(12, num_images // 16 * 3))
    axes = np.reshape(axes, -1)
    for im, ax in zip(tensor, axes):
        ax.imshow(im.permute(1, 2, 0).cpu())
        ax.axis('off')
    plt.subplots_adjust(wspace=0.02, hspace=0.02)
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    os.makedirs('saved_img') if not os.path.exists('saved_img') else None
    
    plt.savefig(file_name)

In [None]:
grid_plot(gt_64, 'gt.pdf')

In [None]:
grid_plot(rec_64, 'rec.pdf')