In [1]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import defaultdict
import pandas as pd
import numpy as np
import torch.fft
import subprocess
import logging
import random
import shutil
import psutil
import scipy
import torch
import copy
import yaml
import time
import tqdm
import sys
import gc

import segmentation_models_pytorch as smp

from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from holodecml.data import PickleReader, UpsamplingReader
from holodecml.propagation import InferencePropagator
from holodecml.transforms import LoadTransformations
from holodecml.models import load_model
from holodecml.losses import load_loss

import os
import warnings
warnings.filterwarnings("ignore")

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [2]:
available_ncpus = len(psutil.Process().cpu_affinity())

In [3]:
# ### Set seeds for reproducibility
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

In [4]:
config = "../results/gan/model.yml"
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [5]:
# Set seeds for reproducibility
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

tile_size = int(conf["data"]["tile_size"])
step_size = int(conf["data"]["step_size"])
data_path = conf["data"]["output_path"]
data_path_raw = conf["data"]["output_path_raw"]

total_positive = int(conf["data"]["total_positive"])
total_negative = int(conf["data"]["total_negative"])
total_examples = int(conf["data"]["total_training"])

transform_mode = "None" if "transform_mode" not in conf["data"] else conf["data"]["transform_mode"]
config_ncpus = int(conf["data"]["cores"])
use_cached = False if "use_cached" not in conf["data"] else conf["data"]["use_cached"]

name_tag = f"{tile_size}_{step_size}_{total_positive}_{total_negative}_{total_examples}_{transform_mode}"
fn_train = f"{data_path}/training_{name_tag}.pkl"
fn_valid = f"{data_path}/validation_{name_tag}.pkl"
fn_train_raw = f"{data_path_raw}/training_{name_tag}.pkl"
fn_valid_raw = f"{data_path_raw}/validation_{name_tag}.pkl"

output_path = conf["data"]["output_path"]

train_batch_size = 16
valid_batch_size = 16

latent_dim = 128
img_shape = (1, tile_size, tile_size)

d_learning_rate = 0.0004
g_learning_rate = 0.0001

b1 = 0.0
b2 = 0.9

In [6]:
is_cuda = torch.cuda.is_available()
data_device = torch.device("cpu") if "device" not in conf["data"] else conf["data"]["device"]

In [7]:
# Load the preprocessing transforms
if "Normalize" in conf["transforms"]["training"]:
    conf["transforms"]["validation"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]
    conf["transforms"]["inference"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]

train_transforms = LoadTransformations(conf["transforms"]["training"])
valid_transforms = LoadTransformations(conf["transforms"]["validation"])

In [8]:
train_synthetic_dataset = PickleReader(
    fn_train,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=True
)

test_synthetic_dataset = PickleReader(
    fn_valid,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=False
)

In [9]:
train_synthetic_loader = torch.utils.data.DataLoader(
    train_synthetic_dataset,
    batch_size=train_batch_size,
    num_workers=0, #available_ncpus//2,
    pin_memory=True,
    shuffle=True)

test_synthetic_loader = torch.utils.data.DataLoader(
    test_synthetic_dataset,
    batch_size=valid_batch_size,
    num_workers=0,  # 0 = One worker with the main process
    pin_memory=True,
    shuffle=False)

In [10]:
# holo_conf = copy.deepcopy(conf)
# holo_conf["data"]["data_path"] = holo_conf["data"]["raw_data"]


# train_holodec_dataset = UpsamplingReader(
#     holo_conf,
#     transform=train_transforms,
#     device=data_device
# )

# holo_conf = copy.deepcopy(conf)
# holo_conf["data"]["raw_data"] = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/real_holograms_CSET_RF07_20150719_203600-203700.nc"
# holo_conf["data"]["data_path"] = holo_conf["data"]["raw_data"]

# test_holodec_dataset = UpsamplingReader(
#     holo_conf,
#     transform=valid_transforms,
#     device=data_device
# )

train_holodec_dataset = PickleReader(
    fn_train_raw,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=True
)

test_holodec_dataset = PickleReader(
    fn_valid_raw,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=False
)

