In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
        self.commitment_cost = commitment_cost

    def forward(self, z):
        z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
        distances = (
            torch.sum(z_flat ** 2, dim=1, keepdim=True)
            - 2 * torch.matmul(z_flat, self.embedding.weight.t())
            + torch.sum(self.embedding.weight ** 2, dim=1)
        )
        indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(indices.size(0), self.num_embeddings, device=z.device)
        encodings.scatter_(1, indices, 1)

        quantized = torch.matmul(encodings, self.embedding.weight).view(
            z.shape[0], z.shape[2], z.shape[3], self.embedding_dim
        ).permute(0, 3, 1, 2)

        e_latent_loss = F.mse_loss(quantized.detach(), z)
        q_latent_loss = F.mse_loss(quantized, z.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        quantized = z + (quantized - z).detach()
        indices = indices.view(z.shape[0], z.shape[2], z.shape[3])
        return quantized, loss, indices

class VQVAE2(nn.Module):
    def __init__(self, num_embeddings=512, embedding_dim=64):
        super().__init__()
        # Encoder bottom -> z_b
        self.enc_b = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, embedding_dim, 3, 1, 1)
        )
        # Encoder top -> z_t
        self.enc_t = nn.Sequential(
            nn.Conv2d(embedding_dim, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, embedding_dim, 3, 1, 1)
        )

        self.quant_t = VectorQuantizer(num_embeddings, embedding_dim, 0.25)
        self.quant_b = VectorQuantizer(num_embeddings, embedding_dim, 0.25)

        self.dec_t = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, embedding_dim, 3, 1, 1)
        )

        self.dec = nn.Sequential(
            nn.ConvTranspose2d(2 * embedding_dim, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 3, 1), nn.Sigmoid()
        )

    def forward(self, x):
        z_b = self.enc_b(x)
        z_t = self.enc_t(z_b)

        quant_t, loss_t, idx_t = self.quant_t(z_t)
        dec_t = self.dec_t(quant_t)

        z_b_combined = z_b + dec_t
        quant_b, loss_b, idx_b = self.quant_b(z_b_combined)

        x_recon = self.dec(torch.cat([quant_b, dec_t], dim=1))
        return x_recon, loss_t + loss_b

    def encode_indices(self, x):
        z_b = self.enc_b(x)
        z_t = self.enc_t(z_b)
        _, _, idx_t = self.quant_t(z_t)
        dec_t = self.dec_t(self.quant_t(z_t)[0])
        z_b_combined = z_b + dec_t
        _, _, idx_b = self.quant_b(z_b_combined)
        return idx_t, idx_b

    def decode_indices(self, idx_t, idx_b):
        emb_t = self.quant_t.embedding(idx_t).permute(0, 3, 1, 2)
        dec_t = self.dec_t(emb_t)
        emb_b = self.quant_b.embedding(idx_b).permute(0, 3, 1, 2)
        recon = self.dec(torch.cat([emb_b, dec_t], dim=1))
        return recon


In [2]:
import torch
import torch.nn as nn

class PixelSNAIL(nn.Module):
    def __init__(self, num_embeddings, hidden_dim, size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, hidden_dim, 7, padding=3),
            nn.ReLU(),
            *[nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.ReLU()) for _ in range(8)],
            nn.Conv2d(hidden_dim, num_embeddings, 1)
        )
        self.size = size

    def forward(self, x):
        return self.net(x)

    def sample(self, device, num_samples):
        samples = torch.zeros((num_samples, 1, self.size, self.size), dtype=torch.long).to(device)
        for i in range(self.size):
            for j in range(self.size):
                with torch.no_grad():
                    logits = self(samples.float())
                    probs = torch.softmax(logits[:, :, i, j], dim=1)
                    samples[:, 0, i, j] = torch.multinomial(probs, 1).squeeze()
        return samples


In [9]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transformacje danych
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])
dataset = ImageFolder("data", transform=transform)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Model
vqvae = VQVAE2().to(device)
optimizer = torch.optim.Adam(vqvae.parameters(), lr=2e-4)

