In [0]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.optim as optim
import numpy as np

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.optim import Adam

# from spectral import SpectralNorm

from torch.nn.utils import spectral_norm
from torch.nn.init import xavier_uniform_

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#data
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5)),
])

# Load STL10 dataset
# cifar10_dset = dsets.CIFAR10(root="./data",
#                              train=True,
#                              download=True,
#                              transform=transform)

stl10_dtset = dsets.STL10(root="./data",
                          download=True,
                          split='train+unlabeled',
                          transform=transform)


Files already downloaded and verified


In [0]:
# Parameters
batch_size = 16

params_loader = {
    'batch_size': batch_size,
    'shuffle': True
}

train_loader = DataLoader(stl10_dtset, **params_loader)

In [0]:
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, nb_conv_layers):
        super(_DecoderBlock, self).__init__()
        middle_channels = in_channels // 2
        
        layers = [
            sn_convT2d(in_channels, in_channels, kernel_size=2, stride=2),
            sn_conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True)
        ]
        layers += [
            sn_conv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
        ] * (num_conv_layers - 2)
        layers += [
            sn_conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        self.generate = nn.Sequential(*layers)

    def forward(self, x):
        return self.generate(x)

class GeneratorSeg(nn.Module):
    def __init__(self, color_ch=2):
        super(SegNet, self).__init__()
        # TODO: check nb channels in vgg
        
        vgg = models.vgg19_bn(pretrained=True)
        
        features = list(vgg.features.children())
        self.enc1 = nn.Sequential(*features[:7])
        self.enc2 = nn.Sequential(*features[7:14])
        self.enc3 = nn.Sequential(*features[14:27])
        self.enc4 = nn.Sequential(*features[27:40])
        self.enc5 = nn.Sequential(*features[40:])
        
        self.gen5 = nn.Sequential(
            *([sn_convT2d(512, 512, kernel_size=2, stride=2)]+
              [sn_conv2d(512, 512),
              nn.BatchNorm2d(512),
              nn.ReLU(True)]*4)
        )
        self.gen4 = GenBlock(1024, 256, 4)
        self.gen3 = GenBlock(512, 128, 4)
        self.gen2 = GenBlock(256, 64, 2)
        self.gen1 = GenBlock(128, color_ch, 2)
        
        self.attention1 = SelfAttention(128)
        self.attention2 = SelfAttention(64)
        
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)
        
        gen5 = self.gen5(enc5)
        gen4 = self.gen4(torch.cat([enc4, dec5], 1))
        gen3 = self.gen3(torch.cat([enc3, dec4], 1))
        gen3 = self.attention1(gen3)
        # Maybe should apply the attention layer on the enc
        # enc2 = self.attention1(enc2)
        gen2 = self.gen2(torch.cat([enc2, dec3], 1))
        gen2 = self.attention2(gen2)
        gen1 = self.gen1(torch.cat([enc1, dec2], 1))
        return dec1

In [0]:
def flatten(x):
    bs, ch, width, height = x.shape
    
    return x.view(bs, -1, width*height)

def sn_conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0):
    return spectral_norm(nn.Conv2d(in_channels,
                                   out_channels,
                                   kernel_size,
                                   stride,
                                   padding))

