In [1]:
import numpy as np
import torch
from torchvision.datasets import MNIST
from torchvision.transforms.functional import rotate
from torch.utils.data import Dataset, DataLoader



In [11]:
# ======================
# Corrected Dataset Class
# ======================

class RotatedMNIST(Dataset):
    def __init__(self, root, train=True, digits=(1, 2)):
        self.mnist = MNIST(root=root, train=train, download=True,
                          transform=ToTensor())  # Convert to tensor upfront
        self.indices = [i for i, (_, label) in enumerate(self.mnist) 
                      if label in digits]
        self.angles = list(range(0, 360, 30))

    def __len__(self):
        return len(self.indices) * len(self.angles)

    def __getitem__(self, idx):
        img_idx = self.indices[idx // len(self.angles)]
        angle = self.angles[idx % len(self.angles)]
        img, label = self.mnist[img_idx]
        
        # Rotate tensor directly (no PIL conversion)
        img_rotated = rotate(img, angle)
        return img_rotated, label, torch.tensor(angle/360.0, dtype=torch.float32)

In [17]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),  # Converts to [0, 1] range automatically
])


train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)


In [13]:
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [14]:
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class VAE(nn.Module):
    def __init__(self, latent_dim=32):
        super(VAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*7*7, 256),
            nn.ReLU(),
        )
        
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_var = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64*7*7),
            nn.ReLU(),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ======================
# Training Loop
# ======================
print("Starting VAE Training...")
for epoch in range(20):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)  # Already has channel dimension [B, 1, 28, 28]
        optimizer.zero_grad()
        
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
    print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset):.4f}')

Starting VAE Training...
Epoch 1, Loss: 154.2721
Epoch 2, Loss: 112.8621
Epoch 3, Loss: 107.0420
Epoch 4, Loss: 104.5923
Epoch 5, Loss: 103.0775
Epoch 6, Loss: 101.9875
Epoch 7, Loss: 101.1702
Epoch 8, Loss: 100.6196
Epoch 9, Loss: 100.1122
Epoch 10, Loss: 99.7227
Epoch 11, Loss: 99.4339
Epoch 12, Loss: 99.0939
Epoch 13, Loss: 98.8460
Epoch 14, Loss: 98.6016
Epoch 15, Loss: 98.3999
Epoch 16, Loss: 98.1911
Epoch 17, Loss: 98.0220
Epoch 18, Loss: 97.8141
Epoch 19, Loss: 97.7028
Epoch 20, Loss: 97.5630


In [19]:
torch.save(model.state_dict(), "vae_mnist.pth")


In [20]:
model.load_state_dict(torch.load("vae_mnist.pth"))
model.train()  # Ensure it's in training mode


VAE(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): Linear(in_features=3136, out_features=256, bias=True)
    (6): ReLU()
  )
  (fc_mu): Linear(in_features=256, out_features=32, bias=True)
  (fc_var): Linear(in_features=256, out_features=32, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=3136, bias=True)
    (3): ReLU()
    (4): Unflatten(dim=1, unflattened_size=(64, 7, 7))
    (5): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (6): ReLU()
    (7): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (8): Sigmoid()
  )
)

In [21]:
for epoch in range(21, 31):  # Continue from epoch 21 to 30
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    
    print(f'Epoch {epoch}, Loss: {train_loss/len(train_loader.dataset):.4f}')


Epoch 21, Loss: 97.4159
Epoch 22, Loss: 97.3002
Epoch 23, Loss: 97.1484
Epoch 24, Loss: 97.0593
Epoch 25, Loss: 96.8989
Epoch 26, Loss: 96.7997
Epoch 27, Loss: 96.7098
Epoch 28, Loss: 96.6413
Epoch 29, Loss: 96.5300
Epoch 30, Loss: 96.4831


In [37]:
import torch
import torch.nn as nn

class VAE(nn.Module):
    def __init__(self, latent_dim=32):  
        super(VAE, self).__init__()

        # **Encoder (CNN)**
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # (B, 1, 28, 28) -> (B, 32, 14, 14)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # (B, 32, 14, 14) -> (B, 64, 7, 7)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # (B, 64, 7, 7) -> (B, 128, 4, 4)
            nn.ReLU(),
            nn.Flatten()  # Flatten for FC layers
        )

        self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)  
        self.fc_var = nn.Linear(128 * 4 * 4, latent_dim)

        # **Decoder**
        self.decoder_input = nn.Linear(latent_dim, 128 * 4 * 4)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.decoder_input(z).view(-1, 128, 4, 4)  
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [38]:
vae = VAE(latent_dim=32).to(device)
vae.load_state_dict(torch.load("vae_mnist.pth", map_location=device))  # Load checkpoint
vae.eval()  # Set to evaluation mode


RuntimeError: Error(s) in loading state_dict for VAE:
	Missing key(s) in state_dict: "encoder.4.weight", "encoder.4.bias", "decoder_input.weight", "decoder_input.bias", "decoder.4.weight", "decoder.4.bias". 
	Unexpected key(s) in state_dict: "encoder.5.weight", "encoder.5.bias", "decoder.7.weight", "decoder.7.bias", "decoder.5.weight", "decoder.5.bias". 
	size mismatch for fc_mu.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([32, 2048]).
	size mismatch for fc_var.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([32, 2048]).
	size mismatch for decoder.0.weight: copying a param with shape torch.Size([256, 32]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
	size mismatch for decoder.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder.2.weight: copying a param with shape torch.Size([3136, 256]) from checkpoint, the shape in current model is torch.Size([64, 32, 3, 3]).
	size mismatch for decoder.2.bias: copying a param with shape torch.Size([3136]) from checkpoint, the shape in current model is torch.Size([32]).

In [35]:
checkpoint = torch.load("vae_mnist.pth", map_location=device)
print(checkpoint.keys())  # Print layer names


odict_keys(['encoder.0.weight', 'encoder.0.bias', 'encoder.2.weight', 'encoder.2.bias', 'encoder.5.weight', 'encoder.5.bias', 'fc_mu.weight', 'fc_mu.bias', 'fc_var.weight', 'fc_var.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.2.weight', 'decoder.2.bias', 'decoder.5.weight', 'decoder.5.bias', 'decoder.7.weight', 'decoder.7.bias'])


In [36]:
print(checkpoint["encoder.2.weight"].shape)  # Check the shape of the Conv layers


torch.Size([64, 32, 3, 3])


In [39]:
vae.load_state_dict(torch.load("vae_mnist.pth", map_location=device), strict=False)


RuntimeError: Error(s) in loading state_dict for VAE:
	size mismatch for fc_mu.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([32, 2048]).
	size mismatch for fc_var.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([32, 2048]).
	size mismatch for decoder.0.weight: copying a param with shape torch.Size([256, 32]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
	size mismatch for decoder.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder.2.weight: copying a param with shape torch.Size([3136, 256]) from checkpoint, the shape in current model is torch.Size([64, 32, 3, 3]).
	size mismatch for decoder.2.bias: copying a param with shape torch.Size([3136]) from checkpoint, the shape in current model is torch.Size([32]).

In [40]:
checkpoint = torch.load("vae_mnist.pth", map_location=device)
for key, value in checkpoint.items():
    print(key, value.shape)


encoder.0.weight torch.Size([32, 1, 3, 3])
encoder.0.bias torch.Size([32])
encoder.2.weight torch.Size([64, 32, 3, 3])
encoder.2.bias torch.Size([64])
encoder.5.weight torch.Size([256, 3136])
encoder.5.bias torch.Size([256])
fc_mu.weight torch.Size([32, 256])
fc_mu.bias torch.Size([32])
fc_var.weight torch.Size([32, 256])
fc_var.bias torch.Size([32])
decoder.0.weight torch.Size([256, 32])
decoder.0.bias torch.Size([256])
decoder.2.weight torch.Size([3136, 256])
decoder.2.bias torch.Size([3136])
decoder.5.weight torch.Size([64, 32, 3, 3])
decoder.5.bias torch.Size([32])
decoder.7.weight torch.Size([32, 1, 3, 3])
decoder.7.bias torch.Size([1])
