In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.utils import save_image
import os
import numpy as np
from PIL import Image, ImageDraw
from tqdm import tqdm
num_classes = 3
latent_dim = 100
image_size = 64
batch_size = 32
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
def generate_shape_image(label):
    img = Image.new("L", (image_size, image_size), "black") 
    draw = ImageDraw.Draw(img)
    color = 255
    if label == 0: 
        draw.ellipse((16, 16, 48, 48), fill=color)
    elif label == 1: 
        draw.rectangle((16, 16, 48, 48), fill=color)
    elif label == 2:
        draw.polygon([(32, 8), (8, 56), (56, 56)], fill=color)
    return img
data_images = []
data_labels = []
for label in range(num_classes):
    for _ in range(200):
        img = generate_shape_image(label)
        data_images.append(transforms.ToTensor()(img))
        data_labels.append(label)
data_images = torch.stack(data_images)
data_labels = torch.tensor(data_labels)
dataset = torch.utils.data.TensorDataset(data_images, data_labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [3]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),
            nn.Linear(512, image_size * image_size),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        labels = self.label_emb(labels)
        gen_input = torch.cat((labels, noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), 1, image_size, image_size) 
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(num_classes + image_size * image_size, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        labels = self.label_emb(labels)
        flat_img = img.view(img.size(0), -1)
        d_in = torch.cat((flat_img, labels), -1)
        validity = self.model(d_in)
        return validity


In [4]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

os.makedirs("cgan_samples", exist_ok=True)


In [5]:
for epoch in range(epochs):
    for imgs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
        batch_size_curr = imgs.size(0)
        real = torch.ones(batch_size_curr, 1).to(device)
        fake = torch.zeros(batch_size_curr, 1).to(device)

        imgs, labels = imgs.to(device), labels.to(device)
        optimizer_G.zero_grad()
        z = torch.randn(batch_size_curr, latent_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (batch_size_curr,)).to(device)
        gen_imgs = generator(z, gen_labels)
        g_loss = criterion(discriminator(gen_imgs, gen_labels), real)
        g_loss.backward()
        optimizer_G.step()
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(imgs, labels), real)
        fake_loss = criterion(discriminator(gen_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
    save_image(gen_imgs[:9], f"cgan_samples/epoch_{epoch+1}.png", nrow=3, normalize=True)


Epoch 1/10: 100%|████████████████████████████| 19/19 [00:01<00:00,  9.65it/s]
Epoch 2/10: 100%|████████████████████████████| 19/19 [00:01<00:00, 10.92it/s]
Epoch 3/10: 100%|████████████████████████████| 19/19 [00:01<00:00,  9.88it/s]
Epoch 4/10: 100%|████████████████████████████| 19/19 [00:01<00:00, 10.71it/s]
Epoch 5/10: 100%|████████████████████████████| 19/19 [00:01<00:00, 10.18it/s]
Epoch 6/10: 100%|████████████████████████████| 19/19 [00:01<00:00,  9.69it/s]
Epoch 7/10: 100%|████████████████████████████| 19/19 [00:01<00:00,  9.99it/s]
Epoch 8/10: 100%|████████████████████████████| 19/19 [00:02<00:00,  9.49it/s]
Epoch 9/10: 100%|████████████████████████████| 19/19 [00:03<00:00,  5.70it/s]
Epoch 10/10: 100%|███████████████████████████| 19/19 [00:02<00:00,  8.30it/s]


In [6]:
torch.save(generator.state_dict(), "generator_shapes.pth")
print("Training complete! Model saved as generator_shapes.pth")


Training complete! Model saved as generator_shapes.pth


In [9]:
import torch
from torchvision.utils import save_image
import os

os.makedirs("generated_shapes", exist_ok=True)

class Generator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = torch.nn.Embedding(num_classes, num_classes)
        self.model = torch.nn.Sequential(
            torch.nn.Linear(latent_dim + num_classes, 128),
            torch.nn.ReLU(True),
            torch.nn.Linear(128, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(True),
            torch.nn.Linear(256, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(True),
            torch.nn.Linear(512, image_size * image_size),
            torch.nn.Tanh()
        )

    def forward(self, noise, labels):
        labels = self.label_emb(labels)
        gen_input = torch.cat((labels, noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), 1, image_size, image_size)
        return img

generator = Generator().to(device)
generator.load_state_dict(torch.load("generator_shapes.pth", map_location=device))
generator.eval()

z = torch.randn(9, latent_dim).to(device)  # random noise
labels = torch.randint(0, num_classes, (9,)).to(device)  # random shape labels
gen_imgs = generator(z, labels)

save_image(gen_imgs, "generated_shapes/generated.png", nrow=3, normalize=True)
print("Images saved in generated_shapes/")


Images saved in generated_shapes/
