In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import plotly.express as px

In [2]:
clifar10 = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0,1)]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 28624538.21it/s]


Extracting data/cifar-10-python.tar.gz to data


In [3]:
class Discrimnator(nn.Module):
    def __init__(self,im_dim=(3,32,32)):
        super().__init__()
        self.im_dim = im_dim
        self.conv1 = nn.Conv2d(in_channels=self.im_dim[0],out_channels=64,kernel_size=4,stride=2,padding=1)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1)
        self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1)
        self.lin = nn.Linear(in_features=256*4*4,out_features=1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1,256*4*4)
        x = self.lin(x)
        x = self.sigmoid(x)
        return x

In [4]:
dis = Discrimnator()
dis(clifar10[0][0])


tensor([[0.4990]], grad_fn=<SigmoidBackward0>)

In [12]:
class Generator(nn.Module):
    def __init__(self,letant_dim=100):
        super().__init__()
        self.lin = nn.Linear(in_features=letant_dim,out_features=256*4*4)
        self.relu = nn.LeakyReLU(0.2)
        self.conv1 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv2 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv3 = nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.tanh = nn.Tanh()
    def forward(self,x):
        x = self.lin(x)
        x = self.relu(x)
        x = x.view(-1,256,4,4)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.tanh(x)
        return x


In [13]:
gen = Generator()
gen(torch.randn(1,100))

torch.Size([1, 100])
torch.Size([1, 256, 4, 4])
torch.Size([1, 128, 8, 8])
torch.Size([1, 64, 16, 16])
torch.Size([1, 3, 32, 32])


tensor([[[[-0.0830, -0.1002, -0.0856,  ..., -0.0945, -0.0746, -0.0997],
          [-0.1146, -0.1424, -0.1037,  ..., -0.1273, -0.1877, -0.1494],
          [-0.0684, -0.1336, -0.0728,  ..., -0.0765, -0.0754, -0.0741],
          ...,
          [-0.1105, -0.1292, -0.1220,  ..., -0.1244, -0.1105, -0.1279],
          [-0.0952, -0.0805, -0.0440,  ..., -0.1099, -0.0636, -0.0943],
          [-0.1064, -0.0926, -0.1178,  ..., -0.1043, -0.1061, -0.0898]],

         [[-0.1445, -0.1194, -0.1246,  ..., -0.1059, -0.1365, -0.1293],
          [-0.1068, -0.1354, -0.1335,  ..., -0.1743, -0.1698, -0.1413],
          [-0.1341, -0.1000, -0.0839,  ..., -0.0600, -0.1703, -0.1213],
          ...,
          [-0.1232, -0.1246, -0.1271,  ..., -0.1481, -0.1000, -0.1431],
          [-0.1181, -0.0807, -0.1180,  ..., -0.1128, -0.1119, -0.1155],
          [-0.1302, -0.1442, -0.1389,  ..., -0.1470, -0.1337, -0.1198]],

         [[-0.0637, -0.0304, -0.0523,  ..., -0.0280, -0.0702, -0.0304],
          [-0.0039, -0.0713, -