In [11]:
train_holodec_loader = torch.utils.data.DataLoader(
    train_holodec_dataset,
    batch_size=train_batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=True)

test_holodec_loader = torch.utils.data.DataLoader(
    test_holodec_dataset,
    batch_size=valid_batch_size,
    num_workers=0,  # 0 = One worker with the main process
    pin_memory=True,
    shuffle=False)

In [12]:
generator = smp.Unet(
    encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=2,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset),
    activation = "tanh"
)


aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.2,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=1,                 # define number of output labels
)

discriminator = smp.Unet(
#     encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1, 
    aux_params=aux_params
)

#print(generator(torch.ones([1, 1, 64, 64])).shape)
#print(discriminator(torch.ones([1, 1, 64, 64]))[1].shape)

In [13]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

if is_cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [14]:
optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=g_learning_rate, betas=(b1, b2))
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=d_learning_rate, betas=(b1, b2))

In [None]:
n_epochs = 500
batches_per_epoch = 200

Tensor = torch.cuda.FloatTensor if is_cuda else torch.FloatTensor
adv_loss = 'wgan-gp'
lambda_gp = 10
train_gen_every = 1
train_disc_every = 1
threshold = 0.5

results = defaultdict(list)
for epoch in range(n_epochs):
    
    ### Train
    real_images = iter(train_holodec_loader)
    synthethic_images = iter(train_synthetic_loader)
    dual_iter = tqdm.tqdm(
        enumerate(zip(real_images, synthethic_images)),
        total = batches_per_epoch, 
        leave = True)
    
    train_results = defaultdict(list)
    for i, ((holo_img, holo_label), (synth_img, synth_label)) in dual_iter:
            
        if holo_img.shape[0] != synth_img.shape[0]:
            continue
                
        # Adversarial ground truths
        valid = Variable(Tensor(holo_img.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(holo_img.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(holo_img.type(Tensor))
        synthethic_imgs = Variable(synth_img.type(Tensor))
        
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 0.2, holo_img.shape)))
        # C-GAN-like input using the synthethic image as conditional input
        gen_input = torch.cat([z, synthethic_imgs], 1)
        # Generate a batch of images
        gen_noise = generator(gen_input)
        # Add to the synthetic images
        gen_imgs = 0.5 * (synthethic_imgs + gen_noise)
        # Discriminate the fake images
        _, verdict = discriminator(gen_imgs)
        
        # -----------------
        #  Train Generator
        # -----------------
        
        if (i + 1) % train_gen_every == 0:
            
            optimizer_G.zero_grad()
            requires_grad(generator, True)
            requires_grad(discriminator, False)
            
            # Loss measures generator's ability to fool the discriminator
            if adv_loss == 'wgan-gp':
                g_loss = -verdict.mean()
            elif adv_loss == 'hinge':
                g_loss = -verdict.mean()
            elif adv_loss == 'bce':
                g_loss = adversarial_loss(verdict, valid)
            g_loss.backward()
            optimizer_G.step()
            
            train_results["g_loss"].append(g_loss.item())
            
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        if (i + 1) % train_disc_every == 0:
        
            optimizer_D.zero_grad()
            requires_grad(generator, False)
            requires_grad(discriminator, True)

            # Measure discriminator's ability to classify real from generated samples
            _, disc_real = discriminator(real_imgs)
            _, disc_synth = discriminator(gen_imgs.detach())
            _, disc_synth_true = discriminator(synthethic_imgs)
            
            train_results["real_acc"].append(((disc_real > threshold) == valid).float().mean().item())
            train_results["syn_acc"].append(((disc_synth > threshold) == fake).float().mean().item())
            train_results["syn_true_acc"].append(((disc_synth_true > threshold) == fake).float().mean().item())
            
            if adv_loss == 'wgan-gp':
                real_loss = -torch.mean(disc_real)
                fake_loss = disc_synth.mean() 
                fake_loss += disc_synth_true.mean()
            elif adv_loss == 'hinge':
                real_loss = torch.nn.ReLU()(1.0 - disc_real).mean()
                fake_loss = torch.nn.ReLU()(1.0 + disc_synth).mean() 
                fake_loss += torch.nn.ReLU()(1.0 + disc_synth_true).mean()
            elif adv_loss == 'bce':
                real_loss = adversarial_loss(disc_real, valid)
                fake_loss = adversarial_loss(disc_synth, fake) 
                fake_loss += adversarial_loss(disc_synth_true, fake)
            d_loss = real_loss + fake_loss
            d_loss.backward()
            optimizer_D.step()

            train_results["d_loss"].append(d_loss.item())

            if adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda().expand_as(real_imgs)
                interpolated = Variable(alpha * real_imgs.data + (1 - alpha) * gen_imgs.data, requires_grad=True)
                out = discriminator(interpolated)[1]

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)

                # Backward + Optimize
                d_loss_reg = lambda_gp * d_loss_gp

                optimizer_D.zero_grad()
                d_loss_reg.backward()
                optimizer_D.step()

                train_results["d_reg"].append(d_loss_reg.item())

        
        print_str =  f'Epoch {epoch}'
        print_str += f' D_loss {np.mean(train_results["d_loss"]):.8f}'
        print_str += f' D_reg {np.mean(train_results["d_reg"]):.8f}'
        print_str += f' G_loss {np.mean(train_results["g_loss"]):8f}'
        print_str += f' h_acc {np.mean(train_results["real_acc"]):.6f}'
        print_str += f' s_pred_acc {np.mean(train_results["syn_acc"]):.6f}'
        print_str += f' s_true_acc {np.mean(train_results["syn_true_acc"]):.6f}'
        dual_iter.set_description(print_str)
        dual_iter.refresh()
        
        if i == batches_per_epoch and i > 0:
            break
        
    # Epoch is over. Save some stuff.
    save_image(synthethic_imgs.data[:16], f'{conf["save_loc"]}/images/synth_{epoch}.png', nrow=4, normalize=True)
    save_image(real_imgs.data[:16], f'{conf["save_loc"]}/images/real_{epoch}.png', nrow=4, normalize=True)
    save_image(gen_imgs.data[:16], f'{conf["save_loc"]}/images/pred_{epoch}.png', nrow=4, normalize=True)
    #save_image(gen_noise.data[:16], f"../results/gan/images/noise_{epoch}.png", nrow=4, normalize=True)
            

    # Save the dataframe to disk
    results["epoch"].append(epoch)
    results["d_loss"].append(np.mean(train_results["d_loss"]))
    results["d_loss_reg"].append(np.mean(train_results["d_reg"]))
    results["g_loss"].append(np.mean(train_results["g_loss"]))
    results["real_acc"].append(np.mean(train_results["real_acc"]))
    results["pred_synth_acc"].append(np.mean(train_results["syn_acc"]))
    results["true_synth_acc"].append(np.mean(train_results["syn_true_acc"]))
    
    df = pd.DataFrame.from_dict(results).reset_index()
    df.to_csv(f'{conf["save_loc"]}/training_log.csv', index=False)
    
    # Save the model
    state_dict = {
        'epoch': epoch,
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'generator_state_dict': generator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
    }
    torch.save(state_dict, f'{conf["save_loc"]}/best.pt')

