In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
from torchvision.utils import make_grid,save_image
import os 
from tqdm import tqdm


In [4]:
from datasets import load_dataset
ds = load_dataset("huggan/pokemon")

README.md:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


dataset_infos.json:   0%|          | 0.00/645 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/131M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7357 [00:00<?, ? examples/s]

In [5]:
transform = transforms.Compose([
    transforms.Resize(256),          # Resizes the shortest side to 64 pixels
    transforms.CenterCrop(256),      # Crops a 64x64 square from the center
    transforms.ToTensor(),          # Converts image (PIL or NumPy) to PyTorch Tensor (C x H x W), values in [0, 1]
    transforms.Normalize([0.5]*3, [0.5]*3)  # Scales pixel values from [0, 1] to [-1, 1] for each channel (R, G, B)
])


In [6]:
from torch.utils.data import Dataset
class POKEMON_DS(Dataset):
    def __init__(self,hf_dataset,transform = None):
        self.data = hf_dataset
        self.transform = transform
    def __len__(self):
        return self.data["train"].num_rows
    def __getitem__(self,idx):
        item = self.data["train"][idx]
        image = item["image"]
        if self.transform:
            img = self.transform(image)
        return img
        

In [7]:
from torch.utils.data import DataLoader

train_dataset = POKEMON_DS(ds, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)



In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

lr = 5e-5
z_dim = 100
image_size = 256
channels_img = 3
batch_size = 128
n_critic = 5
weight_clip = 0.01
epochs = 100


In [16]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img):
        super().__init__()
        self.net = nn.Sequential(
            self._block(z_dim, 512, 4, 1, 0),   # 4x4
            self._block(512, 256, 4, 2, 1),     # 8x8
            self._block(256, 128, 4, 2, 1),     # 16x16
            self._block(128, 64, 4, 2, 1),      # 32x32
            self._block(64, 32, 4, 2, 1),       # 64x64
            self._block(32, 16, 4, 2, 1),       # 128x128
            nn.ConvTranspose2d(16, channels_img, 4, 2, 1),  # 256x256
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
        )

    def forward(self, x):
        return self.net(x)

class Critic(nn.Module):
    def __init__(self, channels_img):
        super().__init__()
        self.net = nn.Sequential(
            self._block(channels_img, 16, 4, 2, 1),   # 128x128
            self._block(16, 32, 4, 2, 1),             # 64x64
            self._block(32, 64, 4, 2, 1),             # 32x32
            self._block(64, 128, 4, 2, 1),            # 16x16
            self._block(128, 256, 4, 2, 1),           # 8x8
            self._block(256, 512, 4, 2, 1),           # 4x4
            nn.Conv2d(512, 1, 4, 1, 0)                # 1x1
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.net(x).view(-1)


In [23]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)



In [26]:
gen = Generator(z_dim, channels_img).to(device)
critic = Critic(channels_img).to(device)
gen.apply(weights_init)
critic.apply(weights_init)



opt_gen = optim.Adam(gen.parameters(), lr=config.lr, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=config.lr, betas=(0.0, 0.9))

fixed_noise = torch.randn(16, z_dim, 1, 1).to(device)


In [11]:
import wandb

wandb.login(key="40155c09386ed20bd7ca7e488aa8f02a190ad188")


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m777bhavya[0m ([33mmv-anmol-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
import os

# -----------------------------------
# WAND INIT
# -----------------------------------
wandb.init(project="wgan-pokemon", config={
    "batch_size": 128,
    "epochs": 100,
    "lr": 1e-4,
    "n_critic": 5,
    "lambda_gp": 10,
    "image_size": 256,
    "z_dim": 100
})
config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")







opt_gen = optim.Adam(gen.parameters(), lr=config.lr, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=config.lr, betas=(0.0, 0.9))

os.makedirs("generated_images", exist_ok=True)

# -----------------------------------
# GRADIENT PENALTY
# -----------------------------------
def compute_gradient_penalty(critic, real, fake, device):
    batch_size = real.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)
    interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    d_interpolated = critic(interpolated)

    grad_outputs = torch.ones_like(d_interpolated)
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = ((gradient_norm - 1) ** 2).mean()
    return penalty

# -----------------------------------
# TRAIN LOOP
# -----------------------------------
for epoch in range(config.epochs):
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
    for i, real in enumerate(loop):
        real = real.to(device)
        batch_size = real.size(0)

        # Train Critic
        for _ in range(config.n_critic):
            noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
            fake = gen(noise).detach()
            real_score = critic(real)
            fake_score = critic(fake)

            gp = compute_gradient_penalty(critic, real, fake, device)
            loss_critic = -(real_score.mean() - fake_score.mean()) + config.lambda_gp * gp

            opt_critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()

        # Train Generator
        noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
        fake = gen(noise)
        loss_gen = -critic(fake).mean()

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Logging
        loop.set_postfix({
            "D Loss": round(loss_critic.item(), 2),
            "G Loss": round(loss_gen.item(), 2)
        })

        wandb.log({
            "critic_loss": loss_critic.item(),
            "generator_loss": loss_gen.item(),
        })

    # Save generated images
    with torch.no_grad():
        fake_images = gen(fixed_noise)
        fake_images = (fake_images + 1) / 2
        grid = make_grid(fake_images, nrow=4)
        save_image(grid, f"generated_images/epoch_{epoch+1:03d}.png")
        wandb.log({"generated_grid": [wandb.Image(grid, caption=f"Epoch {epoch+1}")]})


Epoch 1/100: 100%|██████████| 58/58 [02:54<00:00,  3.01s/it, D Loss=-1772.43, G Loss=-1427.35]
Epoch 2/100: 100%|██████████| 58/58 [02:54<00:00,  3.01s/it, D Loss=-1174.82, G Loss=-1476.32]
Epoch 3/100: 100%|██████████| 58/58 [02:54<00:00,  3.02s/it, D Loss=-850, G Loss=-1386.01]    
Epoch 4/100: 100%|██████████| 58/58 [02:54<00:00,  3.01s/it, D Loss=-576, G Loss=-1188.99]
Epoch 5/100: 100%|██████████| 58/58 [02:54<00:00,  3.02s/it, D Loss=-444, G Loss=-725]    
Epoch 6/100: 100%|██████████| 58/58 [02:55<00:00,  3.02s/it, D Loss=-452, G Loss=-275]
Epoch 7/100: 100%|██████████| 58/58 [02:54<00:00,  3.01s/it, D Loss=-351, G Loss=-101] 
Epoch 8/100: 100%|██████████| 58/58 [02:54<00:00,  3.00s/it, D Loss=-323, G Loss=-108] 
Epoch 9/100: 100%|██████████| 58/58 [02:54<00:00,  3.00s/it, D Loss=-315, G Loss=-100] 
Epoch 10/100: 100%|██████████| 58/58 [02:54<00:00,  3.00s/it, D Loss=-340, G Loss=-136] 
Epoch 11/100: 100%|██████████| 58/58 [02:53<00:00,  3.00s/it, D Loss=-521, G Loss=-196]
Epoch