In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm

import numpy as np

import matplotlib.pyplot as plt

import tqdm
import cv2


In [None]:
!git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
!mv Synchronized-BatchNorm-PyTorch/sync_batchnorm .
from sync_batchnorm import SynchronizedBatchNorm2d

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# Base model:

In [None]:
from torch.nn import init


class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    def print_network(self):
        if isinstance(self, list):
            self = self[0]
        num_params = 0
        for param in self.parameters():
            num_params += param.numel()
        print('Network [%s] was created. Total number of parameters: %.1f million. '
              'To see the architecture, do print(network).'
              % (type(self).__name__, num_params / 1000000))

    def init_weights(self, init_type='normal', gain=0.02):
        def init_func(m):
            classname = m.__class__.__name__
            if classname.find('BatchNorm2d') != -1:
                if hasattr(m, 'weight') and m.weight is not None:
                    init.normal_(m.weight.data, 1.0, gain)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)
            elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'xavier_uniform':
                    init.xavier_uniform_(m.weight.data, gain=1.0)
                elif init_type == 'kaiming':
                    init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    init.orthogonal_(m.weight.data, gain=gain)
                elif init_type == 'none':  # uses pytorch's default init method
                    m.reset_parameters()
                else:
                    raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)

        self.apply(init_func)

        # propagate to children
        for m in self.children():
            if hasattr(m, 'init_weights'):
                m.init_weights(init_type, gain)

# Generator:

In [None]:
class SPADE(nn.Module):
    def __init__(self, norm_nc, label_nc, ks=3):
        super().__init__()

        self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
            nn.ReLU()
        )
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, segmap):

        # Part 1. generate parameter-free normalized activations
        normalized = self.param_free_norm(x)

        # Part 2. produce scaling and bias conditioned on semantic map
        segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
        actv   = self.mlp_shared(segmap)
        gamma  = self.mlp_gamma(actv)
        beta   = self.mlp_beta(actv)

        # apply scale and bias
        out = normalized * (1 + gamma) + beta

        return out

# semantic_nc = 3

# spade = SPADE(3, semantic_nc)
# seg = torch.rand(1,semantic_nc,256,256)
# img = torch.rand(1,3,64,64)

# spade(img, seg).shape

In [None]:
class SPADEResnetBlock(nn.Module):
    def __init__(self, fin, fout):
        super().__init__()
        # Attributes
        self.learned_shortcut = (fin != fout)
        fmiddle = min(fin, fout)

        semantic_nc = 125

        # create conv layers
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        # apply spectral norm if specified
        
        self.conv_0 = spectral_norm(self.conv_0)
        self.conv_1 = spectral_norm(self.conv_1)
        if self.learned_shortcut:
            self.conv_s = spectral_norm(self.conv_s)

        # define normalization layers
        self.norm_0 = SPADE(fin, semantic_nc)
        self.norm_1 = SPADE(fmiddle, semantic_nc)
        if self.learned_shortcut:
            self.norm_s = SPADE(fin, semantic_nc)

    # note the resnet block with SPADE also takes in |seg|,
    # the semantic segmentation map as input
    def forward(self, x, seg):
        x_s = self.shortcut(x, seg)

        dx = self.conv_0(self.actvn(self.norm_0(x,  seg)))
        dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))

        out = x_s + dx

        return out

    def shortcut(self, x, seg):
        if self.learned_shortcut:
            x_s = self.conv_s(self.norm_s(x, seg))
        else:
            x_s = x
        return x_s

    def actvn(self, x):
        return F.leaky_relu(x, 2e-1)


# spaderesnetblock = SPADEResnetBlock(3, 16)

# seg = torch.rand(1,3,256,256)
# img = torch.rand(1,3,64,64)

# spaderesnetblock(img, seg).shape

