In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import plotly.express as px
import pandas as pd

In [2]:
clifar10 = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0,1)]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:17<00:00, 9476174.51it/s]


Extracting data/cifar-10-python.tar.gz to data


In [3]:
class Discrimnator(nn.Module):
    def __init__(self,im_dim=(3,32,32)):
        super().__init__()
        self.im_dim = im_dim
        self.conv1 = nn.Conv2d(in_channels=self.im_dim[0],out_channels=64,kernel_size=4,stride=2,padding=1)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1)
        self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1)
        self.lin = nn.Linear(in_features=256*4*4,out_features=1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1,256*4*4)
        x = self.lin(x)
        x = self.sigmoid(x)
        return x

In [4]:
dis = Discrimnator()
dis(clifar10[0][0])


tensor([[0.4980]], grad_fn=<SigmoidBackward0>)

In [5]:
class Generator(nn.Module):
    def __init__(self,letant_dim=100):
        super().__init__()
        self.lin = nn.Linear(in_features=letant_dim,out_features=256*4*4)
        self.relu = nn.LeakyReLU(0.2)
        self.conv1 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv2 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv3 = nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.tanh = nn.Tanh()
    def forward(self,x):
        x = self.lin(x)
        x = self.relu(x)
        x = x.view(-1,256,4,4)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.tanh(x)
        return x


In [6]:
gen = Generator()
gen(torch.randn(1,100))

tensor([[[[-0.0465, -0.0229, -0.0615,  ..., -0.0306, -0.0788, -0.0488],
          [-0.0984, -0.0531, -0.0607,  ..., -0.0484, -0.0313, -0.0898],
          [-0.0568, -0.0393, -0.0889,  ...,  0.0017, -0.0681, -0.0434],
          ...,
          [-0.0971, -0.0781, -0.1480,  ..., -0.0555, -0.0798, -0.1103],
          [-0.0448, -0.0809, -0.0222,  ..., -0.0423, -0.0212, -0.0384],
          [-0.0930, -0.0778, -0.0968,  ..., -0.0813, -0.0815, -0.0971]],

         [[-0.0130, -0.0053,  0.0088,  ...,  0.0002, -0.0157,  0.0019],
          [ 0.0117,  0.0226, -0.0090,  ...,  0.0221,  0.0042,  0.0057],
          [-0.0357, -0.0512, -0.0443,  ...,  0.0074, -0.0600,  0.0106],
          ...,
          [-0.0241, -0.0052, -0.0422,  ...,  0.0329, -0.0165,  0.0275],
          [-0.0272, -0.0042, -0.0371,  ...,  0.0168, -0.0833, -0.0136],
          [-0.0180,  0.0352, -0.0086,  ...,  0.0146,  0.0038,  0.0092]],

         [[ 0.0220,  0.0206, -0.0050,  ...,  0.0339,  0.0152, -0.0044],
          [ 0.0227,  0.0746,  

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator  = Generator().to(device)
discriminator = Discrimnator().to(device)
optim_gen = torch.optim.Adam(generator.parameters(),lr=0.0002)
optim_dis = torch.optim.Adam(discriminator.parameters(),lr=0.0002)
loss_f = nn.BCELoss()

In [14]:
epocs = 1
batch_size = 128
from math import ceil
run_data = pd.DataFrame(columns=["epoch","batch","loss_gen","loss_dis"],index=range(ceil(len(clifar10)/batch_size)*epocs))

In [None]:
for i in range(epocs):
    for j,(x,y) in enumerate(DataLoader(clifar10,batch_size=batch_size,shuffle=True)):
        x = x.to(device)
        y_hat = discriminator(x).flatten()
        loss_dis_1 = loss_f(y_hat,torch.ones(len(x),device=device))

        noise = torch.randn(len(x),100).to(device)
        y_hat = discriminator(generator(noise)).flatten()
        loss_dis_2 = loss_f(y_hat,torch.zeros(len(x),device=device))
        loss_dis = loss_dis_1 + loss_dis_2/2

        optim_dis.zero_grad()
        loss_dis.backward()
        optim_dis.step()

        noise = torch.randn(batch_size,100).to(device)
        y_hat = discriminator(generator(noise)).flatten()
        loss_gen = loss_f(y_hat,torch.ones(len(x),device=device))
        optim_gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        run_data.loc[i*len(clifar10)+j] = [i,i*len(clifar10)+j,loss_gen.item(),loss_dis.item()]