In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

latent_dim = 100
num_classes = 10  
img_size = 28 * 28
batch_size = 64
lr = 0.0002
epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 37.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.08MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.46MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.92MB/s]


In [3]:
embedding=nn.Embedding(5,10)
embedding(torch.tensor([0,0,1,1]))



tensor([[ 0.1600, -0.2225,  1.2101, -0.3174, -1.0013, -1.8289,  1.9442, -0.3616,
          0.7237, -0.5085],
        [ 0.1600, -0.2225,  1.2101, -0.3174, -1.0013, -1.8289,  1.9442, -0.3616,
          0.7237, -0.5085],
        [-0.8028, -0.3788, -0.2459, -1.0448,  1.2136, -2.1445, -0.4011,  0.5549,
         -1.1913,  1.4045],
        [-0.8028, -0.3788, -0.2459, -1.0448,  1.2136, -2.1445, -0.4011,  0.5549,
         -1.1913,  1.4045]], grad_fn=<EmbeddingBackward0>)

In [4]:
num_class=10
embed_len=20
latent_dim=100
img_size=28*28
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_embedding=nn.Embedding(num_class,embed_len)
        self.model=nn.Sequential(
            nn.Linear(latent_dim+embed_len, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.Linear(512, img_size),
            nn.Tanh()
            
        )
    def forward(self, z, label):
        # print(label.shape)
        c=self.label_embedding(label)
        print(c.shape)
        print(z.shape)
        x=torch.cat([z,c],dim=1)
        print(x.shape)
        return self.model(x)
        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_Gen=Generator().to(device)
model_Gen



Generator(
  (label_embedding): Embedding(10, 20)
  (model): Sequential(
    (0): Linear(in_features=120, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=256, out_features=512, bias=True)
    (5): Linear(in_features=512, out_features=784, bias=True)
    (6): Tanh()
  )
)

In [5]:
batch=32
z=torch.randn(batch, latent_dim)
z.shape
z=torch.tensor(z)
label=torch.randint(1,num_class , (batch,))
label=torch.tensor(label)

model_Gen(z.to(device), label.to(device))


  z=torch.tensor(z)
  label=torch.tensor(label)


torch.Size([32, 20])
torch.Size([32, 100])
torch.Size([32, 120])


tensor([[ 0.0143,  0.0015, -0.0002,  ..., -0.0479, -0.0835, -0.0947],
        [ 0.0408, -0.0328, -0.0309,  ..., -0.0341,  0.0052, -0.1011],
        [-0.0098,  0.0299, -0.0009,  ..., -0.0648, -0.0211, -0.0954],
        ...,
        [-0.0117, -0.1247,  0.0605,  ..., -0.1003,  0.1082,  0.0528],
        [ 0.0345,  0.0115, -0.0060,  ..., -0.0202,  0.0302, -0.1175],
        [-0.0246, -0.0441, -0.0371,  ..., -0.0260, -0.0021, -0.0737]],
       grad_fn=<TanhBackward0>)

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embd=nn.Embedding(num_class, embed_len)
        self.model=nn.Sequential(
            nn.Linear(img_size+embed_len, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024,1),
            nn.Sigmoid()
            
        )
    def forward(self, img, label):
        c=self.label_embd(label)
        x=torch.cat([img,c], dim=1)
        return self.model(x)


model_Dis=Discriminator().to(device)
model_Dis

Discriminator(
  (label_embd): Embedding(10, 20)
  (model): Sequential(
    (0): Linear(in_features=804, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Linear(in_features=1024, out_features=1, bias=True)
    (7): Sigmoid()
  )
)

In [7]:
img=torch.rand(batch,28,28)
img.shape
# img=nn.Flatten(img)
type(img)

img.shape

img_gen=img.view(img.size(0),-1).to(device)
img_gen.shape

label=torch.randint(0,num_class,(batch,))

p=model_Dis(img_gen, label)
p.shape

torch.Size([32, 1])

In [8]:
loss_fn=nn.BCELoss()

optimizer_G=torch.optim.Adam(params=model_Gen.parameters(),lr=lr)

optimizer_D=torch.optim.Adam(params=model_Dis.parameters(),lr=lr)

fixed_noise=torch.randn(16, latent_dim)
fixed_noise.shape
fixed_labels = torch.arange(0, 8).repeat(2).to(device)

fixed_labels.shape
os.makedirs("cgan_outputs", exist_ok=True)

In [9]:
epoch=1
# for epochs in range(epoch)
for i, (img, label) in enumerate(train_loader):
    if i>1:
        break
    print(label)
    print(img.shape)
    

    batch_current=img.size(0)
    real_label=torch.full((batch_current, 1), 0.9, device=device)
    fake_label=torch.zeros(batch_current,1).to(device)
    real_img=img.view(batch_current,-1)
    real_img=real_img.to(device)
    #Discriminator
    optimizer_D.zero_grad()
    real_out=model_Dis(real_img, label)
    real_loss=loss_fn(real_out,real_label)



    z=torch.randn(batch_current, latent_dim).to(device)
    fake_img=model_Gen(z,label)
    fake_img=fake_img.view(batch_current,-1)
    fake_out=model_Dis(fake_img.detach(),label)
    print(fake_out.shape)
    fake_loss=loss_fn(fake_out, fake_label)

    total_loss=real_loss + fake_loss
    total_loss.backward()
    optimizer_D.step()
    #Generator Training


    optimizer_G.zero_grad()
    gen_valid=model_Dis(fake_img, label)
    gen_loss=loss_fn(gen_valid, real_label)
    gen_loss.backward()
    optimizer_G.step()
    if i % 1 == 0:
            # print(f"[Epoch {epoch+1}/{epochs}] [Batch {i}/{len(train_loader)}] "
            #       f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
        print(f"[D loss: {total_loss.item():.4f}] [G loss: {gen_loss.item():.4f}]")
with torch.no_grad():
    generated = model_Gen(fixed_noise, fixed_labels).view(-1, 1, 28, 28)
    save_image(generated, f"cgan_outputs/sample_epoch_{epoch+1}.png", normalize=True)
    
    

    
    
    
    
    
    
    
    

tensor([2, 9, 9, 9, 0, 2, 2, 6, 6, 0, 0, 6, 5, 9, 1, 2, 6, 8, 9, 9, 1, 9, 3, 6,
        6, 2, 8, 2, 6, 7, 1, 6, 3, 4, 2, 0, 7, 9, 7, 4, 8, 0, 7, 6, 9, 4, 0, 1,
        8, 0, 6, 3, 7, 8, 4, 7, 6, 7, 4, 5, 1, 3, 8, 1])
torch.Size([64, 1, 28, 28])
torch.Size([64, 20])
torch.Size([64, 100])
torch.Size([64, 120])
torch.Size([64, 1])
[D loss: 1.3940] [G loss: 0.6916]
tensor([2, 9, 0, 4, 2, 8, 0, 1, 1, 3, 8, 6, 9, 9, 4, 7, 3, 8, 9, 6, 2, 6, 7, 3,
        9, 9, 4, 9, 4, 0, 1, 9, 6, 3, 5, 0, 6, 0, 9, 5, 8, 3, 6, 9, 9, 0, 4, 8,
        9, 1, 8, 0, 2, 9, 1, 8, 3, 3, 0, 8, 3, 0, 8, 5])
torch.Size([64, 1, 28, 28])
torch.Size([64, 20])
torch.Size([64, 100])
torch.Size([64, 120])
torch.Size([64, 1])
[D loss: 1.3199] [G loss: 0.6903]
torch.Size([16, 20])
torch.Size([16, 100])
torch.Size([16, 120])
