In [15]:
from utils.data_loader import load_MNISTdata, get_colored_mnist_dataloader

In [1]:
import torch.nn as nn

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),  # 28x28 → 28x28
            nn.ReLU(),
            nn.MaxPool2d(2),  # 28x28 → 14x14

            nn.Conv2d(32, 64, 3, padding=1),  # 14x14 → 14x14
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14 → 7x7
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        feats = self.features(x)
        logits = self.classifier(feats)
        return logits, feats.view(x.size(0), -1)  # logits, feature_vector


In [8]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def train_mnist_classifier(n_epochs=5, save_path="mnist_classifier.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=64)

    model = MNISTClassifier().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    print("📚 Training MNIST classifier...")
    for epoch in range(n_epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            out, _ = model(x)
            loss = criterion(out, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{n_epochs} ✅ Loss: {loss.item():.4f}")

    # Evaluate
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out, _ = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    acc = correct / total
    print(f"✅ Test Accuracy: {acc*100:.2f}%")

    torch.save(model.state_dict(), save_path)
    print(f"💾 Saved classifier to {save_path}")
    return model


In [11]:
classifier = train_mnist_classifier(n_epochs=5)

📚 Training MNIST classifier...
Epoch 1/5 ✅ Loss: 0.0033
Epoch 2/5 ✅ Loss: 0.1449
Epoch 3/5 ✅ Loss: 0.0001
Epoch 4/5 ✅ Loss: 0.0030
Epoch 5/5 ✅ Loss: 0.0004
✅ Test Accuracy: 99.19%
💾 Saved classifier to mnist_classifier.pth


In [13]:
import torch.nn as nn

class ColoredMNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),  # 28x28 → 28x28
            nn.ReLU(),
            nn.MaxPool2d(2),  # 28x28 → 14x14

            nn.Conv2d(32, 64, 3, padding=1),  # 14x14 → 14x14
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14 → 7x7
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        feats = self.features(x)
        logits = self.classifier(feats)
        return logits, feats.view(x.size(0), -1)  # logits, feature_vector


In [1]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

def train_colored_mnist_classifier(train_loader, n_epochs=5, save_path="colored_mnist_classifier.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
    model = ColoredMNISTClassifier().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    print("Training Colored MNIST Classifier...")
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            out, _ = model(x)
            loss = criterion(out, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{n_epochs} Loss: {avg_loss:.4f}")

    # Optional: Save model
    torch.save(model.state_dict(), save_path)
    print(f"Saved classifier to {save_path}")
    return model


In [17]:
train_loader, _ = load_MNISTdata()  # Load MNIST using your function
colored_mnist_loader = get_colored_mnist_dataloader(train_loader.dataset, batch_size=128, minority_ratio=0.5)


In [None]:
classifier = train_colored_mnist_classifier(colored_mnist_loader, n_epochs=10)


Training Colored MNIST Classifier...
Epoch 1/10 Loss: 0.3019
Epoch 2/10 Loss: 0.0760
Epoch 3/10 Loss: 0.0529
Epoch 4/10 Loss: 0.0404
Epoch 5/10 Loss: 0.0336
Epoch 6/10 Loss: 0.0257
