In [None]:
# Kaggle Notebook: Mizan-based Embedding Training on MNIST

!pip install mizanvector -q  # once you publish to PyPI; for now you can copy losses.py into the notebook

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

# If mizanvector is not installed yet, paste MizanContrastiveLoss here:
import torch.nn.functional as F

def mizan_similarity_torch(v1, v2, p: float = 2.0, eps: float = 1e-8):
    diff = torch.norm(v1 - v2, p=2, dim=-1)
    norm1 = torch.norm(v1, p=2, dim=-1)
    norm2 = torch.norm(v2, p=2, dim=-1)
    num = diff**p
    den = norm1**p + norm2**p + eps
    return 1.0 - num / den

class MizanContrastiveLoss(nn.Module):
    def __init__(self, margin: float = 0.5, p: float = 2.0, eps: float = 1e-8):
        super().__init__()
        self.margin = margin
        self.p = p
        self.eps = eps

    def forward(self, emb1, emb2, labels):
        sim = mizan_similarity_torch(emb1, emb2, p=self.p, eps=self.eps)
        labels = labels.float()
        pos_loss = labels * (1.0 - sim)
        neg_loss = (1.0 - labels) * F.relu(sim - self.margin)
        return (pos_loss + neg_loss).mean()


# --------------------------
# Dataset: MNIST pairs
# --------------------------

class MNISTPairs(Dataset):
    def __init__(self, train=True):
        transform = transforms.Compose([transforms.ToTensor()])
        self.data = datasets.MNIST(root="./data", train=train, download=True, transform=transform)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img1, label1 = self.data[idx]

        idx2 = torch.randint(low=0, high=len(self.data), size=(1,)).item()
        img2, label2 = self.data[idx2]

        y = 1 if label1 == label2 else 0

        img1 = img1.view(-1)
        img2 = img2.view(-1)

        return img1, img2, torch.tensor(y, dtype=torch.float)


# --------------------------
# Tiny Encoder Model
# --------------------------

class TinyEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 32),  # 32-dim embedding
        )

    def forward(self, x):
        return self.fc(x)


# --------------------------
# Training Setup
# --------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TinyEncoder().to(device)
criterion = MizanContrastiveLoss(margin=0.4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

dataset = MNISTPairs(train=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

# --------------------------
# Training Loop
# --------------------------

for epoch in range(3):
    model.train()
    total_loss = 0.0
    for x1, x2, y in loader:
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)

        emb1 = model(x1)
        emb2 = model(x2)

        loss = criterion(emb1, emb2, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1}, Loss={avg_loss:.4f}")

print("Training complete!")