In [None]:
class SPADEGenerator(BaseNetwork):
    def __init__(self):
        super(SPADEGenerator, self).__init__()
        self.num_up_layers = 7
        nf = 64
        self.nf = nf
        self.z_dim = 256

        self.sw, self.sh = self.compute_latent_vector_size(self.num_up_layers)

        # In case of VAE, we will sample from random z vector
        self.fc = nn.Linear(self.z_dim, 16 * nf * self.sw * self.sh)

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf)
        
        final_nc = nf
        
        # MOST MOST MOST MOST MOST MOST MOST MOST MOST MOST MOST
        if self.num_up_layers == 7:
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2)
            final_nc = nf // 2
        


        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)

    def compute_latent_vector_size(self,num_up_layers ):
        sw = 256 // (2**num_up_layers)
        sh = round(sw)
        return sw, sh

    def forward(self, input, z=None):
        seg = input

        # we sample z from unit normal and reshape the tensor
        if z is None:
            # z = torch.randn(input.size(0), self.z_dim,
            #                 dtype=torch.float32, device=input.get_device())
            z = torch.randn(input.size(0), self.z_dim,dtype=torch.float32)
            
        x = self.fc(z)
        x = x.view(-1, 16 * self.nf, self.sh, self.sw)

        x = self.head_0(x, seg)

        x = self.up(x)
        x = self.G_middle_0(x, seg)

        if self.num_up_layers > 5:
            x = self.up(x)

        x = self.G_middle_1(x, seg)

        x = self.up(x)
        x = self.up_0(x, seg)
        x = self.up(x)
        x = self.up_1(x, seg)
        x = self.up(x)
        x = self.up_2(x, seg)
        x = self.up(x)
        x = self.up_3(x, seg)
        
        if self.num_up_layers == 7:
            x = self.up(x)
            x = self.up_4(x, seg)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)

        return x

# spadegenerator = SPADEGenerator()

In [None]:
# seg = torch.rand(1,3,256,256)
# generated = spadegenerator(seg)
# plt.imshow(generated[0].detach().permute(1,2,0))

# Encoder:

In [None]:
class ConvEncoder(BaseNetwork):
    """ Same architecture as the image discriminator """

    def __init__(self,):
        super(ConvEncoder, self).__init__()

        kw = 3
        pw = int(np.ceil((kw - 1.0) / 2))
        ndf = 64
        # norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
        # norm_layer = nn.InstanceNorm2d
        self.conv1 = nn.Conv2d(3, ndf, kw, stride=2, padding=pw)
        self.conv2 = nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)
        self.conv5 = nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)
        self.conv6 = nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)

        self.bn1 = nn.InstanceNorm2d(ndf)
        self.bn2 = nn.InstanceNorm2d(ndf * 2)
        self.bn3 = nn.InstanceNorm2d(ndf * 4)
        self.bn4 = nn.InstanceNorm2d(ndf * 8)
        self.bn5 = nn.InstanceNorm2d(ndf * 8)
        self.bn6 = nn.InstanceNorm2d(ndf * 8)


        self.so = s0 = 4
        self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
        self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)

        self.actvn = nn.LeakyReLU(0.2, False)

    def forward(self, x):
        if x.size(2) != 256 or x.size(3) != 256:
            x = F.interpolate(x, size=(256, 256), mode='bilinear')

        x = self.bn1(self.conv1(x))
        x = self.bn2(self.conv2(self.actvn(x)))
        x = self.bn3(self.conv3(self.actvn(x)))
        x = self.bn4(self.conv4(self.actvn(x)))
        x = self.bn5(self.conv5(self.actvn(x)))
        x = self.bn6(self.conv6(self.actvn(x)))
        x = self.actvn(x)

        x = x.view(x.size(0), -1)

        mu = self.fc_mu(x)
        logvar = self.fc_var(x)

        return mu, logvar

# encoder = ConvEncoder()
# seg = torch.rand(1,3,256,256)
# mu, logvar = encoder(seg)

# Discriminator:

In [None]:
class NLayerDiscriminator(BaseNetwork):

    def __init__(self):
        super(NLayerDiscriminator, self).__init__()

        kw = 4
        padw = int(np.ceil((kw - 1.0) / 2))
        nf = 64
        input_nc = 6

        sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
                     nn.LeakyReLU(0.2, False)]]

        for n in range(1, 4):
            nf_prev = nf
            nf = min(nf * 2, 512)
            stride = 1 if n == 4 - 1 else 2
            sequence += [[
                nn.Conv2d(nf_prev, nf, kernel_size=kw,
                                               stride=stride, padding=padw),
                nn.InstanceNorm2d(nf, affine=False),
                nn.LeakyReLU(0.2, False)
            ]]

        sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]

        # We divide the layers into groups to extract intermediate layer outputs
        for n in range(len(sequence)):
            self.add_module('model' + str(n), nn.Sequential(*sequence[n]))

        # self.model = nn.Sequential(*sequence[n])

    def forward(self, input):
        results = [input]
        for submodel in self.children():
            intermediate_output = submodel(results[-1])
            results.append(intermediate_output)

        return results[1:]

