In [5]:
import time
import cv2
import os
import random
import wandb
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
from glob import glob
from contrastive_attention import ContrastiveAttention
from PIL import Image
from torch.utils.data import Dataset
from visualizer import Visualizer
import torch.nn as nn

In [2]:
def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)



In [6]:
transform = transforms.ToTensor()

my_model_checkpoint ='/kuacc/users/edincer16/Comp541_fall22/course_project/attentionGANwContrastiveLoss/model_results/ContrastiveAttention_weights/ContrastiveAttention_400_net_G.pth'   
my_model = ContrastiveAttention(input_dim=3,output_dim=3,n_epochs=5,norm_layer=nn.BatchNorm2d,lr_decay_iters=5,
    n_epochs_decay=5,batch_size=4,num_patches=256,
    lambda_GAN=4.0, lambda_NCE=4.0,gan_mode='lsgan').cuda()
my_model.netG.load_state_dict(torch.load(my_model_checkpoint))
my_model.netG.eval()

Generator(
  (first_reflectionPad): ReflectionPad2d((3, 3, 3, 3))
  (first_conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
  (first_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (first_relu): ReLU(inplace=True)
  (second_conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (second_norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (second_relu): ReLU(inplace=True)
  (third_conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (third_norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (third_relu): ReLU(inplace=True)
  (resnet_blocks): Sequential(
    (0): ResnetBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      (conv1_norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3,

In [8]:
input_path = '/kuacc/users/edincer16/Comp541_fall22/course_project/attentionGANwContrastiveLoss/horse2zebra/testA/*'
output_path = '/kuacc/users/edincer16/Comp541_fall22/course_project/attentionGANwContrastiveLoss/model_results/baseline_outputs_for_fid/'
for img_path in tqdm(glob(input_path)):
    img_name = img_path.split('/')[-1]
    img_init = cv2.imread(img_path)
    img = cv2.cvtColor(img_init, cv2.COLOR_BGR2RGB)
    img = transform(img).unsqueeze(0)
    my_generated_image, _, _, _, _, _, _, _, _, _, _, \
            _, _, _, _, _, _, _, _, _, _, \
            _, _, _, _, _, _, _, _, _  = my_model.netG(img.cuda())
    gen_img = tensor2im(my_generated_image)
    cv2.imwrite(output_path+img_name,
                cv2.cvtColor(gen_img, cv2.COLOR_RGB2BGR))

100%|█████████████████████████████████████████████████████████████████████████| 120/120 [00:07<00:00, 15.31it/s]


In [9]:
#FID calculation
!python -m pytorch_fid /kuacc/users/edincer16/Comp541_fall22/course_project/attentionGANwContrastiveLoss/horse2zebra/testA /kuacc/users/edincer16/Comp541_fall22/course_project/attentionGANwContrastiveLoss/model_results/baseline_outputs_for_fid

100%|█████████████████████████████████████████████| 3/3 [00:02<00:00,  1.19it/s]
100%|█████████████████████████████████████████████| 3/3 [00:01<00:00,  2.07it/s]
FID:  213.40127871659777
