In [33]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from IPython.display import clear_output

In [2]:
clifar10 = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0,1)]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:20<00:00, 8459189.95it/s] 


Extracting data/cifar-10-python.tar.gz to data


In [3]:
class Discrimnator(nn.Module):
    def __init__(self,im_dim=(3,32,32)):
        super().__init__()
        self.im_dim = im_dim
        self.conv1 = nn.Conv2d(in_channels=self.im_dim[0],out_channels=64,kernel_size=4,stride=2,padding=1)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1)
        self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1)
        self.lin = nn.Linear(in_features=256*4*4,out_features=1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1,256*4*4)
        x = self.lin(x)
        x = self.sigmoid(x)
        return x

In [4]:
dis = Discrimnator()
dis(clifar10[0][0])


tensor([[0.5027]], grad_fn=<SigmoidBackward0>)

In [5]:
class Generator(nn.Module):
    def __init__(self,letant_dim=100):
        super().__init__()
        self.lin = nn.Linear(in_features=letant_dim,out_features=256*4*4)
        self.relu = nn.LeakyReLU(0.2)
        self.conv1 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv2 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv3 = nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.tanh = nn.Tanh()
    def forward(self,x):
        x = self.lin(x)
        x = self.relu(x)
        x = x.view(-1,256,4,4)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.tanh(x)
        return x


In [6]:
gen = Generator()
gen(torch.randn(1,100))

tensor([[[[-0.1308, -0.0704, -0.1378,  ..., -0.0901, -0.1152, -0.1172],
          [-0.0864, -0.1081, -0.0883,  ..., -0.1261, -0.1067, -0.1102],
          [-0.0766, -0.1182, -0.1174,  ..., -0.1203, -0.1090, -0.1313],
          ...,
          [-0.1186, -0.0897, -0.1561,  ..., -0.0960, -0.1465, -0.0883],
          [-0.1004, -0.1451, -0.0559,  ..., -0.1281, -0.1293, -0.1188],
          [-0.1161, -0.1169, -0.1654,  ..., -0.1199, -0.1197, -0.0977]],

         [[-0.0586, -0.0437, -0.0675,  ..., -0.0064, -0.0590, -0.0324],
          [-0.0486, -0.0676, -0.0838,  ..., -0.0685, -0.0704, -0.0568],
          [-0.0904, -0.0356, -0.0498,  ..., -0.0035, -0.0690, -0.0358],
          ...,
          [-0.0634, -0.0973, -0.0647,  ..., -0.0568, -0.0521, -0.0695],
          [-0.0734, -0.0141, -0.0483,  ...,  0.0079, -0.0997, -0.0467],
          [-0.0553, -0.0958, -0.0444,  ..., -0.0696, -0.0512, -0.0583]],

         [[ 0.0229,  0.0191,  0.0522,  ...,  0.0279,  0.0230,  0.0384],
          [ 0.0125, -0.0166,  

In [45]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator  = Generator().to(device)
discriminator = Discrimnator().to(device)
optim_gen = torch.optim.Adam(generator.parameters(),lr=0.0002)
optim_dis = torch.optim.Adam(discriminator.parameters(),lr=0.0002)
loss_f = nn.BCELoss()

In [46]:
epocs = 20
batch_size = 128
from math import ceil
run_data = pd.DataFrame(columns=["epoch","batch","loss_gen","loss_dis"],index=range(ceil(len(clifar10)/batch_size)*epocs))

In [47]:
run_ind = 0
for i in range(epocs):
    for j,(x,y) in enumerate(DataLoader(clifar10,batch_size=batch_size,shuffle=True)):
        x = x.to(device)
        y_hat = discriminator(x).flatten()
        loss_dis_1 = loss_f(y_hat,torch.ones(len(x),device=device))

        noise = torch.randn(len(x),100).to(device)
        y_hat = discriminator(generator(noise)).flatten()
        loss_dis_2 = loss_f(y_hat,torch.zeros(len(x),device=device))
        loss_dis = loss_dis_1 + loss_dis_2/2

        optim_dis.zero_grad()
        loss_dis.backward()
        optim_dis.step()

        noise = torch.randn(len(x),100).to(device)
        y_hat = discriminator(generator(noise)).flatten()
        loss_gen = loss_f(y_hat,torch.ones(len(x),device=device))
        optim_gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        run_data.loc[run_ind] = [i,j,loss_gen.item(),loss_dis.item()]
        run_ind+=1
    clear_output(wait=True)
    px.line(run_data,x=run_data.index,y=["loss_gen","loss_dis"]).show()
    exapmle = generator(torch.randn(5,100).to(device)).detach().cpu().permute(0,2,3,1)
    px.imshow(exapmle,facet_col=0,facet_col_wrap=5).show()