In [None]:
# discriminator = NLayerDiscriminator()
# seg = torch.rand(1,6,256,256)
# res = discriminator(seg)

# Loss functions:

# Training:

In [None]:
# VGG architecter, used for the perceptual loss using a pretrained VGG network
import torchvision
class VGG19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super().__init__()
        vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

# vgg19model = VGG19()

In [None]:
# /home/thesun/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth

In [None]:
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = VGG19().to(device)
        self.criterion = nn.L1Loss()
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

    def forward(self, x, y):
        xy = torch.cat((x, y), dim=0)
        xy_vgg = self.vgg(xy)
        x_vgg = []
        y_vgg = []
        for p in xy_vgg:
            x_vgg.append(p[:p.size(0) // 2])
            y_vgg.append(p[p.size(0) // 2:])

        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

# criterionVGG = VGGLoss()

# loss = criterionVGG(fake_image, real_image)
# print(loss)

---

In [None]:
class SpaceGANModel:
    def __init__(
        self, 
        encoder,
        spadegenerator,
        discriminator,
        use_vggloss=True,
    ):
        self.use_vggloss = use_vggloss
        self.lambda_kld = 0.1
        self.lambda_feat = 10
        self.lambda_vgg = 10

        self.encoder = encoder.to(device)
        self.generator = spadegenerator.to(device)
        self.discriminator = discriminator.to(device)
        
        self.featcriterion = torch.nn.L1Loss()
        self.VGGCriterion = VGGLoss()
#         self.VGGCriterion = torch.nn.MSELoss()



    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul(std) + mu

    def KLDlossfunction(self, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 

    def criterionGAN(self, x, target_is_real, for_discriminator=True):    
        zero_input_tensor = torch.zeros_like(x, requires_grad=False).to(device)
        if for_discriminator:
            if target_is_real:
                minval = torch.min(x - 1, zero_input_tensor)
                loss = -torch.mean(minval)
            else:
                minval = torch.min(-x - 1, zero_input_tensor)
                loss = -torch.mean(minval)
        else:
            assert target_is_real, "The generator's hinge loss must be aiming for real"
            loss = -torch.mean(x)

        return loss



    def FeatCriterion(self, pred_fake, pred_real ):
        loss = 0
        for i in range(len(pred_fake)-1):
            loss += self.featcriterion(pred_fake[i], pred_real[i].detach())

        return loss 



    def discriminate(self, input_semantics, fake_image, real_image):

        fake_concat = torch.cat([input_semantics, fake_image], dim=1)
        real_concat = torch.cat([input_semantics, real_image], dim=1)

        fake_and_real = torch.cat([fake_concat, real_concat], dim=0)

        discriminator_out = self.discriminator(fake_and_real)

        pred_fake = []
        pred_real = []

        for p in discriminator_out:
            pred_fake.append(p[:p.size(0) // 2])
            pred_real.append(p[p.size(0) // 2:])

        return pred_fake, pred_real


    def generator_training_step(self, segmentation, input_semantics, real_image):
        
        G_losses_dict = {}

        mu, logvar = self.encoder(real_image)
        z = self.reparameterize(mu, logvar)

        KLD_loss = self.KLDlossfunction(mu, logvar) * self.lambda_kld
        G_losses_dict['KLD_loss'] = KLD_loss

        fake_image = self.generator(input_semantics, z=z)

        pred_fake, pred_real = self.discriminate(
            segmentation,
            fake_image, 
            real_image
        )

        GAN_loss = self.criterionGAN(
            pred_fake[-1], True,
            for_discriminator=False
        )
        G_losses_dict['GAN_loss'] = GAN_loss


        FEAT_loss = self.FeatCriterion(pred_fake, pred_real)
        G_losses_dict['FEAT_loss'] = FEAT_loss * self.lambda_feat


        if self.use_vggloss:
            VGG_loss = self.VGGCriterion(fake_image, real_image)
            G_losses_dict['VGG_loss'] = VGG_loss * self.lambda_vgg

        
        return G_losses_dict, fake_image


    def discriminator_training_step(self, segmentation, input_semantics, real_image):
        D_losses = {}
        with torch.no_grad():
            mu, logvar = self.encoder(real_image)
            z = self.reparameterize(mu, logvar)
            fake_image = self.generator(input_semantics, z=z)
            fake_image = fake_image.detach()
            fake_image.requires_grad_()
            

        pred_fake, pred_real = self.discriminate(
            segmentation,
            fake_image,
            real_image
        )

        D_losses['D_Fake'] = self.criterionGAN(pred_fake[-1], False,
                                               for_discriminator=True)
        D_losses['D_real'] = self.criterionGAN(pred_real[-1], True,
                                               for_discriminator=True)

        return D_losses

In [None]:
# train_master = SpaceGANModel(
#     encoder=encoder,
#     spadegenerator=spadegenerator,
#     discriminator=discriminator
# )

In [None]:
# G_losses_dict, fake_image = train_master.generator_training_step(
#     input_semantics=torch.rand(1,3,256,256),
#     real_image=torch.rand(1,3,256,256)
# )

In [None]:
# D_losses_dict = train_master.discriminator_training_step(
#     input_semantics=torch.rand(1,3,256,256),
#     real_image=torch.rand(1,3,256,256)
# )

In [None]:
class SpaceGANModelTrainer:
    def __init__(
        self, 
        SpaceGanModel
    ):
        self.SpaceGanModel = SpaceGanModel
        
        
        self.g_losses = []
        self.d_losses = []
        self.last_g_losses = None
        self.last_d_losses = None
        self.last_mask = None
        
        optimizer_G, optimizer_D = self.create_optimizers()
        self.optimizer_G, self.optimizer_D = optimizer_G, optimizer_D
        
    
    def create_optimizers(self):
        G_params = list(self.SpaceGanModel.generator.parameters()) + list(self.SpaceGanModel.encoder.parameters())
        D_params = list(self.SpaceGanModel.discriminator.parameters())

        beta1, beta2 = 0.0, 0.9

        G_lr, D_lr = 0.0001, 0.0004

        optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
        optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))

        return optimizer_G, optimizer_D
        

    def run_generator_one_step(self, segmentation, input_semantics, real_image):
        self.optimizer_G.zero_grad()
        g_losses, generated = self.SpaceGanModel.generator_training_step(segmentation, input_semantics, real_image)
        g_loss = sum(g_losses.values()).mean()
        g_loss.backward()
        self.optimizer_G.step()
        g_losses = dict( (k, v.item()) for k, v in g_losses.items())
        self.g_losses  += [g_losses]
        self.last_g_losses  = g_losses
        self.last_generated = generated.detach().cpu()
        self.last_mask = input_semantics.detach().cpu()


    def run_discriminator_one_step(self, segmentation, input_semantics, real_image):
        self.optimizer_D.zero_grad()
        d_losses = self.SpaceGanModel.discriminator_training_step(segmentation, input_semantics, real_image)
        d_loss = sum(d_losses.values()).mean()
        d_loss.backward()
        self.optimizer_D.step()
        d_losses = dict( (k, v.item()) for k, v in d_losses.items())
        self.last_d_losses = d_losses
        self.d_losses += [d_losses]
        
    

# Data:

In [None]:
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms as T
import os

In [None]:
COLOR_JITTER = T.ColorJitter(0.5, 0.5, 0.5, 0.4)
color_transform = T.RandomChoice([
    T.RandomApply([COLOR_JITTER], p=0.6)
])

primary_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    # T.CenterCrop(size=256),
    T.RandomResizedCrop(size=256, ratio=(1, 1), scale=(0.08, 1)),
])

In [None]:
def get_segmentation(img, K, blur_kernel=(50,50)):
    blured = cv2.blur(img, blur_kernel)
    twoDimage = blured.reshape((-1,3))
    twoDimage = np.float32(twoDimage)

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    attempts=3
    ret,label,center=cv2.kmeans(twoDimage,K,None,criteria,attempts,cv2.KMEANS_PP_CENTERS)
    center = np.uint8(center)
    res = center[label.flatten()]
    result_image = res.reshape((img.shape))
    return result_image

In [None]:
class SpaceImageSegDataset(Dataset):
    def __init__(self, path):
        super(SpaceImageSegDataset, self).__init__()
        images = sorted(os.listdir(path))
        loaded_images = []
        for img_path in tqdm.tqdm(images):
            img = plt.imread(path + '/' + img_path)
            if img.shape[0] > 400 and img.shape[1] > 400:
                loaded_images.append(img)
                
        self.images = loaded_images
        self.path = path

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]

        img = torch.tensor(img)
        img = img.permute(2,0,1)
        
        img = primary_transform(img)
        img = color_transform(img).permute(1,2,0)
        
        seg = get_segmentation(np.array(img), 7, blur_kernel=(20,20))
        seg = torch.tensor(seg)

        N = 5
        seg = torch.floor(torch.floor((seg*(N/256)))*(255/N)).int()
        qnt = ((seg[...,0] + seg[...,1] * N + seg[...,2] * N * N) * N // 255).long()
#         print(qnt.unique())
        sem = torch.nn.functional.one_hot(qnt, N**3).permute(2,0,1)
        sem = sem.float()
        
        img = img.permute(2,0,1) / 127.5 - 1
        seg = seg.permute(2,0,1) / 127.5 - 1

        # plt.imshow(img.permute(1,2,0))
        # plt.show()
        # plt.imshow(seg.permute(1,2,0))

        return seg.unsqueeze(0), sem.unsqueeze(0) ,img.unsqueeze(0)

In [None]:
def collate_fn(batch):
    segs, sems, imgs = [], [], []
    for seg, sem, img in batch:
        segs += [seg]
        sems += [sem]
        imgs += [img]
    
    segs = torch.cat(segs, dim=0)
    sems = torch.cat(sems, dim=0)
    imgs = torch.cat(imgs, dim=0)
    
    return segs, sems, imgs

In [None]:
dataset = SpaceImageSegDataset('../input/spacegan/clean_images/clean_images')

dataloader = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=True, 
    collate_fn=collate_fn, 
    prefetch_factor=2, 
    num_workers=2
    )

In [None]:
print(len(dataset))

In [None]:
# for i in range(125):
#     plt.title(i)
#     plt.imshow(sem[0][i].cpu(), cmap='gray')
#     plt.show()

# Start training baby:

In [None]:
GANModel = SpaceGANModel(
    encoder=ConvEncoder(),
    spadegenerator=SPADEGenerator(),
    discriminator=NLayerDiscriminator(),
    use_vggloss=True 
)

In [None]:
trainer_master = SpaceGANModelTrainer(
    GANModel
)

In [None]:
old_chkpt = torch.load('../input/spacegan-cnt/spacegan4.pt')

In [None]:
trainer_master.SpaceGanModel.encoder.load_state_dict(old_chkpt['encoder4.pt'])
trainer_master.SpaceGanModel.generator.load_state_dict(old_chkpt['generator4.pt'])
trainer_master.SpaceGanModel.discriminator.load_state_dict(old_chkpt['discriminator4.pt'])
trainer_master.optimizer_D.load_state_dict(old_chkpt['optimizer_D4.pt'])
trainer_master.optimizer_G.load_state_dict(old_chkpt['optimizer_G4.pt'])

In [None]:
def moving_average(L, ss=200):
    s = len(L) // 2
    if s == 0:
        s = 1
    res = []
    for i in range(len(L)-s):
        res.append(np.mean(L[i:i+s]))
    return res

def plot_losses(all_losses, ss=200):
    keys = all_losses[0].keys()
    clean_losses = {}
    for k in keys:
        clean_losses[k] = moving_average(list(l[k] for l in all_losses), ss)

    if len(keys) == 4:
        fig, axs = plt.subplots(1, 4 ,figsize=(10, 3))
    else:
        fig, axs = plt.subplots(1, len(keys) ,figsize=(10, 5))
        

    for i, k in enumerate(keys):
        axs[i].plot(clean_losses[k])
        axs[i].set_title(k)

In [None]:
plot_losses(trainer_master.d_losses, ss=5000)

In [None]:
EPOCHS = 300
plot_per = 500
generator_per = 1



for epoch in range(200, EPOCHS+1):
    
    i = 0
    
    for sample in tqdm.tqdm(dataloader):
        
        segmentation_image, input_semantics, real_image = sample
        segmentation_image = segmentation_image.to(device)
        input_semantics = input_semantics.to(device)
        real_image = real_image.to(device)
        
        if i % 1 == 0:
            trainer_master.run_generator_one_step(segmentation_image, input_semantics, real_image)
            
        if i % 1 == 0:   
            trainer_master.run_discriminator_one_step(segmentation_image, input_semantics, real_image)
        
        i += 1
        
        
    if 0 % plot_per == 0:
        fig, axs = plt.subplots(1, 3 ,figsize=(15, 15))
        axs[0].imshow(((trainer_master.last_generated[0].permute(1,2,0).detach().cpu() + 1) /2 ))
        axs[0].set_title(f'Generated epoch:{epoch}')
        axs[1].imshow(((segmentation_image[0].permute(1,2,0).detach().cpu() + 1) /2 ))
        axs[1].set_title(f'Segmentation epoch:{epoch}')
        axs[2].imshow(((real_image[0].permute(1,2,0).detach().cpu() + 1) /2 ))
        axs[2].set_title(f'Real Image epoch:{epoch}')
        plt.show()
    
    if epoch % 10 == 0:  
        plot_losses(trainer_master.d_losses, 500)
        plot_losses(trainer_master.g_losses, 500)
    
    
        

In [None]:
for p in trainer_master.SpaceGanModel.encoder.parameters():
    print(p.shape)

In [None]:
n = 5
torch.save(trainer_master.SpaceGanModel.encoder.state_dict(), f'encoder{n}.pt')
torch.save(trainer_master.SpaceGanModel.generator.state_dict(), f'generator{n}.pt')
torch.save(trainer_master.SpaceGanModel.discriminator.state_dict(), f'discriminator{n}.pt')
torch.save(trainer_master.optimizer_D.state_dict(), f'optimizer_D{n}.pt')
torch.save(trainer_master.optimizer_G.state_dict(), f'optimizer_G{n}.pt')

In [None]:
trainer_master.optimizer_D.state_dict()

In [None]:
fig, axs = plt.subplots(1, 3 ,figsize=(15, 15))
axs[0].imshow(((trainer_master.last_generated[0].permute(1,2,0).detach().cpu() + 1) /2 ))
axs[0].set_title('Generated')
axs[1].imshow(((segmentation_image[0].permute(1,2,0).detach().cpu() + 1) /2 ))
axs[1].set_title('Segmentation')
axs[2].imshow(((real_image[0].permute(1,2,0).detach().cpu() + 1) /2 ))
axs[2].set_title('Real Image')
plt.show()

In [None]:
for i in range(125):
    plt.title(i)
    plt.imshow(input_semantics[0][i].cpu(), cmap='gray')
    plt.show()

In [None]:
input_semantics[0][2][0]

In [None]:
plt

In [None]:
class SpaceImageSegDataset(Dataset):
    def __init__(self, path):
        super(SpaceImageSegDataset, self).__init__()
        self.images = sorted(os.listdir(path))
        self.path = path

    def __len__(self):
        return len(self.images) // 2

    def __getitem__(self, idx):
        img_path_ = self.images[idx*2]

        img_path = self.path + '/' + img_path_[:-8] + 'crop.jpg'
        seg_path = self.path + '/' + img_path_[:-8] + 'mask.jpg'

        img = plt.imread(img_path)
        seg = plt.imread(seg_path)

        img = torch.tensor(img)
        seg = torch.tensor(seg)

        img = img.permute(2,0,1)
        seg = seg.permute(2,0,1)

        C, H, W = img.shape

        img_seg = torch.zeros((C, H, W * 2), dtype=seg.dtype)
        img_seg[...,:W] = img
        img_seg[...,W:] = seg
        img_seg = color_transform(img_seg)
        
        seg = img_seg[...,W:].permute(1,2,0)
        
        N = 5
        
#         seg = torch.zeros_like(seg).fill_(255)
        seg = torch.floor(torch.floor((seg*(N/256)))*(256/N)).int()
        qnt = ((seg[...,0] + seg[...,1] * N + seg[...,2] * N * N) * N // 255).long()
#         print(seg.max(), qnt.max(), qnt.min())
        sem = torch.nn.functional.one_hot(qnt, N**3).permute(2,0,1)
        sem = sem.float()
        
        img = img_seg[...,:W] / 127.5 - 1
        seg = seg / 127.5 - 1

        # plt.imshow(img.permute(1,2,0))
        # plt.show()
        # plt.imshow(seg.permute(1,2,0))

        return qnt, seg.unsqueeze(0), sem.unsqueeze(0) ,img.unsqueeze(0)
    
dataset = SpaceImageSegDataset('../input/spacegan/crops_masks')

In [None]:
qnt, seg, sem, img = dataset[1]

In [None]:
N = 5
x = torch.tensor(255)
print(torch.floor((x*(N/256))))
torch.floor(torch.floor((x*(N/256)))*(255/N)).int()

In [None]:
qnt.unique()

In [None]:
plt.imshow(seg[0].permute(1,2,0))

In [None]:
plt.imshow(qnt)

In [None]:
qnt.unique()

In [None]:
for i in range(125):
    plt.title(i)
    plt.imshow(sem[0][i].cpu(), cmap='gray')
    plt.show()

In [None]:
train