# Import Dependecies

In [1]:
import os
import random
import torch
import torch.nn as nn

import numpy as np
import pandas as pd

from TenGAN.ggan.zoo import *

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from torchmetrics.image.fid import FrechetInceptionDistance

In [2]:
work_dir = "/home/spaka002/Tensor_GAN/"

### Helper Functions

In [3]:
def trim(word):
    
    w = word
    
    trim_characters = [' ']
    trim_characters = {t_c:True for t_c in trim_characters}
    
    if w[0] in trim_characters:
        return trim(w[1:])
    
    if w[-1] in trim_characters:
        return trim(w[:-1])
    
    return w

# Load Generator

In [4]:
config_id = int(input('config_id'))

hps = dict()

with open(f'{os.getcwd()}/configs.txt', 'r') as f:
    text = f.read()
    config_txt = text[text.find(f'config_id: {config_id}'):text.find(f'config_id: {config_id+1}')]
    del text
    
use_decomposition = '_NoDecomp' not in config_txt[config_txt.find("Generator Model")+17:config_txt.find("Discriminator Model")-1]

hps_text = config_txt.split('{')[-1].split('}')[0]
hp_text_list = hps_text.split('\n')

for hp_text in hp_text_list:

    if len(hp_text) < 2: continue
    hp_name = trim(hp_text.split(':')[0])
    hp_value = trim(hp_text.split(':')[1])

    if hp_name in ['gen_model', 'disc_model']: hp_value = str(hp_value)
    else:
        try: hp_value = eval(hp_value)
        except: hp_value = str(hp_value)

    hps[hp_name] = hp_value
    
generator_path = config_txt[config_txt.find('saved_generators/'):].split('\n')[0]

In [5]:
substr = 'Tensor Shape'
idx = config_txt.find(substr)
tensor_shape = eval(trim(config_txt[idx+len(substr)+1: config_txt.find('\n', idx)]))
del substr, idx

In [6]:
zoo = ModelZoo()
gen_class = zoo.get_model(hps['gen_model'])

if hps['gen_model'] == 'MyBaselineGenerator': generator = gen_class(generate_tensor_shape = tensor_shape[:-1], opt = hps)
elif hps['gen_model'] in ['3DGAN']:
    
    hps['batch_size'] = 64
    hps['lr'] = 1e-2
    hps['latent_size'] = 254

    hps['num_epochs'] = 10
    hps['layer_size'] = 128
    hps['latent_dim'] = 100

    hps['tensor_shape'] = [25, 51, 51]
    hps['hidden_tensor_shape'] = [8, 7, 7, 8]
    
    generator = gen_class(opt = hps)
else:
    generator = gen_class(latent_dim = hps['latent_dim'],                     # 1D noise vector input dim
                        layer_size = hps['gen_layer_size'],                 # hidden_layer size
                        num_nodes = hps['num_nodes'],                       # image height & width
                        rank = hps['decomp_rank'],                          # tensor decomposition rank
                        num_views = hps['num_views'],                       # num_channels for image
                        decomp_type = hps['tensor_model'],                  # decomposition type - 'CPD', 'tucker' only so far
                        num_tensor_modes = len(list(tensor_shape))-1,       # number of tensor_modes, exclude mode for batch size
                        tensor_decomposition = use_decomposition,
                        opt = hps
                        )

generator_state_dict = torch.load(generator_path)
generator.load_state_dict(generator_state_dict)

del zoo, gen_class, generator_path, generator_state_dict

Successfully initialized generator.


# Get Real & Generated Data

In [7]:
n_samples = 500

In [8]:
generator.change_device('cpu')
generated_images = generator.generate(n_samples).detach().numpy()

In [9]:
tensor_comparison = torch.load(f"{work_dir}TenGAN/my_data/sample_EleEscan_1_1-5.pt")
# tensor_comparison = torch.load(f"{work_dir}TenGAN/my_data/torch_GammaEscan_RandomAngle_1_1&2&5.pt")
# tensor_comparison = torch.from_numpy(tensor_comparison)

tensor_comparison = tensor_comparison.permute(0, 3, 1, 2).numpy()

In [10]:
# sample_indices = set()

# while len(sample_indices) < n_samples:
#     new_index = int(random.random()*real_data.shape[0])
#     sample_indices.add(new_index)
# sample_indices = list(sample_indices)

# tensor_comparison = real_data[sample_indices].permute(0, 3, 1, 2).numpy()
# del real_data, new_index, sample_indices