Epoch 0 D_loss 0.32681587 D_reg 8.74229616 G_loss -0.340354 h_acc 0.385354 s_pred_acc 0.677894 s_true_acc 0.632610: 100%|██████████| 500/500 [17:02<00:00,  2.04s/it] 
Epoch 1 D_loss 0.07869994 D_reg 7.29737211 G_loss -0.130341 h_acc 0.268214 s_pred_acc 0.935254 s_true_acc 0.795160: 100%|██████████| 500/500 [17:49<00:00,  2.14s/it]
Epoch 2 D_loss -0.13861550 D_reg 5.16274157 G_loss -0.059109 h_acc 0.246881 s_pred_acc 0.994511 s_true_acc 0.972305: 100%|██████████| 500/500 [17:51<00:00,  2.14s/it]
Epoch 3 D_loss -0.27649639 D_reg 2.25340946 G_loss -0.015994 h_acc 0.261976 s_pred_acc 0.999875 s_true_acc 0.999875: 100%|██████████| 500/500 [17:50<00:00,  2.14s/it]
Epoch 4 D_loss -0.33729177 D_reg 1.89187824 G_loss -0.014369 h_acc 0.381861 s_pred_acc 0.999501 s_true_acc 0.999875: 100%|██████████| 500/500 [17:50<00:00,  2.14s/it]
Epoch 5 D_loss -0.41444625 D_reg 1.64227292 G_loss -0.031997 h_acc 0.561502 s_pred_acc 0.995509 s_true_acc 0.999875: 100%|██████████| 500/500 [17:50<00:00,  2.14s/it]

