In [1]:
import os
import cv2
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, utils

In [2]:
train_data = datasets.CIFAR10('CIFAR10', train = True, 
                         transform = transforms.Compose([transforms.ToTensor()]),
                         download = True)

train_batch = torch.utils.data.DataLoader(train_data, batch_size = 100, shuffle = True)

Files already downloaded and verified


In [11]:
def latent_space_vectors(size): #size is the number of samples in a batch
    return torch.randn(size, 100, 1, 1)#.to(device)

def real_data_target(size):
    return (torch.ones(size))#.to(device)

def fake_data_target(size):
    return (torch.zeros(size))#.to(device)

In [12]:
#generator will input a random noise of 100 (as usual) and the conditioned on the grayscale image (100, 1, 32, 32)
#y is random noise and x in condition

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        #UPSAMPLE RANDOM NOISE
        
        self.rand1 = nn.ConvTranspose2d(100, 128, kernel_size = 4, stride = 1, padding = 0)#4x4
        self.rand2 = nn.ConvTranspose2d(128, 64, kernel_size = 4, stride = 2, padding = 1)#8x8
        self.rand3 = nn.ConvTranspose2d(64, 64, kernel_size = 4, stride = 2, padding = 1)#16x16
        self.rand4 = nn.ConvTranspose2d(64, 3, kernel_size = 4, stride = 2, padding = 1)#32x32
        
        #ENCODER
        
        self.convE1 = nn.Conv2d(4, 64, kernel_size = (3, 3), stride = 1, padding = 1)#32x32
        self.convE2 = nn.Conv2d(64, 64, kernel_size = (3, 3), stride = 1, padding = 1)#32x32
        self.convE3 = nn.Conv2d(64, 128, kernel_size = (3, 3), stride = 1, padding = 1)#16x16
        self.convE4 = nn.Conv2d(128, 128, kernel_size = (3, 3), stride = 1, padding = 1)#16x16
        self.convE5 = nn.Conv2d(128, 256, kernel_size = (3, 3), stride = 1, padding = 1)#8x8
        self.convE6 = nn.Conv2d(256, 256, kernel_size = (3, 3), stride = 1, padding = 1)#8x8
        self.convE7 = nn.Conv2d(256, 512, kernel_size = (3, 3), stride = 1, padding = 1)#4x4
        self.convE8 = nn.Conv2d(512, 512, kernel_size = (3, 3), stride = 1, padding = 1)#4x4
        self.convE9 = nn.Conv2d(512, 1024, kernel_size = (3, 3), stride = 1, padding = 1)#2x2
        self.convE10 = nn.Conv2d(1024, 1024, kernel_size = (3, 3), stride = 1, padding = 1)#2x2
        
        #DECODER
        
        self.convT1 = nn.ConvTranspose2d(1024, 512, kernel_size = (4, 4), stride = 2, padding = 1)#4x4
        #concat
        self.convD1 = nn.Conv2d(1024, 512, kernel_size = (3, 3), stride = 1, padding = 1)#4x4
        self.convD2 = nn.Conv2d(512, 512, kernel_size = (3, 3), stride = 1, padding = 1)#4x4
        
        self.convT2 = nn.ConvTranspose2d(512, 256, kernel_size = (4, 4), stride = 2, padding = 1)#8x8
        #concat
        self.convD3 = nn.Conv2d(512, 256, kernel_size = (3, 3), stride = 1, padding = 1)#8x8
        self.convD4 = nn.Conv2d(256, 256, kernel_size = (3, 3), stride = 1, padding = 1)#8x8
        
        self.convT3 = nn.ConvTranspose2d(256, 128, kernel_size = (4, 4), stride = 2, padding = 1)#16x16
        #concat
        self.convD5 = nn.Conv2d(256, 128, kernel_size = (3, 3), stride = 1, padding = 1)#16x16
        self.convD6 = nn.Conv2d(128, 128, kernel_size = (3, 3), stride = 1, padding = 1)#16x16
        
        self.convT4 = nn.ConvTranspose2d(128, 64, kernel_size = (4, 4), stride = 2, padding = 1)#32x32
        #concat
        self.convD7 = nn.Conv2d(128, 64, kernel_size = (3, 3), stride = 1, padding = 1)#32x32
        self.convD8 = nn.Conv2d(64, 64, kernel_size = (3, 3), stride = 1, padding = 1)#32x32
        self.convD9 = nn.Conv2d(64, 2, kernel_size = (3, 3), stride = 1, padding = 1)#32x32
        
        
    def forward(self, x, y):
        
        #UPSAMPLE RANDOM NOISE
        
        x0 = F.leaky_relu(self.rand1(y))
        x0 = F.leaky_relu(self.rand2(x0))
        x0 = F.leaky_relu(self.rand3(x0))
        x0 = F.leaky_relu(self.rand4(x0))
        print(np.shape(x), np.shape(x0))
        x1 = torch.cat((x0, x), 1)
        
        #ENCODER
        
        x1 = F.leaky_relu(self.convE1(x1))
        x1_concat = F.leaky_relu(self.convE2(x1))
        
        x2 = F.max_pool2d(x1_concat, kernel_size = (2, 2), stride = 2, padding = 0)#16x16
        x2 = F.leaky_relu(self.convE3(x2))
        x2_concat = F.leaky_relu(self.convE4(x2))
        
        x3 = F.max_pool2d(x2_concat, kernel_size = (2, 2), stride = 2, padding = 0)#8x8
        x3 = F.leaky_relu(self.convE5(x3))
        x3_concat = F.leaky_relu(self.convE6(x3))
        
        x4 = F.max_pool2d(x3_concat, kernel_size = (2, 2), stride = 2, padding = 0)#4x4
        x4 = F.leaky_relu(self.convE7(x4))
        x4_concat = F.leaky_relu(self.convE8(x4))
        
        x5 = F.max_pool2d(x4_concat, kernel_size = (2, 2), stride = 2, padding = 0)#2x2
        x5 = F.leaky_relu(self.convE9(x5))
        x5 = F.leaky_relu(self.convE10(x5))
        
        #DECODER
        
        x5 = self.convT1(x5)#4x4
        x5 = torch.cat((x5, x4_concat), 1)
        x5 = F.leaky_relu(self.convD1(x5))
        x5 = F.leaky_relu(self.convD2(x5))
        
        x5 = self.convT2(x5)#8x8
        x5 = torch.cat((x5, x3_concat), 1)
        x5 = F.leaky_relu(self.convD3(x5))
        x5 = F.leaky_relu(self.convD4(x5))
        
        x5 = self.convT3(x5)#16x16
        x5 = torch.cat((x5, x2_concat), 1)
        x5 = F.leaky_relu(self.convD5(x5))
        x5 = F.leaky_relu(self.convD6(x5))
        
        x5 = self.convT4(x5)#32x32
        x5 = torch.cat((x5, x1_concat), 1)
        x5 = F.leaky_relu(self.convD7(x5))
        x5 = F.leaky_relu(self.convD8(x5))
        x5 = torch.sigmoid(self.convD9(x5))
        
        return x5