# === Trening VQ-VAE-2 ===
for epoch in range(20):
    vqvae.train()
    total_loss = 0.0
    for x, _ in tqdm(loader, desc=f"VQ-VAE2 Epoch {epoch+1}/20"):
        x = x.to(device)
        x_recon, loss = vqvae(x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1}: avg loss = {avg_loss:.4f}")

    # Zapis przykładowych rekonstrukcji co 5 epok
    if (epoch + 1) % 5 == 0:
        vqvae.eval()
        with torch.no_grad():
            sample = next(iter(loader))[0][:8].to(device)
            recon, _ = vqvae(sample)
            save_image(torch.cat([sample, recon], dim=0), f"recon_epoch{epoch+1}.png", nrow=8)

# === Zbieranie kodów latentnych ===
vqvae.eval()
idx_t_all, idx_b_all = [], []

with torch.no_grad():
    for x, _ in tqdm(loader, desc="Encoding latents"):
        x = x.to(device)
        idx_t, idx_b = vqvae.encode_indices(x)
        idx_t_all.append(idx_t)
        idx_b_all.append(idx_b)

idx_t_all = torch.cat(idx_t_all)
idx_b_all = torch.cat(idx_b_all)

# === Trening PixelSNAIL ===
pixelsnail_t = PixelSNAIL(num_embeddings=512, hidden_dim=64, size=idx_t_all.shape[1]).to(device)
pixelsnail_b = PixelSNAIL(num_embeddings=512, hidden_dim=64, size=idx_b_all.shape[1]).to(device)

opt_t = torch.optim.Adam(pixelsnail_t.parameters(), lr=2e-4)
opt_b = torch.optim.Adam(pixelsnail_b.parameters(), lr=2e-4)

# Pomocnicze dataloadery z kodami indeksów
from torch.utils.data import TensorDataset

loader_t = DataLoader(TensorDataset(idx_t_all), batch_size=64, shuffle=True)
loader_b = DataLoader(TensorDataset(idx_b_all), batch_size=64, shuffle=True)

# === Trening PixelSNAIL Top ===
for epoch in range(10):
    pixelsnail_t.train()
    for (x_t,) in tqdm(loader_t, desc=f"PixelSNAIL Top Epoch {epoch+1}/10"):
        x_t = x_t.to(device).unsqueeze(1).float()  # [B, 1, H, W]
        logits = pixelsnail_t(x_t)
        target = x_t.squeeze(1).long()
        loss = torch.nn.functional.cross_entropy(logits, target)

        opt_t.zero_grad()
        loss.backward()
        opt_t.step()

# === Trening PixelSNAIL Bottom ===
for epoch in range(10):
    pixelsnail_b.train()
    for (x_b,) in tqdm(loader_b, desc=f"PixelSNAIL Bottom Epoch {epoch+1}/10"):
        x_b = x_b.to(device).unsqueeze(1).float()
        logits = pixelsnail_b(x_b)
        target = x_b.squeeze(1).long()
        loss = torch.nn.functional.cross_entropy(logits, target)

        opt_b.zero_grad()
        loss.backward()
        opt_b.step()

# === Generowanie ===
print("Sampling latents...")
pixelsnail_t.eval()
pixelsnail_b.eval()

idx_t_sample = pixelsnail_t.sample(device, 16).squeeze(1).long()
idx_b_sample = pixelsnail_b.sample(device, 16).squeeze(1).long()

vqvae.eval()
with torch.no_grad():
    recon = vqvae.decode_indices(idx_t_sample, idx_b_sample)

os.makedirs("generated", exist_ok=True)
save_image(recon, "generated/generated_cats_vqvae2.png", nrow=4)
print("Saved: generated/generated_cats_vqvae2.png")


VQ-VAE2 Epoch 1/20: 100%|██████████| 467/467 [00:14<00:00, 32.97it/s]


Epoch 1: avg loss = 0.0000


VQ-VAE2 Epoch 2/20: 100%|██████████| 467/467 [00:13<00:00, 35.28it/s]


Epoch 2: avg loss = 0.0000


VQ-VAE2 Epoch 3/20: 100%|██████████| 467/467 [00:12<00:00, 36.70it/s]


Epoch 3: avg loss = 0.0000


VQ-VAE2 Epoch 4/20: 100%|██████████| 467/467 [00:12<00:00, 36.51it/s]


Epoch 4: avg loss = 0.0000


VQ-VAE2 Epoch 5/20: 100%|██████████| 467/467 [00:12<00:00, 36.56it/s]


