# IMPORTS

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
import os


In [None]:
import sys  
sys.path.insert(0, os.path.join(os.getcwd(),"../src/"))# this should vary when using Colab

from dcgan import DCGAN
from config import *
from dataset import WaterBodyGeneratorDataset


# MODELS TO INFERENCE

In [None]:
inference_cfg.models = inference_cfg.models   # ["model name"]
inference_cfg.configs = inference_cfg.configs # [ config_name]
inference_cfg.epoch = inference_cfg.epoch     # [epoch]
inference_cfg.gt = inference_cfg.gt           # True or False

# CREATING MODELS

In [None]:
fixed_noise = torch.randn(32, 100,1,1, device="cuda")
content = []

# CREATE MODELS
models = [ DCGAN(config) for config in inference_cfg.configs]

# GENERATE INFERENCE BATCHES

In [None]:
for i, model in enumerate(models):
    for epoch in inference_cfg.epoch:
        model.load(os.path.join(inference_cfg.root_path, inference_cfg.models[i]), epoch)
        model.eval()
        batch = model.generate_batch(batch_size=inference_cfg.batch_size, threeshold=inference_cfg.threeshold,batch_gen_size=inference_cfg.batch_gen_size)
        batch = [(x[0],np.transpose(x[1],(1,2,0))) for x in batch]
        content.append(batch)


# GENERATE GT FOR COMPARISION

In [None]:
# GENERATE GROUND TRUTH
if inference_cfg.gt:
    tf = transforms.Compose(
        [
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize(
                mean    = [0.5, 0.5, 0.5],
                std     = [0.5, 0.5, 0.5])
        ]
    )

    path = "D:/AI_ML/Kaggle/Water Bodies Dataset_pruned_more/"
    dataset = WaterBodyGeneratorDataset(path, transform=tf)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32,
                                         shuffle=True, num_workers=0)
    gt = [("gt",np.transpose(o,(1,2,0))) for o in iter(dataloader).next()[:inference_cfg.batch_size]]
    content.append(gt)


# DISPLAY IMAGES

In [None]:
fig, axs = plt.subplots(len(content),inference_cfg.batch_size, figsize=(20,20))
counter = 0
e = 0
m=0
for i in range(len(content)):
    if e >= len(inference_cfg.epoch):
        e = 0
        m+=1
    for j in range(inference_cfg.batch_size):
        plt.axis("off")
        
        white = content[i][counter][1].max()
        black = content[i][counter][1].min()
        arr = (content[i][counter][1] - black)* (1/(white-black))
        if inference_cfg.gt and i== len(content) -1:
            axs[i][0].set_ylabel("ground truth")
        else:

            axs[i][0].set_ylabel(f'{inference_cfg.models[m]}_{inference_cfg.epoch[e]}')

        axs[i][j].set_title(str(content[i][counter][0]))
        #print(arr)
        axs[i][j].imshow(arr)
        counter +=1
    e +=1
    counter = 0