generator = Generator()
device = torch.device('cpu')
generator.to(device)

Generator(
  (rand1): ConvTranspose2d(100, 128, kernel_size=(4, 4), stride=(1, 1))
  (rand2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (rand3): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (rand4): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (convE1): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE7): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE8): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE9): Conv2d(512, 1

In [13]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        #ENCODER
        
        self.conv1 = nn.Conv2d(3, 3, kernel_size = 3, stride = 1, padding = 1)#32x32
        
        self.conv1_labels = nn.Conv2d(1, 3, kernel_size = 3, stride = 1, padding = 1)#32x32
        
        self.convE1 = nn.Conv2d(6, 32, kernel_size = (3, 3), stride = 1, padding = 1)#32x32 #6to3
        self.convE2 = nn.Conv2d(32, 64, kernel_size = (3, 3), stride = 1, padding = 1)#32x32
        self.convE3 = nn.Conv2d(64, 128, kernel_size = (3, 3), stride = 1, padding = 1)#16x16
        self.convE4 = nn.Conv2d(128, 128, kernel_size = (3, 3), stride = 1, padding = 1)#16x16
        self.convE5 = nn.Conv2d(128, 256, kernel_size = (3, 3), stride = 1, padding = 1)#8x8
        self.convE6 = nn.Conv2d(256, 256, kernel_size = (3, 3), stride = 1, padding = 1)#8x8
        self.convE7 = nn.Conv2d(256, 512, kernel_size = (3, 3), stride = 1, padding = 1)#4x4
        self.convE8 = nn.Conv2d(512, 512, kernel_size = (3, 3), stride = 1, padding = 1)#4x4
        self.convE9 = nn.Conv2d(512, 1024, kernel_size = (3, 3), stride = 1, padding = 1)#2x2
        self.convE10 = nn.Conv2d(1024, 1024, kernel_size = (3, 3), stride = 1, padding = 1)#2x2
        self.convE11 = nn.Conv2d(1024, 1, kernel_size = (4, 4), stride = 2, padding = 1)#1x1
        
    def forward(self, x, y):
        
        x = self.conv1(x)
        x = F.leaky_relu(x, 0.2)
        y = self.conv1_labels(y)
        y = F.leaky_relu(y, 0.2)
        
        x = torch.cat((x, y), 1)
        
        x1 = F.leaky_relu(self.convE1(x))
        x1_concat = F.leaky_relu(self.convE2(x1))
        
        x2 = F.max_pool2d(x1_concat, kernel_size = (2, 2), stride = 2, padding = 0)#16x16
        x2 = F.leaky_relu(self.convE3(x2))
        x2_concat = F.leaky_relu(self.convE4(x2))
        
        x3 = F.max_pool2d(x2_concat, kernel_size = (2, 2), stride = 2, padding = 0)#8x8
        x3 = F.leaky_relu(self.convE5(x3))
        x3_concat = F.leaky_relu(self.convE6(x3))
        
        x4 = F.max_pool2d(x3_concat, kernel_size = (2, 2), stride = 2, padding = 0)#4x4
        x4 = F.leaky_relu(self.convE7(x4))
        x4_concat = F.leaky_relu(self.convE8(x4))
        
        x5 = F.max_pool2d(x4_concat, kernel_size = (2, 2), stride = 2, padding = 0)#2x2
        x5 = F.leaky_relu(self.convE9(x5))
        x5 = F.leaky_relu(self.convE10(x5))
        
        x6 = torch.sigmoid(self.convE11(x5))
        return x6

discriminator = Discriminator()
discriminator.to(device)

Discriminator(
  (conv1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_labels): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE1): Conv2d(6, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE7): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE8): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE9): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE10): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convE11): Conv2d(1024, 1, k

In [14]:
loss_function = torch.nn.BCELoss()
optimizer_generator = optim.Adam(generator.parameters(), lr = 0.0002)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr = 0.0002)