Epoch 5: avg loss = 0.0000


VQ-VAE2 Epoch 6/20: 100%|██████████| 467/467 [00:12<00:00, 36.62it/s]


Epoch 6: avg loss = 0.0000


VQ-VAE2 Epoch 7/20: 100%|██████████| 467/467 [00:12<00:00, 36.77it/s]


Epoch 7: avg loss = 0.0000


VQ-VAE2 Epoch 8/20: 100%|██████████| 467/467 [00:12<00:00, 36.79it/s]


Epoch 8: avg loss = 0.0000


VQ-VAE2 Epoch 9/20: 100%|██████████| 467/467 [00:12<00:00, 36.84it/s]


Epoch 9: avg loss = 0.0000


VQ-VAE2 Epoch 10/20: 100%|██████████| 467/467 [00:12<00:00, 36.76it/s]


Epoch 10: avg loss = 0.0000


VQ-VAE2 Epoch 11/20: 100%|██████████| 467/467 [00:12<00:00, 36.56it/s]


Epoch 11: avg loss = 0.0000


VQ-VAE2 Epoch 12/20: 100%|██████████| 467/467 [00:12<00:00, 36.83it/s]


Epoch 12: avg loss = 0.0000


VQ-VAE2 Epoch 13/20: 100%|██████████| 467/467 [00:12<00:00, 36.96it/s]


Epoch 13: avg loss = 0.0000


VQ-VAE2 Epoch 14/20: 100%|██████████| 467/467 [00:12<00:00, 36.80it/s]


Epoch 14: avg loss = 0.0000


VQ-VAE2 Epoch 15/20: 100%|██████████| 467/467 [00:12<00:00, 36.79it/s]


Epoch 15: avg loss = 0.0000


VQ-VAE2 Epoch 16/20: 100%|██████████| 467/467 [00:12<00:00, 36.84it/s]


Epoch 16: avg loss = 0.0000


VQ-VAE2 Epoch 17/20: 100%|██████████| 467/467 [00:12<00:00, 36.81it/s]


Epoch 17: avg loss = 0.0000


VQ-VAE2 Epoch 18/20: 100%|██████████| 467/467 [00:12<00:00, 37.01it/s]


Epoch 18: avg loss = 0.0000


VQ-VAE2 Epoch 19/20: 100%|██████████| 467/467 [00:12<00:00, 36.79it/s]


Epoch 19: avg loss = 0.0000


VQ-VAE2 Epoch 20/20: 100%|██████████| 467/467 [00:12<00:00, 36.88it/s]


Epoch 20: avg loss = 0.0000


Encoding latents: 100%|██████████| 467/467 [00:10<00:00, 43.21it/s]
PixelSNAIL Top Epoch 1/10: 100%|██████████| 467/467 [00:01<00:00, 257.66it/s]
PixelSNAIL Top Epoch 2/10: 100%|██████████| 467/467 [00:01<00:00, 260.38it/s]
PixelSNAIL Top Epoch 3/10: 100%|██████████| 467/467 [00:01<00:00, 257.12it/s]
PixelSNAIL Top Epoch 4/10: 100%|██████████| 467/467 [00:01<00:00, 265.25it/s]
PixelSNAIL Top Epoch 5/10: 100%|██████████| 467/467 [00:01<00:00, 264.77it/s]
PixelSNAIL Top Epoch 6/10: 100%|██████████| 467/467 [00:01<00:00, 276.20it/s]
PixelSNAIL Top Epoch 7/10: 100%|██████████| 467/467 [00:01<00:00, 272.10it/s]
PixelSNAIL Top Epoch 8/10: 100%|██████████| 467/467 [00:01<00:00, 263.23it/s]
PixelSNAIL Top Epoch 9/10: 100%|██████████| 467/467 [00:01<00:00, 259.35it/s]
PixelSNAIL Top Epoch 10/10: 100%|██████████| 467/467 [00:01<00:00, 257.27it/s]
PixelSNAIL Bottom Epoch 1/10: 100%|██████████| 467/467 [00:01<00:00, 251.39it/s]
PixelSNAIL Bottom Epoch 2/10: 100%|██████████| 467/467 [00:01<00:00, 2

Sampling latents...
Saved: generated/generated_cats_vqvae2.png



