In [None]:
import torch
import torchvision
import argparse
import matplotlib.pyplot as plt
import os

# cd to repo directory
os.chdir('..')

from cae import CAE
from restore_dataset import build_dataset, get_args_parser

In [None]:
parser = argparse.ArgumentParser('Denoise Dataset', parents=[get_args_parser()])
args, _ = parser.parse_known_args()
print(args)

In [None]:
# specify your checkpoint location here
checkpoint = 'checkpoint-cae.pth'
!ls -hl $checkpoint

In [None]:
model = CAE(512)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt = torch.load(checkpoint, map_location=device)
model.load_state_dict(ckpt['model'])
model.eval()

In [None]:
# create dataset and dataloader
dataset, args.nb_classes = build_dataset(args=args)
data_loader = torch.utils.data.DataLoader(
    dataset, shuffle=True,
    batch_size=1
)

In [None]:
# Extract inputs and labels from the batch
for batch_data in data_loader:
    input_img, original = batch_data
    break

with torch.no_grad():
    output = model(input_img)
    

noisy = input_img[:, 0, :, :]
standard = input_img[:, 1, :, :]
original = torch.squeeze(original, dim=0)
output = torch.squeeze(output, dim=0)
print(f"noisy shape {noisy.shape}")
print(f"original shape {original.shape}")
print(f"output shape {output.shape}")

# visualization
grid_img = torchvision.utils.make_grid([noisy, standard, original, output], nrow=4, padding=4)
plt.imshow(grid_img.permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.show()