# FID

### Setup

In [11]:
#First we have to normalize the data, the frechet inception distance from pytorch only takes uint8 data, 
# so we first normalize so our data is between 0 and 1

tensor_comparison_norm = (tensor_comparison[:n_samples] - tensor_comparison[:n_samples].min()) / (tensor_comparison[:n_samples].max() - tensor_comparison[:n_samples].min())
generated_images_norm = (generated_images - generated_images.min()) / (generated_images.max() - generated_images.min())

# Now multiply by 255 to [0, 255] and convert to uint8
tensor_comparison_uint8 = (tensor_comparison_norm * 255).astype(np.uint8)
generated_images_uint8 = (generated_images_norm * 255).astype(np.uint8)

tensor_comparison_torch = torch.tensor(tensor_comparison_uint8, dtype=torch.uint8)
generated_images_torch = torch.tensor(generated_images_uint8, dtype=torch.uint8)

print("tensor_comparison_torch min/max:", tensor_comparison_torch.min().item(), tensor_comparison_torch.max().item())
print("generated_images_torch min/max:", generated_images_torch.min().item(), generated_images_torch.max().item())
print("tensor_comparison_torch shape:", tensor_comparison_torch.shape)
print("generated_images_torch shape:", generated_images_torch.shape)

tensor_comparison_torch min/max: 0 255
generated_images_torch min/max: 0 255
tensor_comparison_torch shape: torch.Size([500, 25, 51, 51])
generated_images_torch shape: torch.Size([500, 25, 51, 51])


In [12]:
#Since this and the last score are the same, I'm assuming these are the correct way to do it.

real_dataset = tensor_comparison_torch.unsqueeze(1).repeat(1, 3, 1, 1, 1).view(-1, 3, 51, 51)
generated_dataset = generated_images_torch.unsqueeze(1).repeat(1, 3, 1, 1, 1).view(-1, 3, 51, 51)
random_dataset = torch.tensor(torch.randint(low = 0, high = 255+1, size = (real_dataset.shape[0], 3, 51, 51)), dtype = torch.uint8)

del tensor_comparison_norm, generated_images_norm, tensor_comparison_uint8, generated_images_uint8
del tensor_comparison_torch, generated_images_torch, tensor_comparison, generated_images

In [13]:
def get_fid(real_dataset, fake_dataset, device = 'cpu'):

    fid = FrechetInceptionDistance(feature = 2048).to(device)
    fid.update(real_dataset.to(device), real=True)
    fid.update(fake_dataset.to(device), real=False)
    
    fid_score_dataset = fid.compute().item()

    return float(fid_score_dataset)

### Display FID Scores

In [14]:
device = 'cpu'
print(f"Config ID: {config_id}")
print(f"Generated Data FID: {get_fid(real_dataset, generated_dataset, device = device)}")
if False: print(f"Random Data FID: {get_fid(real_dataset, random_dataset, device = device)}")

Config ID: 39
Generated Data FID: 240.0154266357422


In [None]:
print(f"{' '*5}Hyperparameters:\n\n" + '\n'.join([f"{hp} = {hps[hp]}" for hp in list(hps)]))

     Hyperparameters:

do_cnn_slices = False
n_epochs = 1
batch_size = 64
gen_lr = 1e-05
disc_lr = 0.001
latent_dim = 100
num_nodes = 51
num_views = 25
num_time_steps = 9
val_gen_size = 100
tensor_model = CPD
decomp_rank = None
training_size = 10000
epoch_print_every = 1
batch_print_every = 5
critic_iterations = 1
generator_iterations = 1
gen_model = 3DGAN
gen_layer_size = 128
gen_hidden_channels = 64
gen_num_hidden_convs = 1
gen_inconsistency_lambda = 0
gen_epoch_start_inc = 1000
gen_add_noise = ['mulitply', [0.5, 2]]
gen_channel_smooth = 0
gen_channel_smooth_window = 3
gen_smooth_modes = []
gen_epoch_start_smooth = 0
disc_model = Discriminator3d
num_slices = 25
disc_sig_out = False
disc_add_noise = None
disc_conv_out_channels = 16
disc_conv_hidden_channels = 32
disc_num_conv_layers = 2
disc_linear_hidden_dim = 64
disc_num_linear_layers = 1
max_eval_rank = 40
tensor_eval_samples = 50
rank_lambda = 0
penalty_type = fro
rank_penalty_method = A
n_graph_sample_batches = 10
eval_method = m