In [1]:
import os, random, math, time
from PIL import Image, ImageDraw
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from transformers import BertTokenizer, BertModel

SEED                 = 42
DEVICE               = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE             = 28
DATA_ROOT            = "shapes_data"
CLASSES              = ["circle", "square", "triangle", "star"]
SAMPLES_PER_CLASS    = 2000
EPOCHS               = 40
BATCH_SIZE           = 128
LR                  = 2e-4
BETAS               = (0.5, 0.999)
NZ                   = 100
EMBED_DIM            = 768
SAMPLE_DIR           = "task5_samples"
os.makedirs(SAMPLE_DIR, exist_ok=True)

torch.manual_seed(SEED)
random.seed(SEED)

def build_shape_dataset(root=DATA_ROOT, classes=CLASSES, per_class=SAMPLES_PER_CLASS, size=IMG_SIZE):
    if os.path.exists(root) and all(os.path.exists(os.path.join(root, c)) for c in classes):
        return
    for c in classes:
        os.makedirs(os.path.join(root, c), exist_ok=True)

    def draw_shape(lbl, s=size):
        img = Image.new("L", (s, s), 0)
        d = ImageDraw.Draw(img)
        pad = random.randint(2, 5)
        if lbl == "circle":
            d.ellipse((pad, pad, s-pad, s-pad), fill=255)
        elif lbl == "square":
            d.rectangle((pad, pad, s-pad, s-pad), fill=255)
        elif lbl == "triangle":
            d.polygon([(s//2, pad), (pad, s-pad), (s-pad, s-pad)], fill=255)
        elif lbl == "star":
            cx, cy = s//2, s//2
            r1, r2 = s//2 - 2, s//4
            pts = []
            for i in range(10):
                r = r1 if i % 2 == 0 else r2
                ang = (i * 36 - 90) * math.pi / 180
                x = cx + int(r * math.cos(ang))
                y = cy + int(r * math.sin(ang))
                pts.append((x, y))
            d.polygon(pts, fill=255)
        return img

    print("-> Generating shape dataset...")
    for lbl in classes:
        cls_dir = os.path.join(root, lbl)
        for i in range(per_class):
            img = draw_shape(lbl)
            if random.random() < 0.5:
                img = img.rotate(random.randint(-10, 10), fillcolor=0)
            img.save(os.path.join(cls_dir, f"{lbl}_{i:05d}.png"))
    print("-> Dataset ready:", root)

build_shape_dataset()

tfm = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),                # [0,1]
    transforms.Normalize([0.5], [0.5])    # [-1,1]
])
dataset   = datasets.ImageFolder(root=DATA_ROOT, transform=tfm)
dataloader= DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

idx2class = {i: c for i, c in enumerate(dataset.classes)}

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert      = BertModel.from_pretrained("bert-base-uncased").to(DEVICE).eval()
@torch.no_grad()
def embed_labels(names):
    toks = tokenizer(names, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
    vecs = bert(**toks).last_hidden_state[:, 0, :]  # CLS
    return vecs

with torch.no_grad():
    class_embeds = embed_labels(dataset.classes)

class Generator(nn.Module):
    def __init__(self, z_dim=NZ, e_dim=EMBED_DIM):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(e_dim, 128),
            nn.ReLU(True)
        )
        self.fc = nn.Sequential(
            nn.Linear(z_dim + 128, 128 * 7 * 7),
            nn.ReLU(True)
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1),
            nn.Tanh()
        )
    def forward(self, z, e):
        ef = self.proj(e)
        x  = torch.cat([z, ef], dim=1)
        x  = self.fc(x).view(-1, 128, 7, 7)
        return self.deconv(x)

class Discriminator(nn.Module):
    def __init__(self, e_dim=EMBED_DIM):
        super().__init__()
        self.img_disc = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True)
        )
        self.emb_proj = nn.Sequential(
            nn.Linear(e_dim, 128), nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(64*7*7 + 128, 128), nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1), nn.Sigmoid()
        )
    def forward(self, x, e):
        fx = self.img_disc(x).view(x.size(0), -1)
        fe = self.emb_proj(e)
        return self.fc(torch.cat([fx, fe], dim=1))

G = Generator().to(DEVICE)
D = Discriminator().to(DEVICE)
bce = nn.BCELoss()
optG = optim.Adam(G.parameters(), lr=LR, betas=BETAS)
optD = optim.Adam(D.parameters(), lr=LR, betas=BETAS)

fixed_noise = torch.randn(16, NZ, device=DEVICE)
def sample_and_save(epoch_tag="init"):
    G.eval()
    with torch.no_grad():
        grids = []
        for cls_idx, cls_name in enumerate(dataset.classes):
            e = class_embeds[cls_idx].unsqueeze(0).repeat(16, 1)  # [16, 768]
            z = torch.randn(16, NZ, device=DEVICE)
            imgs = G(z, e).cpu()
            grid = make_grid(imgs, nrow=8, normalize=True, value_range=(-1,1))
            save_image(grid, os.path.join(SAMPLE_DIR, f"{cls_name}_e{epoch_tag}.png"))
            grids.append(grid)
        combined = torch.cat(grids, dim=1)
        save_image(combined, os.path.join(SAMPLE_DIR, f"_grid_e{epoch_tag}.png"))
    G.train()

print("-> Starting training...")
for epoch in range(1, EPOCHS+1):
    for real, tgt in dataloader:
        real = real.to(DEVICE)
        e = class_embeds[tgt.to(DEVICE)]

        bsz  = real.size(0)
        real_y = torch.ones(bsz, 1, device=DEVICE)
        fake_y = torch.zeros(bsz, 1, device=DEVICE)

        z = torch.randn(bsz, NZ, device=DEVICE)
        with torch.no_grad():
            fake = G(z, e)
        D.zero_grad()
        lossD = bce(D(real, e), real_y) + bce(D(fake.detach(), e), fake_y)
        lossD.backward()
        optD.step()

        G.zero_grad()
        z = torch.randn(bsz, NZ, device=DEVICE)
        gen = G(z, e)
        lossG = bce(D(gen, e), real_y)
        lossG.backward()
        optG.step()

    if epoch == 1 or epoch % 5 == 0:
        print(f"Epoch {epoch}/{EPOCHS}  |  D: {lossD.item():.4f}  G: {lossG.item():.4f}")
        sample_and_save(epoch_tag=str(epoch))

sample_and_save(epoch_tag=f"{EPOCHS}_final")
print(f"Done. Check samples in: {SAMPLE_DIR}/")


-> Generating shape dataset...
-> Dataset ready: shapes_data


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

-> Starting training...
Epoch 1/40  |  D: 0.6902  G: 3.1896
Epoch 5/40  |  D: 0.0872  G: 4.7652
Epoch 10/40  |  D: 0.0412  G: 4.9368
Epoch 15/40  |  D: 0.3202  G: 1.7731
Epoch 20/40  |  D: 0.0522  G: 4.5373
Epoch 25/40  |  D: 0.0465  G: 3.5917
Epoch 30/40  |  D: 0.0934  G: 6.2279
Epoch 35/40  |  D: 0.0730  G: 4.3288
Epoch 40/40  |  D: 0.0128  G: 5.6381
Done. Check samples in: task5_samples/
