In [None]:
!pip install opendatasets

In [None]:
import opendatasets as od
import pandas as pd
import os
od.download("https://www.kaggle.com/datasets/yehongjiang/pokemon-sprites-images")
pokemonData = pd.read_csv("pokemon_labels.csv")

import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms as T
import torch.optim as optim
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


In [None]:
class PokemonDataset(Dataset):
    def __init__(self, root_dir, dataframe, transform=None):
        self.root_dir = root_dir
        # Filter out rows where the folder is missing
        valid_indices = []
        invalid_ids = []
        for idx in range(len(dataframe)):
            row = dataframe.iloc[idx]
            poke_id = str(row["id"])
            folder_found = False
            for f in os.listdir(self.root_dir):
                if f.startswith(f"{poke_id}-"):
                    folder_found = True
                    break
            if folder_found:
                valid_indices.append(idx)
            else:
                invalid_ids.append(poke_id)

        self.df = dataframe.iloc[valid_indices].reset_index(drop=True)
        self.transform = transform

        print(f"Original dataset size: {len(dataframe)}")
        print(f"Filtered dataset size: {len(self.df)}")
        print(f"Number of entries filtered out: {len(invalid_ids)}")
        if invalid_ids:
            print(f"First 10 filtered IDs: {invalid_ids[:10]}")


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        poke_id = str(row["id"])

        folder = None
        for f in os.listdir(self.root_dir):
            if f.startswith(f"{poke_id}-"):
                folder = os.path.join(self.root_dir, f)
                break

        if folder is None:
            # This should not happen after filtering, but keep for safety
            raise FileNotFoundError(f"Brak folderu dla id={poke_id}")

        sprite_dir = os.path.join(folder, "front", "normal")
        files = [os.path.join(sprite_dir, fn) for fn in os.listdir(sprite_dir) if fn.endswith(".png")]
        if len(files) == 0:
            raise FileNotFoundError(f"Brak sprite'ów w {sprite_dir}")

        img_path = files[0]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        label = row.drop(["id", "name"]).values.astype("float32")
        label = torch.tensor(label)

        return img, label

In [None]:
print(os.listdir("./pokemon-sprites-images/pokemon_images/sprites")[:1000])

In [None]:
def sinusoidal_embedding(timesteps: torch.LongTensor, dim: int):
    """
    timesteps: (B,) long tensor
    returns: (B, dim) float tensor
    """
    assert len(timesteps.shape) == 1
    device = timesteps.device
    half = dim // 2
    freqs = torch.exp(-torch.log(torch.tensor(10000.0, device=device)) * torch.arange(half, device=device) / (half - 1))
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb

In [None]:
transform = T.Compose([
    T.Resize((96, 96), interpolation=Image.NEAREST),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = PokemonDataset(root_dir="./pokemon-sprites-images/pokemon_images/sprites", dataframe = pokemonData, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

In [None]:
class ConvBlock(nn.Module):
    """
    Standard convolutional block with GroupNorm, SiLU activation,
    a residual connection, and time embedding injection.
    """
    def __init__(self, in_ch, out_ch, time_emb_dim, is_res=True):
        super().__init__()
        self.is_res = is_res
        self.main_path = nn.Sequential(
            nn.GroupNorm(8, in_ch),
            nn.SiLU(),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
        )
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_ch),
        )
        self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, temb):
        h = self.main_path(x)
        time_emb = self.time_mlp(temb).unsqueeze(-1).unsqueeze(-1)
        h = h + time_emb
        return h + self.res_conv(x)

