In [16]:
import cv2
#import mediapipe as mp
import numpy as np
import time
from shapely import Point, Polygon
import pandas as pd
from PIL import Image, ImageChops
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10

In [17]:
#set gpu for execution
device = torch.device('cuda')


In [18]:
class Discriminator(nn.Module):
    def __init__(self, embeddings, nc=3, ndf=64):
        super(Discriminator, self).__init__()

        self.embeddings = embeddings
        self.label_to_image = nn.Linear(100,64*64*3)
        self.conv1 = nn.Conv2d(nc * 2, nc, 1, 1, 0, bias=False)

        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x, label_embed):
        x = x.view(-1,3,64,64)
        label_embed = self.embeddings(label_embed)

        label_map = self.label_to_image(label_embed)
        label_map = label_map.view(-1,3,64,64)

        x = torch.cat([x,label_map], dim=1)

        out = self.conv1(x)
        output = self.main(out)

        return output

In [40]:
class Generator(nn.Module):
    def __init__(self, embeddings, nc=3, nz=100, ngf=64):
        super(Generator, self).__init__()

        self.embeddings = embeddings
        self.linear = nn.Linear(200,100)

        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, x, label_embed):
        x = x.view(-1,nz,1,1)
        label_embed = self.embeddings(label_embed)

        x = x.view(-1,100)
        print(f'label_embed shape', label_embed.shape)
        print(f'x shape', x.shape)
        x = torch.cat([x,label_embed], dim=1)

        x = self.linear(x)
        x = x.unsqueeze(2).unsqueeze(3)

        output = self.main(x)
        return output

In [20]:
batchsize = 200
epochs = 500

#Use the CFAR10 dataset
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data import random_split
from torchvision import transforms

# Define the dataset
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = CIFAR10(root='data/', download=True, transform=transform)
txt_label = train_data.classes
print(txt_label)


Files already downloaded and verified
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [21]:
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize,shuffle=True,drop_last=True)



In [82]:
embeddings = nn.Embedding(10,100).to(device)
embeddings.weight.requires_grad = False

netD = Discriminator(embeddings).to(device)
netG = Generator(embeddings).to(device)


optimizerD = optim.Adam(netD.parameters(),lr=0.0002,betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(),lr=0.0002,betas=(0.5, 0.999))

netD.train()
netG.train()

nz = 100

criterion = nn.BCELoss()

real_label = torch.ones([batchsize,1], dtype=torch.float).to(device)
fake_label = torch.zeros([batchsize,1], dtype=torch.float).to(device)


for epoch in range(epochs):
    for i, (input_sequence, label) in enumerate(train_data_loader):

        fixed_noise = torch.randn(batchsize, nz, 1, 1, device=device)

        input_sequence = input_sequence.to(device)
        label_embed = label.to(device)

        '''
            Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        '''

        D_real_result = netD(input_sequence, label_embed).to(device)
        D_real_loss = criterion(D_real_result.view(batchsize,-1), real_label)

        G_result = netG(fixed_noise,label_embed)

        D_fake_result = netD(G_result,label_embed)

        D_fake_loss = criterion(D_fake_result.view(batchsize,-1), fake_label)

        # Back propagation
        D_train_loss = (D_real_loss + D_fake_loss) / 2

        netD.zero_grad()
        D_train_loss.backward()
        optimizerD.step()

        '''
            Update G network: maximize log(D(G(z)))
        '''
        new_label = torch.LongTensor(batchsize,10).random_(0, 10).to(device)
        new_embed = new_label[:,0].view(-1)

        G_result = netG(fixed_noise, new_embed)

        D_fake_result = netD(G_result, new_embed)
        G_train_loss = criterion(D_fake_result.view(batchsize,-1), real_label)


        # Back propagation
        netD.zero_grad()
        netG.zero_grad()
        G_train_loss.backward()
        optimizerG.step()

        print("D_loss:%f\tG_loss:%f" % (D_train_loss,G_train_loss))
        #show a generated image on every 25th epoch then close it after 5 seconds

        if epoch % 25 == 0:
            G_result = netG(fixed_noise, new_embed)
            G_result = G_result.cpu().detach().numpy()
            plt.imshow(G_result[0].transpose(1,2,0))
            plt.show()
            plt.close()
            time.sleep(5)


Output hidden; open in https://colab.research.google.com to view.

In [83]:
#save model and checkpoints then load again to resume training
torch.save(netG.state_dict(), 'netG.pth')
torch.save(netD.state_dict(), 'netD.pth')
torch.save(embeddings.state_dict(), 'embeddings.pth')

netG.load_state_dict(torch.load('netG.pth'))
netD.load_state_dict(torch.load('netD.pth'))
embeddings.load_state_dict(torch.load('embeddings.pth'))
















<All keys matched successfully>