In [26]:
from dataset import ICLEVRLoader
from model import Discriminator, Generator
from evaluator import evaluation_model
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
import pathlib
import torch
import torch.nn as nn
import numpy as np
import copy
import os 

if __name__=='__main__':
   
    #create dir
    if not os.path.exists("./eval_graph"):
        os.makedirs("./eval_graph")
        
    # setting parameters
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    z_dim=100
    c_dim=200
    image_shape=(64,64,3)

    # create generate & discriminator
    generator=Generator(z_dim,c_dim).to(device)
    path = 'ckpt/epoch128_score0.76.pt'
    generator.load_state_dict(torch.load(path))    

    evaluation_model=evaluation_model()
    root_folder=pathlib.Path().absolute()
    test_labels = ICLEVRLoader(root_folder,mode="test")
    test_conditions = np.zeros((len(test_labels),24))
    for i in range(len(test_labels)):
        test_conditions[i] = test_labels[i]
    test_conditions = torch.from_numpy(test_conditions).float().to(device)
    
    new_test_labels = ICLEVRLoader(root_folder,mode="new_test")
    new_test_conditions = np.zeros((len(new_test_labels),24))
    for i in range(len(new_test_labels)):
        new_test_conditions[i] = new_test_labels[i]
    new_test_conditions = torch.from_numpy(new_test_conditions).float().to(device)
    fixed_z = torch.randn((len(test_labels), z_dim)).to(device)
    best_score = 0

    # evaluate
    generator.eval()

    with torch.no_grad():
        test_imgs=generator(fixed_z,test_conditions)
        new_test_imgs=generator(fixed_z,new_test_conditions)
    test_score=evaluation_model.eval(test_imgs,test_conditions)
    new_test_score=evaluation_model.eval(new_test_imgs,new_test_conditions)
    
    save_image(test_imgs, os.path.join('eval_graph', f'test.png'), nrow=8, normalize=True)
    save_image(new_test_imgs, os.path.join('eval_graph', f'new_test.png'), nrow=8, normalize=True)

    print(f'testing score: {test_score:.2f}')
    print(f'new_testing score: {new_test_score:.2f}')
    print('---------------------------------------------')
        

testing score: 0.76
new_testing score: 0.60
---------------------------------------------