In [None]:
class AttentionBlock(nn.Module):
    """
    Self-attention block. Applies Multi-Head Self-Attention to 2D feature maps.
    """
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1, bias=False)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.norm(x)
        qkv = self.qkv(x).reshape(B, 3, self.num_heads, C // self.num_heads, H * W)
        q, k, v = qkv.unbind(1)

        attn = torch.einsum('b h c i, b h c j -> b h i j', q, k) * ((C // self.num_heads) ** -0.5)
        attn = F.softmax(attn, dim=-1)

        out = torch.einsum('b h i j, b h c j -> b h c i', attn, v)
        out = out.reshape(B, C, H, W)
        return x + self.proj(out)

In [None]:
class UNet(nn.Module):
    def __init__(self, img_channels=3, base_channels=128, time_emb_dim=256, cond_dim = 47):
        super().__init__()

        # Sinusoidalne embeddingi
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # Wektor cech
        self.cond_mlp = nn.Sequential(
            nn.Linear(cond_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )


        ch_mults = (1, 2, 4, 8)
        channels = [base_channels] + [base_channels * m for m in ch_mults]

        self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)

        self.down_blocks = nn.ModuleList()
        for i in range(len(ch_mults)):
            in_ch = channels[i]
            out_ch = channels[i+1]
            self.down_blocks.append(nn.ModuleList([
                ConvBlock(in_ch, out_ch, time_emb_dim),
                AttentionBlock(out_ch) if i >= 2 else nn.Identity(),
                nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1)
            ]))

        self.mid_block1 = ConvBlock(channels[-1], channels[-1], time_emb_dim)
        self.mid_attn = AttentionBlock(channels[-1])
        self.mid_block2 = ConvBlock(channels[-1], channels[-1], time_emb_dim)

        self.up_blocks = nn.ModuleList()
        for i in reversed(range(len(ch_mults))):
            in_ch = channels[i+1]
            out_ch = channels[i]

            self.up_blocks.append(nn.ModuleList([
                nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1),
                ConvBlock(out_ch + in_ch, out_ch, time_emb_dim),
                AttentionBlock(out_ch) if i >= 2 else nn.Identity(),
            ]))

        self.final_norm = nn.GroupNorm(8, base_channels)
        self.final_act = nn.SiLU()
        self.final_conv = nn.Conv2d(base_channels, img_channels, 1)

    def forward(self, x, t, cond):
        temb = sinusoidal_embedding(t, 256)
        temb = self.time_mlp(temb)

        cemb = self.cond_mlp(cond)
        temb = temb + cemb
        x = self.init_conv(x)

        skips = [x]
        for block, attn, downsample in self.down_blocks:
            x = block(x, temb)
            x = attn(x)
            skips.append(x)
            x = downsample(x)

        x = self.mid_block1(x, temb)
        x = self.mid_attn(x)
        x = self.mid_block2(x, temb)

        for upsample, block, attn in self.up_blocks:
            x = upsample(x)
            skip = skips.pop()
            x = torch.cat([x, skip], dim=1)
            x = block(x, temb)
            x = attn(x)

        x = self.final_norm(x)
        x = self.final_act(x)
        return self.final_conv(x)

In [None]:
class Diffusion:
    def __init__(self, model: nn.Module, img_size=96, device="cuda", timesteps=256):
        self.model = model
        self.img_size = img_size
        self.device = device
        self.timesteps = timesteps

        betas = torch.linspace(1e-4, 0.02, timesteps, device=device)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]])

        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

    def q_sample(self, x0, t, noise=None):
        """
        sample from q(x_t | x_0)
        x0: (B,C,H,W)
        t: (B,) long tensor
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_alpha_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise

    def training_step(self, x0, cond):
        """
        single batch training step: pick random t for each sample
        returns MSE loss between predicted noise and real noise
        """
        b = x0.size(0)
        t = torch.randint(0, self.timesteps, (b,), device=self.device, dtype=torch.long)
        noise = torch.randn_like(x0)
        x_noisy = self.q_sample(x0, t, noise)
        noise_pred = self.model(x_noisy, t, cond)
        loss = F.mse_loss(noise_pred, noise)
        return loss

    def p_mean_variance(self, x_t, t, cond):
        """
        computes posterior mean and variance for q(x_{t-1} | x_t, x0_pred)
        x_t: (B,C,H,W)
        t: scalar int or 0-d python int
        returns: posterior_mean, posterior_variance (both tensors shape (B,C,H,W) for mean and (B,1,1,1) for var)
        """
        B = x_t.shape[0]
        device = x_t.device
        t_tensor = torch.full((B,), t, device=device, dtype=torch.long)
        eps_theta = self.model(x_t, t_tensor, cond)

        sqrt_alpha_hat_t = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alpha_hat_t = self.sqrt_one_minus_alphas_cumprod[t]
        x0_pred = (x_t - sqrt_one_minus_alpha_hat_t * eps_theta) / sqrt_alpha_hat_t
        x0_pred = torch.clamp(x0_pred, -1.0, 1.0)

        alpha_t = self.alphas[t]
        alpha_hat_t = self.alphas_cumprod[t]
        alpha_hat_prev = self.alphas_cumprod_prev[t]
        beta_t = self.betas[t]

        coef_x0 = (torch.sqrt(alpha_hat_prev) * beta_t) / (1.0 - alpha_hat_t)
        coef_xt = (torch.sqrt(alpha_t) * (1.0 - alpha_hat_prev)) / (1.0 - alpha_hat_t)

        coef_x0 = coef_x0.view(1, 1, 1, 1)
        coef_xt = coef_xt.view(1, 1, 1, 1)

        posterior_mean = coef_x0 * x0_pred + coef_xt * x_t
        posterior_variance = beta_t * (1.0 - alpha_hat_prev) / (1.0 - alpha_hat_t)
        posterior_log_variance = torch.log(torch.clamp(posterior_variance, min=1e-20)).view(1, 1, 1, 1)

        return posterior_mean, posterior_variance, posterior_log_variance

    def p_sample(self, x_t, t, cond):
        """
        sample x_{t-1} from p(x_{t-1} | x_t)
        """
        mean, var, log_var = self.p_mean_variance(x_t, t, cond)
        if t == 0:
            return mean
        noise = torch.randn_like(x_t)
        return mean + torch.sqrt(var).view(1,1,1,1) * noise

    def sample(self, cond, batch_size=8):
        """
        Full sampling loop: start from x_T ~ N(0,I), run p_sample iteratively.
        """
        x = torch.randn(batch_size, 3, self.img_size, self.img_size, device=self.device)
        for t in reversed(range(self.timesteps)):
            x = self.p_sample(x, t, cond)
        return x

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(img_channels=3, base_channels=128, time_emb_dim=256).to(device)
diffusion = Diffusion(model, img_size=96, device=device, timesteps=1000)
optimizer = optim.Adam(model.parameters(), lr=2e-4)

In [None]:
img, _ = next(iter(dataloader))
print(torch.unique(img))
plt.imshow(img[0].permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.show()

In [None]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

In [None]:
num_epochs = 3500
iters = 0
for epoch in range(num_epochs):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for imgs, cond in pbar:
        imgs = imgs.to(device)
        cond = cond.to(device)
        loss = diffusion.training_step(imgs, cond)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iters += 1
        pbar.set_postfix({"loss": loss.item(), "iters": iters})

    if (epoch + 1) % 100 == 0:
        model.eval()
        with torch.no_grad():
            _, cond_batch = next(iter(dataloader))
            cond_batch = cond_batch[:8].to(device)

            samples = diffusion.sample(cond = cond_batch, batch_size=8)
            samples = (samples * 0.5 + 0.5).clamp(0, 1)
        model.train()

        fig, axes = plt.subplots(1, 8, figsize=(16, 2))
        for i in range(8):
            img = samples[i].permute(1, 2, 0).cpu().numpy()
            axes[i].imshow(img)
            axes[i].axis("off")
        plt.show()

In [None]:
# 1. legendary, 2. mega_evolution, 3. alolan_form, 4. galarian_form, 5. gigantamax, 6. Bug, 7. Dark, 8. Dragon, 9. Electric, 10. Fairy, 11. Fighting, 12. Fire, 13. Flying, 14. Ghost, 15. Grass, 16. Ground, 17. Ice, 18. Normal, 19. Poison, 20. Psychic,
# 21. Rock, 22. Steel, 23. Water, 24. shape_Armor, 25. shape_Arms, 26. shape_Ball, 27. shape_Blob, 28. shape_Bug-Wings, 29. shape_Fish, 30. shape_Heads, 31. shape_Humanoid, 32. shape_Legs, 33. shape_Quadruped, 34. shape_Squiggle, 35. shape_Tentacles, 36. shape_Upright, 37. shape_Wings, 38. color_Black, 39. color_Blue, 40. color_Brown,
# 41. color_Gray, 42. color_Green, 43. color_Pink, 44. color_Purple, 45. color_Red, 46. color_White, 47. color_Yellow

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

feature_names = ["legendary", "mega_evolution", "alolan_form", "galarian_form", "gigantamax",
                 "Bug", "Dark", "Dragon", "Electric", "Fairy",
                 "Fighting", "Fire", "Flying", "Ghost", "Grass",
                 "Ground", "Ice", "Normal", "Poison", "Psychic",
                 "Rock", "Steel", "Water", "shape_Armor", "shape_Arms",
                 "shape_Ball", "shape_Blob", "shape_Bug-Wings", "shape_Fish", "shape_Heads",
                 "shape_Humanoid", "shape_Legs", "shape_Quadruped", "shape_Squiggle", "shape_Tentacles",
                 "shape_Upright", "shape_Wings", "color_Black", "color_Blue", "color_Brown",
                 "color_Gray", "color_Green", "color_Pink", "color_Purple", "color_Red",
                 "color_White", "color_Yellow"]
cond_dim = len(feature_names)
batch_size = 4

cond_batch = torch.zeros(batch_size, cond_dim)

cond_batch[0, [14,18,32,41]] = 1
cond_batch[1, [0,12,23,46]] = 1
cond_batch[2, [0]]   = 1
cond_batch[3, [5,8,13,31]] = 1

cond_batch = cond_batch.to(device)

In [None]:
with torch.no_grad():
    samples = diffusion.sample(cond=cond_batch, batch_size=batch_size)
    samples = (samples * 0.5 + 0.5).clamp(0, 1)

In [None]:
fig, axes = plt.subplots(1, batch_size, figsize=(16, 2))
for i in range(batch_size):
    axes[i].imshow(samples[i].permute(1, 2, 0).cpu().numpy())
    axes[i].axis("off")
plt.show()

In [None]:
with torch.no_grad():
    _, cond_batch = next(iter(dataloader))
    cond_batch = cond_batch[:8].to(device)

    samples = diffusion.sample(cond=cond_batch, batch_size=8)

    fig, axes = plt.subplots(1, 8, figsize=(16, 2))
    for i in range(8):
        axes[i].imshow(samples[i].permute(1, 2, 0).cpu().numpy())
        axes[i].axis("off")
    plt.show()

In [None]:
torch.save(model.state_dict(), "unet_pokemon_diffusion.pth")