In [15]:
def train_generator(fake_image, gray_scale):
    generator.zero_grad()
    
    fake_image = torch.cat((fake_image, gray_scale), 1)
    prediction_fake_image = discriminator.forward(fake_image, gray_scale).view(100)
    loss_fake_image = loss_function(prediction_fake_image, real_data_target(prediction_fake_image.size(0)))
    loss_fake_image.backward()
    optimizer_generator.step()
    
    return loss_fake_image

In [16]:
def train_discriminator(real_image, fake_image, gray_scale):
    discriminator.zero_grad()
    
    #train on real images
    prediction_real_image = discriminator.forward(real_image, gray_scale).view(100)
    loss_real_image = loss_function(prediction_real_image, real_data_target(prediction_real_image.size(0)))
    loss_real_image.backward()
    
    #train on fake images
    fake_image = torch.cat((fake_image, gray_scale), 1)
    prediction_fake_image = discriminator.forward(fake_image, gray_scale).view(100)
    loss_fake_image = loss_function(prediction_fake_image, fake_data_target(prediction_fake_image.size(0)))
    loss_fake_image.backward()
    
    optimizer_discriminator.step()
    
    return loss_real_image + loss_fake_image, prediction_real_image, prediction_fake_image

In [17]:
EPOCHS = 1

#losses_D = np.zeros([100, 500])
#losses_G = np.zeros([100, 500])

for epoch in range(EPOCHS):
    index = 0
    for images, condition in tqdm(zip(LAB_images_batch, L_images_batch)):
        
        real_image = torch.tensor(images, dtype=torch.float)#.to(device)
        gray_scale = torch.tensor(condition, dtype=torch.float)#.to(device)
        fake_image = generator.forward(gray_scale, latent_space_vectors(100))
        d_error, d_pred_real, d_pred_fake = train_discriminator(real_image, fake_image, gray_scale)
        
        fake_image = generator.forward(gray_scale, latent_space_vectors(100))
        g_error = train_generator(fake_image, gray_scale)
        
        #losses_D[epoch][index] = d_error
        #losses_G[epoch][index] = g_error
        #index += 1
        
        print('Epoch = ' + str(epoch) + " Discriminator loss = " + str(d_error.data.cpu().numpy()) + " Generator loss = " + str(g_error.data.cpu().numpy()))
        break

0it [00:00, ?it/s]

torch.Size([100, 1, 32, 32]) torch.Size([100, 3, 32, 32])
torch.Size([100, 1, 32, 32]) torch.Size([100, 3, 32, 32])
Epoch = 0 Discriminator loss = 1.386109 Generator loss = 0.68956476


In [3]:
i=0
count = 0
LAB_images_batch = np.zeros([500,100,3,32,32])
L_images_batch = np.zeros([500,100,1,32,32])
for m,v in train_batch:#each batch
    for j in range(len(m)):#each sample
        utils.save_image(torch.tensor(m[j]), os.path.join('pyTorch_cifar10' ,str(count)+'.png'))
        temp = cv2.imread(os.path.join('pyTorch_cifar10' ,str(count)+'.png'))
        temp = cv2.cvtColor(temp, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(temp)
        L_images_batch = np.zeros([500,100,1,32,32])
        L_images_batch[i][j] = l
        temp = np.transpose(temp, (2,0,1)) 
        LAB_images_batch[i][j] = temp
        count+=1
    i+=1

  import sys
