In [2]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
z_dim = 100
num_classes = 10
img_size = 28
channels = 1
epochs = 50
lr = 0.0002
beta1 = 0.5
os.makedirs("cgan_generated", exist_ok=True)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_loader = torch.utils.data.DataLoader(datasets.MNIST('.', train = True , download = True))

In [4]:
class Generator(nn.Module):
    def __init__(self, z_dim , num_classes , img_shape):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes , num_classes)
        self.img_shape = img_shape
        input_dim = z_dim + num_classes

        self.model = nn.Sequential(
            nn.Linear(input_dim , 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256,512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),
            nn.Linear(512,1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),

            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )
    
    def forward(self , noise , labels):
        # Concatenate noise and label embeddings
        x = torch.cat([noise, self.label_emb(labels)], dim =1)
        img = self.model(x)
        img = img.view(x.size(0), *self.img_shape)
        return img

In [10]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        input_dim = int(torch.prod(torch.tensor(img_shape))) + num_classes

        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # Flatten image and concatenate label
        img_flat = img.view(img.size(0), -1)
        x = torch.cat([img_flat, self.label_emb(labels)], dim=1)
        return self.model(x)

In [11]:
img_shape = (channels,img_size,img_size)
generator = Generator(z_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr,betas = (beta1,0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr = lr , betas = (beta1,0.999))

In [None]:

k = 3   
p = 1  