# <center>Discriminative Regularize Generative Model for CIFAR10 </center>

## Load Data

In [1]:
# !pip3 install --upgrade torch torchvision

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from IPython.display import Image
#from google.colab import files

#Set random seed 
torch.manual_seed(512)

<torch._C.Generator at 0x7f54534a5050>

In [2]:
class preTrainedModel(nn.Module):
  
    def __init__(self):
      
      super(preTrainedModel,self).__init__()
      
      vgg_model = torchvision.models.vgg16(pretrained=True)		
      self.Conv1 = nn.Sequential(*list(vgg_model.features.children())[0:4])
      #self.Conv2 = nn.Sequential(*list(vgg_model.features.children())[4:9]) 
      #self.Conv3 = nn.Sequential(*list(vgg_model.features.children())[9:16])
      #self.upSample1 = nn.Upsample(scale_factor=2)
      #self.upSample2 = nn.Upsample(scale_factor=4)

    def forward(self,x):
      out1 = self.Conv1(x)
      #out2 = self.Conv2(out1)
      #out3 = self.Conv3(out2)
      ###### up sampling to create output with the same size
      #out2 = self.upSample1(out2)
      #out3 = self.upSample2(out3)
      #concat_features = torch.cat([out1, out2, out3], 1)
      return out1

In [3]:
#Load model 
vgg19 = preTrainedModel().eval().cuda()

In [4]:
#Get the CIFAR10 train images 
cifar = datasets.CIFAR10('./data/cifar/', train = True, download = True)

# Organize training data in batches, 
# normalize them to have values between [-1, 1] (?)

train_images = torch.utils.data.DataLoader ( datasets.CIFAR10('./data/cifar/', train = True, download=False,
                               transform=transforms.Compose([
                               #transforms.Resize(64), 
                               #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               transforms.ToTensor(),])) , 
                               batch_size = 32, shuffle = True)

Files already downloaded and verified


In [5]:
upsampling = nn.Upsample(size=256)
for batch_idx, (data,_) in enumerate(train_images):    
    out = vgg19(upsampling(data.cuda()))
    print(out.size())
    print(data.size())
    #print(data[0])
    break

torch.Size([32, 64, 256, 256])
torch.Size([32, 3, 32, 32])


## Model

We will use the arquitecture suggested by [Radford et al](https://arxiv.org/abs/1511.06434) for both the encoder and decoder. With convolutional layers in the encoder and fractionally-strided  convolutions  in  the  decoder.   In  each convolutional layer in the encoder we double the number of filters present in the previous layer and use a convolutional stride of 2.  In each convolutional layer in the decoder we use a fractional stride of 2 and halve the number of filters on each layer.

In [6]:
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

class VAE( nn.Module ):

    def __init__ ( self, image_size ,  hidden_dim , encoding_dim ):
        
        super( VAE, self ).__init__()
        
        self.encoding_dim = encoding_dim
        self.image_size = image_size
        self.hidden_dim = hidden_dim 
        
        # Decoder - Fractional strided convolutional layers
        self.decoder  = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 1, 0, bias = False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias = False),
            nn.Sigmoid() # nn.Tanh()  
        )
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(32, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(64, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 0, bias = False),
            nn.Sigmoid()
        )
        
        # Fully-connected layers
        self.fc1 = nn.Linear(256, self.hidden_dim)
        self.fc21 = nn.Linear(self.hidden_dim, self.encoding_dim)
        self.fc22 = nn.Linear(self.hidden_dim, self.encoding_dim)
        self.fc3 = nn.Linear(self.encoding_dim, self.hidden_dim)
        self.fc4 = nn.Linear(self.hidden_dim, 256)
    
    def decode (self, z):
        h3 = F.relu(self.fc3(z))
        h4 = F.sigmoid(self.fc4(h3))
        return self.decoder( h4.view(z.size(0),-1,1,1) ) 

        
    def forward(self, x):
        
        # Encode 
        encoded = F.relu(self.fc1( self.encoder(x).view(x.size(0), -1) ) )
        
        #Obtain mu and logvar
        mu = self.fc21( encoded )
        logvar = self.fc22 ( encoded )
        
        #Reparametrization trick
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mu)
        
        # Decode 
        decoded = self.decode(z)

        # return decoded, mu, logvar
        return decoded, mu , logvar


upsampling = nn.Upsample(size=256)
sigmoid = nn.Sigmoid()

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    d1_recon_x = sigmoid(vgg19(upsampling( recon_x )))
    d1_x = sigmoid(vgg19( upsampling( x )))

    L1 = F.mse_loss(d1_recon_x, d1_x, size_average=False)
    
    del d1_recon_x 
    del d1_x
    
    return BCE + KLD + L1

In [16]:
# from google.colab import files
# uploaded = files.upload()

#Define model
model = VAE( 32, 100, 20 ).cuda()
model.load_state_dict(torch.load('../models_save_cifar_checkpoint_epoch_100_bs32.pth'))
optimizer = optim.Adam(model.parameters(), lr=1e-3)

#Train model
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_images):
        data = Variable(data).cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_images.dataset),
                100. * batch_idx / len(train_images),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_images.dataset)))

In [None]:
import time

num_epochs = 200
for epoch in range(100,num_epochs+1):
    start = time.time()
    train(epoch)    
    end = time.time()
    print(f'Time elapsed: {end - start:.2f}')

====> Epoch: 100 Average loss: 7480.5012
Time elapsed: 316.49
====> Epoch: 101 Average loss: 7476.4171
Time elapsed: 320.55


In [11]:
torch.save(model.cpu().state_dict(), "../models_save_cifar_checkpoint_epoch_200_bs32.pth")
# files.download("save_checkpoint_epoch_70.pth")
               
               

In [13]:
with torch.no_grad():
        sample = torch.randn(64, 20)
        sample = model.decode(sample)
        #torch.save(model.cpu().state_dict(), "./save_checkpoint_epoch_"+str(epoch)+".pth")
        #files.download("./save_checkpoint_epoch_"+str(epoch)+".pth")
        torchvision.utils.save_image(sample.view(64, 3, 32, 32),'../results/sample_cifar_' + str(epoch) + '_bs32.png')
#         files.download('./sample_' + str(epoch) + '.png')

In [None]:
torch.save(sample.cpu(), "sample_70.pth")
files.download("sample_70.pth")

In [18]:
sample.size()

torch.Size([64, 3, 32, 32])