<a href="https://colab.research.google.com/github/Panperception/MML/blob/main/MML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Learning on Many Manifolds

## Initialize

In [None]:
!pip install tensorflow

## Demo Code 0

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

class SharedAutoencoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(SharedAutoencoder, self).__init__()

        # Shared encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*7*7, latent_dim)
        )

        # Separate latent space for each class (one per class)
        self.latent_spaces = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(num_classes)])

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64*7*7),
            nn.ReLU(),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x, class_idx):
        # Shared encoder
        z = self.encoder(x)

        # Latent space transformation per class
        z_class = self.latent_spaces[class_idx](z)

        # Decoder
        reconstructed = self.decoder(z_class)
        return reconstructed, z_class

def train(model, dataloader, epochs=5, learning_rate=0.001):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for data, targets in dataloader:
            optimizer.zero_grad()
            # Use the class index from the targets (it's already an integer)
            for i in range(data.size(0)):  # For each batch sample
                class_idx = targets[i].item()  # Convert target to int
                reconstructed, _ = model(data[i:i+1], class_idx)  # Use single image per forward pass
                loss = criterion(reconstructed, data[i:i+1])  # Reconstruction loss for the image
                loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader)}')

# Data Preparation (Using MNIST for simplicity)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Model Initialization
latent_dim = 128
num_classes = 10  # 10 classes for MNIST
model = SharedAutoencoder(latent_dim, num_classes)

# Training the Model
train(model, train_loader, epochs=5)

# Visualization of Reconstruction
def visualize_reconstruction(model, dataloader, class_idx=0):
    model.eval()
    data, targets = next(iter(dataloader))
    reconstructed, _ = model(data, class_idx)

    # Plot original and reconstructed images
    fig, axes = plt.subplots(2, 10, figsize=(15, 5))
    for i in range(10):
        axes[0, i].imshow(data[i].squeeze().cpu().numpy(), cmap='gray')
        axes[1, i].imshow(reconstructed[i].squeeze().cpu().detach().numpy(), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].axis('off')

    plt.show()

# Visualize reconstruction for a specific class
visualize_reconstruction(model, train_loader, class_idx=0)  # For class 0 (e.g., '0' digit in MNIST)


## Demo Code 1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------
# Model Components
# -----------------------

class SharedEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU()
        )

    def forward(self, x):
        return self.shared(x)

class LatentBranch(nn.Module):
    def __init__(self, latent_dim=16):
        super().__init__()
        self.fc = nn.Linear(128, latent_dim)

    def forward(self, x):
        return self.fc(x)

class SharedDecoder(nn.Module):
    def __init__(self, latent_dim=16):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.decoder(z).view(-1, 1, 28, 28)

class MultiLatentAutoencoder(nn.Module):
    def __init__(self, latent_dim=16, num_classes=10):
        super().__init__()
        self.encoder = SharedEncoder()
        self.latent_branches = nn.ModuleList([LatentBranch(latent_dim) for _ in range(num_classes)])
        self.decoder = SharedDecoder(latent_dim)

    def forward(self, x, labels):
        shared = self.encoder(x)
        latents = torch.stack([branch(shared) for branch in self.latent_branches], dim=1)  # [B, C, latent_dim]
        z = latents[torch.arange(x.size(0)), labels]  # Select class-specific latent
        out = self.decoder(z)
        return out

# -----------------------
# Dataset
# -----------------------

transform = transforms.ToTensor()

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# -----------------------
# Training Setup
# -----------------------

model = MultiLatentAutoencoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# -----------------------
# Training Loop
# -----------------------

epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, labels)
        loss = criterion(outputs, images)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")

# -----------------------
# Evaluation / Visualization
# -----------------------

model.eval()
images, labels = next(iter(test_loader))
images, labels = images.to(device), labels.to(device)

with torch.no_grad():
    recon = model(images, labels)

# Show original and reconstructed
def show_images(original, reconstructed, n=10):
    plt.figure(figsize=(15, 3))
    for i in range(n):
        # Original
        plt.subplot(2, n, i + 1)
        plt.imshow(original[i].cpu().squeeze(), cmap='gray')
        plt.axis('off')
        # Reconstructed
        plt.subplot(2, n, i + 1 + n)
        plt.imshow(reconstructed[i].cpu().squeeze(), cmap='gray')
        plt.axis('off')
    plt.suptitle("Top: Original | Bottom: Reconstructed", fontsize=14)
    plt.show()

show_images(images, recon)


## Demo Code 2