In [None]:
# import matplotlib.pyplot as plt, pandas as pd, numpy as np

# with open(f'{conf["save_loc"]}/training_log.csv', "r") as fid:
#     lines = np.array([[float(g) for g in f.strip("\n").split(",")] for f in fid.readlines()])
# batch_updates = range(len(lines[:, 0]))

# plt.plot(batch_updates, lines[:, 2])
# plt.plot(batch_updates, lines[:, 3])
# plt.plot(batch_updates, lines[:, 4])
# plt.legend(["d-loss", "d-reg-loss", "g-loss"])

### Other custom examples of architectures

In [None]:
from torch.optim.optimizer import Optimizer, required
from torch import Tensor
from torch.nn import Parameter

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self, in_dim, activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        
        self.query_conv = SpectralNorm(nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1))
        self.key_conv = SpectralNorm(nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1))
        self.value_conv = SpectralNorm(nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1))
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        
        out = self.gamma*out + x
        return out, attention
    
# def weights_init_normal(m):
#     classname = m.__class__.__name__
#     if classname.find("Conv") != -1:
#         torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
#     elif classname.find("BatchNorm2d") != -1:
#         torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
#         torch.nn.init.constant_(m.bias.data, 0.0)
        
# def weights_init_normal(m):
#     if isinstance(m, nn.Linear):
#         torch.nn.init.xavier_normal_(m.weight)
#         m.bias.data.zero_()
#     elif isinstance(m, nn.Conv2d):
#         pass
#     elif isinstance(m, torch.nn.GRU or torch.nn.LSTM):
#         for name, param in m.named_parameters():
#             if 'bias' in name:
#                 torch.nn.init.constant_(param, 0.0)
#             elif 'weight_ih' in name:
#                 torch.nn.init.kaiming_normal_(param)
#             elif 'weight_hh' in name:
#                 torch.nn.init.orthogonal_(param)
#     elif isinstance(m, torch.nn.BatchNorm2d or torch.nn.BatchNorm1d):
#         m.weight.data.fill_(1)
#         m.bias.data.zero_()


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = tile_size // 4
        self.l1 = nn.Sequential(SpectralNorm(nn.Linear(latent_dim, 128 * self.init_size ** 2)))
        #self.attn1 = Self_Attn( 128 * 8, 'relu')

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            SpectralNorm(nn.Conv2d(128, 128, 3, stride=1, padding=1)),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            SpectralNorm(nn.Conv2d(128, 64, 3, stride=1, padding=1)),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            SpectralNorm(nn.Conv2d(64, 1, 3, stride=1, padding=1)),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img
    
    
class ConditionalGenerator(nn.Module):
    def __init__(self):
        super(ConditionalGenerator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [SpectralNorm(nn.Conv2d(in_filters, out_filters, 3, 2, 1)), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.init_size = tile_size // 4
        self.l1 = nn.Sequential(SpectralNorm(nn.Linear(latent_dim, 128 * self.init_size ** 2)))
        #self.attn1 = Self_Attn( 128 * 8, 'relu')
    
        self.downsample = nn.Sequential(
            *discriminator_block(1, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        # The height and width of downsampled image
        ds_size = tile_size // 2 ** 4
        self.down_layer = nn.Sequential(SpectralNorm(nn.Linear(128 * ds_size ** 2, latent_dim)))
        self.up_layer = nn.Sequential(SpectralNorm(nn.Linear(2 * latent_dim, 128 * self.init_size ** 2)))

        self.upsample = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            SpectralNorm(nn.Conv2d(128, 128, 3, stride=1, padding=1)),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            SpectralNorm(nn.Conv2d(128, 64, 3, stride=1, padding=1)),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            SpectralNorm(nn.Conv2d(64, 1, 3, stride=1, padding=1)),
            nn.Tanh(),
        )

    def forward(self, im, z):
        out = self.downsample(im)
        out = out.view(out.shape[0], -1)
        out = self.down_layer(out)
        out = torch.cat([out, z], 1)
        out = self.up_layer(out)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.upsample(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [SpectralNorm(nn.Conv2d(in_filters, out_filters, 3, 2, 1)), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(1, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = tile_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity
    
    
    
# class Generator(nn.Module):
#     """Generator."""

#     def __init__(self, batch_size = 32, image_size=512, z_dim=128, conv_dim=64):
#         super(Generator, self).__init__()
#         self.imsize = image_size
#         layer1 = []
#         layer2 = []
#         layer3 = []
#         layer4 = []
#         layer5 = []
#         layer6 = []
#         layer7 = []
#         last = []

#         repeat_num = int(np.log2(self.imsize)) - 3
#         mult = 2 ** repeat_num # 8
#         layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
#         layer1.append(nn.BatchNorm2d(conv_dim * mult))
#         layer1.append(nn.LeakyReLU(0.2, inplace=True))
#         #layer1.append(nn.Dropout2d(0.25))

#         curr_dim = conv_dim * mult

#         layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
#         layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
#         layer2.append(nn.LeakyReLU(0.2, inplace=True))
#         #layer2.append(nn.Dropout2d(0.25))

#         curr_dim = int(curr_dim / 2)

#         layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
#         layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
#         layer3.append(nn.LeakyReLU(0.2, inplace=True))
#         #layer3.append(nn.Dropout2d(0.25))
        
#         curr_dim = int(curr_dim / 2)

#         layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
#         layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
#         layer4.append(nn.LeakyReLU(0.2, inplace=True))
#         #layer4.append(nn.Dropout2d(0.25))
        
#         curr_dim = int(curr_dim / 2)

#         layer5.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
#         layer5.append(nn.BatchNorm2d(int(curr_dim / 2)))
#         layer5.append(nn.LeakyReLU(0.2, inplace=True))
#         #layer5.append(nn.Dropout2d(0.25))
        
#         curr_dim = int(curr_dim / 2)

#         layer6.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
#         layer6.append(nn.BatchNorm2d(int(curr_dim / 2)))
#         layer6.append(nn.LeakyReLU(0.2, inplace=True))
#         #layer6.append(nn.Dropout2d(0.25))
        
#         curr_dim = int(curr_dim / 2)
        
#         layer7.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
#         layer7.append(nn.BatchNorm2d(int(curr_dim / 2)))
#         layer7.append(nn.LeakyReLU(0.2, inplace=True))
#         #layer7.append(nn.Dropout2d(0.25))
        
#         curr_dim = int(curr_dim / 2)

#         self.l1 = nn.Sequential(*layer1)
#         self.l2 = nn.Sequential(*layer2)
#         self.l3 = nn.Sequential(*layer3)
#         self.l4 = nn.Sequential(*layer4)
#         self.l5 = nn.Sequential(*layer5)
#         self.l6 = nn.Sequential(*layer6)
#         self.l7 = nn.Sequential(*layer7)

#         last.append(nn.ConvTranspose2d(curr_dim, 1, 4, 2, 1))
#         last.append(nn.Tanh())
#         self.last = nn.Sequential(*last)

#         self.attn1 = Self_Attn( 512 * 8, 'relu')
#         self.attn2 = Self_Attn( 256 * 8, 'relu')
#         self.attn3 = Self_Attn( 128 * 8, 'relu')
#         self.attn4 = Self_Attn( 64 * 8,  'relu')
#         #self.attn4 = Self_Attn( 32 * 8, 'relu')
#         #self.attn5 = Self_Attn( 16 * 8, 'relu')

#     def forward(self, z):
#         z = z.view(z.size(0), z.size(1), 1, 1)
#         out=self.l1(z)
#         #out, _ = self.attn1(out)
#         out=self.l2(out)
#         #out, _ = self.attn2(out)
#         out=self.l3(out)
#         #out, _ = self.attn3(out)
#         out=self.l4(out)
#         #out, _ = self.attn4(out)
#         out = self.l5(out)
#         #out,p5 = self.attn5(out)
#         out = self.l6(out)
#         out = self.l7(out)
#         out=self.last(out)
#         return out


# class Discriminator(nn.Module):
#     """Discriminator, Auxiliary Classifier."""

#     def __init__(self, color_dim = 1, batch_size=32, image_size=512, conv_dim=64):
#         super(Discriminator, self).__init__()
#         self.imsize = image_size
#         layer1 = []
#         layer2 = []
#         layer3 = []
#         layer4 = []
#         layer5 = []
#         layer6 = []
#         layer7 = []
#         last = []

#         layer1.append(SpectralNorm(nn.Conv2d(color_dim, conv_dim, 4, 2, 1)))
#         layer1.append(nn.LeakyReLU(0.1))
#         layer1.append(nn.Dropout2d(0.25))

#         curr_dim = conv_dim

#         layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
#         layer2.append(nn.LeakyReLU(0.1))
#         layer2.append(nn.Dropout2d(0.25))
#         curr_dim = curr_dim * 2

#         layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
#         layer3.append(nn.LeakyReLU(0.1))
#         layer3.append(nn.Dropout2d(0.25))
#         curr_dim = curr_dim * 2
        
#         layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
#         layer4.append(nn.LeakyReLU(0.1))
#         layer4.append(nn.Dropout2d(0.25))
#         curr_dim = curr_dim * 2
        
#         layer5.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
#         layer5.append(nn.LeakyReLU(0.1))
#         layer5.append(nn.Dropout2d(0.25))
#         curr_dim = curr_dim * 2
        
#         layer6.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
#         layer6.append(nn.LeakyReLU(0.1))
#         layer6.append(nn.Dropout2d(0.25))
#         curr_dim = curr_dim * 2
        
#         layer7.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
#         layer7.append(nn.LeakyReLU(0.1))
#         layer7.append(nn.Dropout2d(0.25))
#         curr_dim = curr_dim * 2
            
#         self.l1 = nn.Sequential(*layer1)
#         self.l2 = nn.Sequential(*layer2)
#         self.l3 = nn.Sequential(*layer3)
#         self.l4 = nn.Sequential(*layer4)
#         self.l5 = nn.Sequential(*layer5)
#         self.l6 = nn.Sequential(*layer6)
#         self.l7 = nn.Sequential(*layer7)

#         last.append(nn.Conv2d(curr_dim, 1, 4))
#         self.last = nn.Sequential(*last)

#         #self.attn1 = Self_Attn(128, 'relu')
#         #self.attn2 = Self_Attn(256, 'relu')
#         self.attn3 = Self_Attn(512, 'relu')
#         self.attn4 = Self_Attn(1024, 'relu')
#         self.attn5 = Self_Attn(2048, 'relu')
#         self.attn6 = Self_Attn(4096, 'relu')

#     def forward(self, x):
#         out = self.l1(x)
#         out = self.l2(out)
#         #out, p1 = self.attn1(out)
#         out = self.l3(out)
#         #out, p2 = self.attn2(out)
#         out = self.l4(out)
#         #out, _ = self.attn3(out)
#         out = self.l5(out)
#         #out, _ = self.attn4(out)
#         out = self.l6(out)
#         #out, _ = self.attn5(out)
#         out = self.l7(out)
#         #out, _ = self.attn6(out)
#         out = self.last(out)

#         return out.squeeze()
    
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag