In [1]:
import numpy as np
import torch
from torchvision.datasets import MNIST
from torchvision.transforms.functional import rotate
from torch.utils.data import Dataset, DataLoader



In [2]:
class RotatedMNIST(Dataset):
    def __init__(self, root, train=True, digits=(1, 2)):
        self.mnist = MNIST(root=root, train=train, download=True)
        self.indices = [i for i, (img, label) in enumerate(self.mnist) if label in digits]
        self.angles = list(range(0, 360, 30))  # [0, 30, 60, ..., 330]

    def __len__(self):
        return len(self.indices) * len(self.angles)

    def __getitem__(self, idx):
        img_idx = self.indices[idx // len(self.angles)]
        angle = self.angles[idx % len(self.angles)]
        img, label = self.mnist[img_idx]
        img = rotate(img, angle)
        return img, label, angle  # (image, digit label, rotation angle)

In [3]:
train_dataset = RotatedMNIST(root='./data', train=True)
test_dataset = RotatedMNIST(root='./data', train=False)

100%|██████████| 9.91M/9.91M [00:06<00:00, 1.50MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 83.5kB/s]
100%|██████████| 1.65M/1.65M [00:02<00:00, 759kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.30MB/s]


In [4]:
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [5]:
import torch.nn as nn
import torch.nn.functional as F

In [9]:
class VAE(nn.Module):
    def __init__(self, latent_dim=32):
        super(VAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*7*7, 256),
            nn.ReLU(),
        )
        
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_var = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64*7*7),
            nn.ReLU(),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ======================
# Training Loop
# ======================

print("Starting VAE Training...")
for epoch in range(20):
    model.train()
    train_loss = 0
    for batch_idx, (data, _, _) in enumerate(train_loader):
        data = data.unsqueeze(1).to(device)  # Add channel dimension
        optimizer.zero_grad()
        
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
    print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset):.4f}')

Starting VAE Training...


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>