### Key Takeaways
* Each class's decoder models a distinct digit manifold.
* RandomAffine introduces group action (SO(2), translation), challenging the model to be invariant or equivariant.
* Decoder selection requires label knowledge; removing this label makes inference NP-hard (you’d need to guess decoder + latent code jointly).
* This aligns with the theoretical proof sketch — combinatorial search across submanifolds under transformations.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# --- 1. Data Loading with Group Action (rotation) ---
transform_train = transforms.Compose([
    transforms.RandomAffine(degrees=30),  # simulates group action SO(2)
    transforms.ToTensor()
])
transform_test = transforms.ToTensor()

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False)

# --- 2. Shared Encoder (learns common latent structure) ---
class Encoder(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),  # 28 -> 14
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 14 -> 7
            nn.ReLU()
        )
        self.fc = nn.Linear(32 * 7 * 7, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# --- 3. Class-Specific Decoder (models separate submanifolds) ---
class Decoder(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 32 * 7 * 7)
        self.deconv = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # 7 -> 14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),   # 14 -> 28
            nn.Sigmoid()
        )

    def forward(self, z):
        z = self.fc(z)
        z = z.view(-1, 32, 7, 7)
        return self.deconv(z)

# --- 4. Full Model with Shared Encoder + Per-Class Decoders ---
class ManifoldAutoencoder(nn.Module):
    def __init__(self, latent_dim=32, num_classes=10):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoders = nn.ModuleList([Decoder(latent_dim) for _ in range(num_classes)])

    def forward(self, x, labels):
        z = self.encoder(x)
        recon = torch.zeros_like(x)
        for i in range(x.size(0)):
            recon[i] = self.decoders[labels[i]](z[i].unsqueeze(0))
        return recon

# --- 5. Training the Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ManifoldAutoencoder(latent_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

def train(model, loader, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            x_hat = model(x, y)
            loss = loss_fn(x_hat, x)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} | Avg Loss: {total_loss / len(loader):.4f}")

train(model, train_loader, epochs=5)

# --- 6. Visualize Reconstructions ---
def visualize_reconstruction(model, loader):
    model.eval()
    x, y = next(iter(loader))
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        x_hat = model(x, y)
    x, x_hat = x.cpu(), x_hat.cpu()

    fig, axes = plt.subplots(2, 10, figsize=(12, 3))
    for i in range(10):
        axes[0, i].imshow(x[i][0], cmap='gray')
        axes[1, i].imshow(x_hat[i][0], cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].axis('off')
    plt.suptitle("Top: Original | Bottom: Reconstructed via Class-Specific Decoder")
    plt.show()

visualize_reconstruction(model, test_loader)


## Demo Code 3

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST data loader
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Shared Encoder
class Encoder(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

    def forward(self, x):
        return self.net(x)

# Class-specific Decoder
class Decoder(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Sigmoid()
        )

    def forward(self, z):
        x_hat = self.net(z)
        return x_hat.view(-1, 1, 28, 28)

# Full Autoencoder
class MultiDecoderAutoencoder(nn.Module):
    def __init__(self, num_classes=10, latent_dim=32):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoders = nn.ModuleList([Decoder(latent_dim) for _ in range(num_classes)])

    def forward(self, x, label):
        z = self.encoder(x)
        x_hat = torch.zeros_like(x)
        for i in range(len(self.decoders)):
            mask = (label == i)
            if mask.any():
                x_hat[mask] = self.decoders[i](z[mask])
        return x_hat

model = MultiDecoderAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training
def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        x_hat = model(x, y)
        loss = criterion(x_hat, x)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    print(f"[Train Loss] {total_loss / len(loader.dataset):.4f}")

# Train the model
for epoch in range(1, 6):
    print(f"Epoch {epoch}")
    train_epoch(model, train_loader, optimizer)

# Unsupervised decoder inference
def infer_decoder_unsupervised(model, x_batch):
    model.eval()
    x_batch = x_batch.to(device)
    with torch.no_grad():
        z = model.encoder(x_batch)
        errors = []
        for decoder in model.decoders:
            x_hat = decoder(z)
            err = F.mse_loss(x_hat, x_batch, reduction='none')
            err = err.view(err.size(0), -1).mean(dim=1)
            errors.append(err.unsqueeze(1))
        all_errors = torch.cat(errors, dim=1)
        best_decoder = all_errors.argmin(dim=1)
    return best_decoder.cpu()

# Evaluate unsupervised accuracy
def evaluate_unsupervised(model, loader):
    correct = 0
    total = 0
    for x, y in loader:
        pred = infer_decoder_unsupervised(model, x)
        correct += (pred == y).sum().item()
        total += x.size(0)
    print(f"[Unsupervised Inference Accuracy] {correct}/{total} = {correct/total:.2%}")

evaluate_unsupervised(model, test_loader)
