In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
from PIL import ImageFilter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== 1. Multi-modal MNIST dataset ======
class MultiModalMNIST(torch.utils.data.Dataset):
    def __init__(self, train=True):
        self.dataset = datasets.MNIST(root="./data", train=train, download=True)
        
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        
        # Modality 1: original
        m1 = TF.to_tensor(img)
        # Modality 2: edge-detected
        m2 = TF.to_tensor(img.filter(ImageFilter.FIND_EDGES))
        # Modality 3: inverted
        m3 = TF.to_tensor(TF.invert(img))
        
        return (m1, m2, m3), label
    
    def __len__(self):
        return len(self.dataset)

train_loader = torch.utils.data.DataLoader(
    MultiModalMNIST(train=True),
    batch_size=64, shuffle=True
)

# ====== 2. Tiny CNN encoders per modality ======
class TinyEncoder(nn.Module):
    def __init__(self, out_dim=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(16*7*7, out_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# 3 nodes (modalities)
num_modalities = 3
embed_dim = 128
encoders = nn.ModuleList([TinyEncoder(embed_dim).to(device) for _ in range(num_modalities)])



In [2]:
# ====== 3. Restriction maps P_{i->e} ======
edges = [(0,1),(0,2),(1,2)]
edge_dim = 64  # shared comparison space
restrictions = nn.ParameterDict()
for (i,j) in edges:
    restrictions[f"{i}->{i}-{j}"] = nn.Parameter(torch.randn(edge_dim, embed_dim) * 0.1)
    restrictions[f"{j}->{i}-{j}"] = nn.Parameter(torch.randn(edge_dim, embed_dim) * 0.1)

# ====== 4. Loss functions ======
def cosine_sim(a, b):
    return F.cosine_similarity(a.unsqueeze(1), b.unsqueeze(0), dim=-1)

def contrastive_loss(p_i, p_j, tau=0.1):
    # p_i, p_j: [B, D]
    sim_ij = cosine_sim(p_i, p_j) / tau
    labels = torch.arange(p_i.size(0)).to(device)
    # Symmetric InfoNCE
    loss_i = F.cross_entropy(sim_ij, labels)
    loss_j = F.cross_entropy(sim_ij.t(), labels)
    return 0.5 * (loss_i + loss_j)

def laplacian_loss(p_i, p_j):
    return ((p_i - p_j)**2).sum(dim=1).mean()

# ====== 5. Training loop ======
optimizer = torch.optim.Adam(list(encoders.parameters()) + list(restrictions.parameters()), lr=1e-4)
lambda_lap = 1.0
beta_contrast = 1.0



In [3]:
for epoch in range(10):  # small demo
    for (mods, labels) in train_loader:
        mods = [m.to(device) for m in mods]
        batch_size = mods[0].size(0)

        # Local embeddings h_i
        h = [enc(mods[i]) for i, enc in enumerate(encoders)]  # list of [B, embed_dim]

        total_loss = 0.0

        # Loop over edges for sheaf contrastive + Laplacian loss
        for (i,j) in edges:
            P_i = restrictions[f"{i}->{i}-{j}"]
            P_j = restrictions[f"{j}->{i}-{j}"]

            p_i = (h[i] @ P_i.t())   # [B, edge_dim]
            p_j = (h[j] @ P_j.t())   # [B, edge_dim]

            lap_loss = laplacian_loss(p_i, p_j)
            contrast_loss = contrastive_loss(p_i, p_j)

            total_loss += lambda_lap * lap_loss + beta_contrast * contrast_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}")


Epoch 1, Loss: 0.2781
Epoch 2, Loss: 0.1960
Epoch 3, Loss: 0.1265
Epoch 4, Loss: 0.1224
Epoch 5, Loss: 0.1026
Epoch 6, Loss: 0.0593
Epoch 7, Loss: 0.0816
Epoch 8, Loss: 0.0674
Epoch 9, Loss: 0.0497
Epoch 10, Loss: 0.0517


In [4]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
from PIL import ImageFilter

# ====== 0. Load your trained components ======
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# These should be the trained encoders and restriction maps from your training script
#from train_multimodal_mnist import encoders, restrictions, edges  # <-- import your trained model parts
embed_dim = 32
edge_dim = 16

# ====== 1. Multi-modal MNIST test dataset ======
class MultiModalMNIST(torch.utils.data.Dataset):
    def __init__(self, train=False):
        self.dataset = datasets.MNIST(root="./data", train=train, download=True)
        
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        m1 = TF.to_tensor(img)
        m2 = TF.to_tensor(img.filter(ImageFilter.FIND_EDGES))
        m3 = TF.to_tensor(TF.invert(img))
        return (m1, m2, m3), label
    
    def __len__(self):
        return len(self.dataset)

test_loader = DataLoader(MultiModalMNIST(train=False), batch_size=256, shuffle=False)

# Helper to normalize embeddings
def encode_modality(modality_idx, loader):
    enc = encoders[modality_idx].eval()
    embs, labels = [], []
    with torch.no_grad():
        for mods, y in loader:
            x = mods[modality_idx].to(device)
            h = enc(x)
            h = F.normalize(h, dim=-1)
            embs.append(h)
            labels.append(y)
    return torch.cat(embs), torch.cat(labels)





[Zero-Shot Eval] Using modality 0 embeddings
Zero-Shot Accuracy (modality 0): 81.67%

[Cross-Modal Retrieval] Query: 0 -> Gallery: 1
Recall@1: 11.37%
Recall@5: 28.67%
Recall@10: 39.94%

[Cross-Modal Retrieval] Query: 1 -> Gallery: 0
Recall@1: 11.01%
Recall@5: 28.93%
Recall@10: 40.15%


In [5]:
# ====== 2. Zero-Shot Learning Evaluation ======
def zero_shot_eval(mod_idx=0):
    """Use the embedding space of modality `mod_idx` for classification via nearest centroid."""
    print(f"\n[Zero-Shot Eval] Using modality {mod_idx} embeddings")
    # Encode all test data
    test_embs, test_labels = encode_modality(mod_idx, test_loader)

    # Compute class centroids
    num_classes = 10
    centroids = []
    for c in range(num_classes):
        class_embs = test_embs[test_labels == c]
        centroids.append(class_embs.mean(dim=0))
    centroids = F.normalize(torch.stack(centroids), dim=-1)  # [10, D]

    # Classify by cosine similarity
    sims = test_embs @ centroids.t()  # [N, 10]
    preds = sims.argmax(dim=1)
    acc = (preds == test_labels.to(device)).float().mean().item()
    print(f"Zero-Shot Accuracy (modality {mod_idx}): {acc*100:.2f}%")
    return acc

# ====== 3. Cross-Modal Retrieval ======
def cross_modal_retrieval(query_mod=0, gallery_mod=1, top_k=(1,5,10)):
    print(f"\n[Cross-Modal Retrieval] Query: {query_mod} -> Gallery: {gallery_mod}")
    q_embs, q_labels = encode_modality(query_mod, test_loader)
    g_embs, g_labels = encode_modality(gallery_mod, test_loader)

    # Compute cosine similarity
    sim = q_embs @ g_embs.t()
    ranks = sim.argsort(dim=1, descending=True)

    recalls = {k: 0 for k in top_k}
    for i, label in enumerate(q_labels):
        retrieved_labels = g_labels[ranks[i]]
        for k in top_k:
            if label in retrieved_labels[:k]:d
                recalls[k] += 1

    for k in top_k:
        recalls[k] /= len(q_labels)
        print(f"Recall@{k}: {recalls[k]*100:.2f}%")
    return recalls

# ====== 4. Run evaluations ======
if __name__ == "__main__":
    # Example: evaluate ZSL for modality 0
    zero_shot_eval(mod_idx=0)

    # Example: cross-modal retrieval (0 -> 1 and 1 -> 0)
    cross_modal_retrieval(query_mod=0, gallery_mod=2)
    cross_modal_retrieval(query_mod=2, gallery_mod=0)


[Zero-Shot Eval] Using modality 0 embeddings
Zero-Shot Accuracy (modality 0): 81.67%

[Cross-Modal Retrieval] Query: 0 -> Gallery: 2
Recall@1: 9.92%
Recall@5: 24.89%
Recall@10: 34.46%

[Cross-Modal Retrieval] Query: 2 -> Gallery: 0
Recall@1: 9.61%
Recall@5: 25.87%
Recall@10: 36.45%
