In [None]:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets.folder import default_loader
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from sklearn.metrics.pairwise import pairwise_distances
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

DATA_DIR = "./data/train"  # katalog z obrazkami (np. z kaggle dogs-vs-cats)
IMG_SIZE = 64
BATCH_SIZE = 32
EPOCHS = 20
LR = 2e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:

class SimpleImageDataset(Dataset):
    def __init__(self, folder, transform):
        self.paths = [os.path.join(folder, f) for f in os.listdir(folder) if ('cat' in f or 'dog' in f)]
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = default_loader(self.paths[idx]).convert("RGB")
        return self.transform(img)

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = SimpleImageDataset(DATA_DIR, transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


In [None]:

class Encoder(nn.Module):
    def __init__(self, in_channels=3, hidden_dim=128, z_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(hidden_dim, z_dim, 1)
        )

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, z_dim=64, hidden_dim=128, out_channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, hidden_dim, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, hidden_dim, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, out_channels, 3, 1, 1)
        )

    def forward(self, x):
        return self.net(x)

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
        self.beta = beta

    def forward(self, z):
        z_perm = z.permute(0, 2, 3, 1).contiguous()
        flat_z = z_perm.view(-1, z.shape[1])

        dists = (
            flat_z.pow(2).sum(1, keepdim=True)
            - 2 * flat_z @ self.embedding.weight.t()
            + self.embedding.weight.pow(2).sum(1)
        )
        indices = dists.argmin(1)
        quantized = self.embedding(indices).view(z_perm.shape)
        quantized = quantized.permute(0, 3, 1, 2)

        loss = F.mse_loss(quantized.detach(), z) + self.beta * F.mse_loss(quantized, z.detach())
        quantized = z + (quantized - z).detach()
        return quantized, loss

class VQVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.vq = VectorQuantizer(num_embeddings=512, embedding_dim=64)
        self.decoder = Decoder()

    def forward(self, x):
        z = self.encoder(x)
        quantized, loss_vq = self.vq(z)
        recon = self.decoder(quantized)
        return recon, loss_vq


In [None]:

model = VQVAE().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for x in tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        x = x.to(DEVICE)
        recon, loss_vq = model(x)
        recon_loss = F.mse_loss(recon, x)
        loss = recon_loss + loss_vq

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(loader):.4f}")


In [None]:

model.eval()
with torch.no_grad():
    imgs = next(iter(loader)).to(DEVICE)
    recon, _ = model(imgs)
    imgs = imgs[:6]
    recon = recon[:6]
    def denorm(x): return (x * 0.5 + 0.5).clamp(0, 1)
    grid = make_grid(torch.cat([denorm(imgs), denorm(recon)]), nrow=6)
    plt.figure(figsize=(12, 4))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.axis("off")
    plt.title("Original (top) vs Reconstruction (bottom)")
    plt.show()


In [None]:

def interpolate_latents(z1, z2, steps=8):
    return [z1 * (1 - alpha) + z2 * alpha for alpha in np.linspace(0, 1, steps + 2)]

model.eval()
with torch.no_grad():
    imgs = next(iter(loader)).to(DEVICE)[:2]
    z1 = model.encoder(imgs[0].unsqueeze(0))
    z2 = model.encoder(imgs[1].unsqueeze(0))
    interpolated = interpolate_latents(z1, z2)
    decoded = [model.decoder(z) for z in interpolated]
    decoded = torch.cat(decoded)
    decoded = (decoded * 0.5 + 0.5).clamp(0, 1)
    grid = make_grid(decoded, nrow=10)
    plt.figure(figsize=(20, 4))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.axis("off")
    plt.title("Interpolation between two latent codes")
    plt.show()


In [None]:

# UWAGA: To nie jest dokładny FID, tylko przybliżona wersja oparta na euclidean distance między cechami.
def simple_feature_distance(x, y):
    return ((x - y)**2).mean().item()

model.eval()
with torch.no_grad():
    x = next(iter(loader)).to(DEVICE)
    recon, _ = model(x)
    dist = simple_feature_distance(x, recon)
    print(f"Uproszczona metryka odległości cech: {dist:.4f}")



## Eksperymenty z hiperparametrami

Można eksperymentować ze zmianą:
- `z_dim`: wymiar przestrzeni latentnej (np. 32, 128)
- `beta` w VectorQuantizer
- `num_embeddings`
- `LR` – szybkość uczenia
- `IMG_SIZE` – rozdzielczość

Następnie porównać:
- jakość wizualną obrazów,
- prostą metrykę różnicy (np. MSE),
- ewentualnie dokładniejszy FID z `torch-fidelity`.



## Dodatkowe zadanie: psy i koty

Zmień warunek w `SimpleImageDataset`, aby ładować zarówno psy, jak i koty z Kaggle. Możesz porównać rekonstrukcje i interpolacje, aby sprawdzić czy model rozróżnia typy zwierząt.

```python
if 'cat' in f or 'dog' in f
```