def sn_convT2d(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return spectral_norm(nn.ConvTranspose2d(in_channels,
                                            out_channels,
                                            kernel_size,
                                            stride,
                                            padding))

class SelfAttention(nn.Module):
    
    def __init__(self, ch_in, sq_fact=8):
        super(SelfAttention, self).__init__()
        
        self.ch_in = ch_in
        self.query = sn_conv2d(self.ch_in, self.ch_in//sq_fact, 1)
        self.key = sn_conv2d(self.ch_in, self.ch_in//sq_fact, 1)
        self.value = sn_conv2d(self.ch_in, self.ch_in, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        x_shape = x.shape
        # print(x_shape)
        
        proj_query = flatten(self.query(x)).permute(0, 2, 1)
        proj_key = flatten(self.key(x))
        proj_value = flatten(self.value(x))
        # print(f"query: {proj_query.shape}")
        # print(f"key: {proj_key.shape}")
        energy = torch.bmm(proj_query, proj_key)
        # print("energy", energy.shape)
        attention = self.softmax(energy)
        
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        # print(out.shape)
        out = out.view(*x_shape)
        
        out = self.gamma*out + x
        return out
        

In [0]:


class Discriminator(nn.Module):
    
    def __init__(self, in_dim=3, img_size=64,conv_dim=64):
        super(Discriminator, self).__init__()
        self.in_dim = in_dim
        self.img_size = img_size
        self.conv_dim = conv_dim
        
        self.layers = nn.Sequential(
            sn_conv2d(self.in_dim, self.conv_dim, 4, 2, 1),
            nn.LeakyReLU(0.1),
            sn_conv2d(self.conv_dim, self.conv_dim*2, 4, 2, 1),
            nn.LeakyReLU(0.1),
            sn_conv2d(self.conv_dim*2, self.conv_dim*4, 4, 2, 1),
            nn.LeakyReLU(0.1),
            sn_conv2d(self.conv_dim*4, self.conv_dim*4, 4, 2, 1),
            nn.LeakyReLU(0.1),
            # SelfAttention(256),
            # sn_conv2d(self.conv_dim*4, self.conv_dim*8, 4, 2, 1),
            # nn.LeakyReLU(0.1),
            # 
            # SelfAttention(512),
            # nn.Conv2d(self.conv_dim*8, 1, 4)
        )
        self.layer4 = nn.Sequential(
            sn_conv2d(self.conv_dim*4, self.conv_dim*8, 4, 2, 1),
            nn.LeakyReLU(0.1)
        )
        self.last = nn.Sequential(
            nn.Conv2d(self.conv_dim*8, 1, 4)
        )
        
        # TODO: change dynamicly the channels size
        self.attention1 = SelfAttention(256)
        self.attention2 = SelfAttention(512)
        
    def forward(self, x):
        out = self.layers(x)
        out = self.attention1(out)
        out = self.layer4(out)
        out = self.attention2(out)
        out = self.last(out)
        
        # out = self.layers(x)
        
        return out.squeeze()
        

In [0]:
class Generator(nn.Module):
    
    def __init__(self, z_dim=100, conv_dim=64, dim_out=3):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.conv_dim = conv_dim
        self.dim_out = dim_out
        
        self.layers = nn.Sequential(
            sn_convT2d(z_dim, conv_dim*8, 4),
            nn.BatchNorm2d(conv_dim*8),
            nn.ReLU(),
            sn_convT2d(conv_dim*8, conv_dim*4, 4, 2, 1),
            nn.BatchNorm2d(conv_dim*4),
            nn.ReLU(),
            sn_convT2d(conv_dim*4, conv_dim*2, 4, 2, 1),
            nn.BatchNorm2d(conv_dim*2),
            nn.ReLU(),
            sn_convT2d(conv_dim*2, conv_dim*2, 4, 2, 1),
            nn.BatchNorm2d(conv_dim*2),
            nn.ReLU(),
            # SelfAttention(128),
            # sn_convT2d(conv_dim*2, conv_dim, 4, 2, 1),
            # nn.BatchNorm2d(conv_dim),
            # nn.ReLU(),
            # SelfAttention(64),
            # nn.ConvTranspose2d(conv_dim, dim_out, 4, 2, 1),
            # nn.Tanh()
            
        )
        self.layer4 = nn.Sequential(
            sn_convT2d(conv_dim*2, conv_dim, 4, 2, 1),
            nn.BatchNorm2d(conv_dim),
            nn.ReLU()
        )
        self.last = nn.Sequential(
            nn.ConvTranspose2d(conv_dim, dim_out, 4, 2, 1),
            nn.Tanh()
        )
        self.attention1 = SelfAttention(128)
        self.attention2 = SelfAttention(64)
        
    def forward(self, z):
        # bs, ch, *_ = z.shape
        # z = z.view(bs, ch, 1, 1)
        
        out = self.layers(z)
        out = self.attention1(out)
        out = self.layer4(out)
        out = self.attention2(out)
        out = self.last(out)
        
        # out = self.layers(z)
        
        return out

In [0]:
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        xavier_uniform_(m.weight)
        m.bias.data.fill_(0.)


In [0]:
# Init
z_dim = 128

netG = Generator(z_dim=z_dim).to(device)
netD = Discriminator().to(device)

netG.apply(init_weights)
netD.apply(init_weights)

# parameters given in the original paper
lr_g = 0.0001
lr_d = 0.0004

betas = (0., 0.9)

optimizer_g = Adam(netG.parameters(), lr=lr_g, betas=betas)
optimizer_d = Adam(netD.parameters(), lr=lr_d, betas=betas)

# loss_c = nn.CrossEntropyLoss()


fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)


In [0]:
nb_epochs = 40

for epoch in range(nb_epochs):
    for idx, (images, _) in enumerate(train_loader):
        
        # The last batch hasn't the same batch size so skip it
        bs, *_ = images.shape
        if bs != batch_size:
            continue
            
        images = images.to(device)
        
        netD.train()
        netG.train()
            
        #######################
        # Train Discriminator #
        #######################
        
        # Train with real
        d_out_real = netD(images)
        # print(d_out_real.shape)
        # print("ok")
        # 
        # Train with fake
        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        # print(noise.shape)
        fakes = netG(noise)
        d_out_fake = netD(fakes)
        
        
        # adversial hinge loss
        d_loss_real = nn.ReLU()(1.0 - d_out_real).mean()
        d_loss_fake = nn.ReLU()(1.0 + d_out_fake).mean()
        d_loss = d_loss_real + d_loss_fake
        
        # Backward and optimize
        netD.zero_grad()
        d_loss.backward()
        optimizer_d.step()
        
        #######################
        # Train Discriminator #
        #######################
        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fakes = netG(noise)
        g_out_fake = netD(fakes)
        
        # Adversial hinge loss
        g_loss = -g_out_fake.mean()
        
        netG.zero_grad()
        g_loss.backward()
        optimizer_g.step()
        
        # print(d_loss_real.item())
        
        if idx % 100 == 0:
            print(f"Epoch [{epoch}/{nb_epochs}], "
                  f"iter[{idx}/{len(train_loader)}], "
                  f"d_out_real: {d_loss_real.item():.4f}, "
                  f"g_out_fake: {g_loss.item():.4f}")
            
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                              f'./x__{epoch}_epoch.png',
                              normalize=True)
                
        
        

Epoch [0/40], iter[0/6563], d_out_real: 1.0201, g_out_fake: -0.1283
Epoch [0/40], iter[100/6563], d_out_real: 0.0000, g_out_fake: 1.6002
Epoch [0/40], iter[200/6563], d_out_real: 0.3008, g_out_fake: 1.7927
Epoch [0/40], iter[300/6563], d_out_real: 0.0606, g_out_fake: 1.3476
Epoch [0/40], iter[400/6563], d_out_real: 0.0168, g_out_fake: 1.4809
Epoch [0/40], iter[500/6563], d_out_real: 0.6709, g_out_fake: 0.5638
Epoch [0/40], iter[600/6563], d_out_real: 0.8023, g_out_fake: 0.6284
Epoch [0/40], iter[700/6563], d_out_real: 0.8316, g_out_fake: 0.6061
Epoch [0/40], iter[800/6563], d_out_real: 0.4789, g_out_fake: 0.1009
Epoch [0/40], iter[900/6563], d_out_real: 0.6670, g_out_fake: 1.6214
Epoch [0/40], iter[1000/6563], d_out_real: 0.6209, g_out_fake: 1.1829
Epoch [0/40], iter[1100/6563], d_out_real: 0.4360, g_out_fake: 1.2441
Epoch [0/40], iter[1200/6563], d_out_real: 0.2373, g_out_fake: 1.8430
Epoch [0/40], iter[1300/6563], d_out_real: 0.8807, g_out_fake: 0.8769
Epoch [0/40], iter[1400/6563], 

In [0]:
print(len(train_loader.dataset))