In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
from PIL import ImageFilter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== 1. Multi-modal MNIST dataset ======
class MultiModalMNIST(torch.utils.data.Dataset):
    def __init__(self, train=True):
        self.dataset = datasets.MNIST(root="./data", train=train, download=True)
        
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        
        # Modality 1: original
        m1 = TF.to_tensor(img)
        # Modality 2: edge-detected
        m2 = TF.to_tensor(img.filter(ImageFilter.FIND_EDGES))
        # Modality 3: inverted
        m3 = TF.to_tensor(TF.invert(img))
        
        return (m1, m2, m3), label
    
    def __len__(self):
        return len(self.dataset)

train_loader = torch.utils.data.DataLoader(
    MultiModalMNIST(train=True),
    batch_size=64, shuffle=True
)

# ====== 2. Tiny CNN encoders per modality ======
class TinyEncoder(nn.Module):
    def __init__(self, out_dim=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(16*7*7, out_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# 3 nodes (modalities)
num_modalities = 3
embed_dim = 128
encoders = nn.ModuleList([TinyEncoder(embed_dim).to(device) for _ in range(num_modalities)])


# ====== 3. Restriction maps P_{i->e} and Dual maps Q_{e->i} ======
edges = [(0,1),(0,2),(1,2)]
edge_dim = 64  # shared comparison space

restrictions = nn.ParameterDict()  # P maps
duals = nn.ParameterDict()         # Q maps

for (i,j) in edges:
    # P: node -> edge
    restrictions[f"{i}->{i}-{j}"] = nn.Parameter(torch.randn(edge_dim, embed_dim) * 0.1)
    restrictions[f"{j}->{i}-{j}"] = nn.Parameter(torch.randn(edge_dim, embed_dim) * 0.1)

    # Q: edge -> node
    duals[f"{i}-{j}->{i}"] = nn.Parameter(torch.randn(embed_dim, edge_dim) * 0.1)
    duals[f"{i}-{j}->{j}"] = nn.Parameter(torch.randn(embed_dim, edge_dim) * 0.1)


# ====== 4. Loss functions ======
def cosine_sim(a, b):
    return F.cosine_similarity(a.unsqueeze(1), b.unsqueeze(0), dim=-1)

def contrastive_loss(p_i, p_j, tau=0.1):
    # p_i, p_j: [B, D]
    sim_ij = cosine_sim(p_i, p_j) / tau
    labels = torch.arange(p_i.size(0)).to(device)
    # Symmetric InfoNCE
    loss_i = F.cross_entropy(sim_ij, labels)
    loss_j = F.cross_entropy(sim_ij.t(), labels)
    return 0.5 * (loss_i + loss_j)

def laplacian_loss(p_i, p_j):
    return ((p_i - p_j)**2).sum(dim=1).mean()

def reconstruction_loss(h_i, p_i, Q_ei):
    # Reconstruct node embedding from edge embedding
    # p_i: node->edge embedding [B, edge_dim]
    # Q_ei: edge->node map [embed_dim, edge_dim]
    recon = p_i @ Q_ei.t()  # [B, embed_dim]
    return F.mse_loss(recon, h_i)

# ====== 5. Training loop ======
optimizer = torch.optim.Adam(
    list(encoders.parameters()) + list(restrictions.parameters()) + list(duals.parameters()),
    lr=1e-4
)

lambda_lap = 1.0
beta_contrast = 10.0
gamma_recon = 0.1  # weight for reconstruction loss

for epoch in range(50):  # small demo
    for (mods, labels) in train_loader:
        mods = [m.to(device) for m in mods]
        batch_size = mods[0].size(0)

        # Local embeddings h_i
        h = [enc(mods[i]) for i, enc in enumerate(encoders)]  # list of [B, embed_dim]

        total_loss = 0.0

        # Loop over edges for sheaf contrastive + Laplacian + Reconstruction
        for (i,j) in edges:
            P_i = restrictions[f"{i}->{i}-{j}"]
            P_j = restrictions[f"{j}->{i}-{j}"]
            Q_ei = duals[f"{i}-{j}->{i}"]
            Q_ej = duals[f"{i}-{j}->{j}"]

            p_i = (h[i] @ P_i.t())   # [B, edge_dim]
            p_j = (h[j] @ P_j.t())   # [B, edge_dim]

            lap_loss = laplacian_loss(p_i, p_j)
            contrast_loss = contrastive_loss(p_i, p_j)

            # Reconstruction from edge back to nodes
            recon_loss_i = reconstruction_loss(h[i], p_i, Q_ei)
            recon_loss_j = reconstruction_loss(h[j], p_j, Q_ej)

            
            total_loss += (
                lambda_lap * lap_loss +
                beta_contrast * contrast_loss +
                gamma_recon * (recon_loss_i + recon_loss_j)
            )

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}")


Epoch 1, Loss: 2.7384
Epoch 2, Loss: 1.1660
Epoch 3, Loss: 1.2381
Epoch 4, Loss: 0.8045
Epoch 5, Loss: 0.9656
Epoch 6, Loss: 0.9429
Epoch 7, Loss: 0.4589
Epoch 8, Loss: 0.6917
Epoch 9, Loss: 0.5351
Epoch 10, Loss: 0.5625
Epoch 11, Loss: 0.4575
Epoch 12, Loss: 0.6223
Epoch 13, Loss: 0.4560
Epoch 14, Loss: 0.5552
Epoch 15, Loss: 0.3097
Epoch 16, Loss: 0.2188
Epoch 17, Loss: 0.2892
Epoch 18, Loss: 0.3468
Epoch 19, Loss: 0.2762
Epoch 20, Loss: 0.2042
Epoch 21, Loss: 0.3284
Epoch 22, Loss: 0.5363
Epoch 23, Loss: 0.2427
Epoch 24, Loss: 0.2566
Epoch 25, Loss: 0.2504
Epoch 26, Loss: 0.2343
Epoch 27, Loss: 0.2034
Epoch 28, Loss: 0.1608
Epoch 29, Loss: 0.1951
Epoch 30, Loss: 0.2055
Epoch 31, Loss: 0.2039
Epoch 32, Loss: 0.1841
Epoch 33, Loss: 0.2830
Epoch 34, Loss: 0.1670
Epoch 35, Loss: 0.1931
Epoch 36, Loss: 0.1272
Epoch 37, Loss: 0.5061
Epoch 38, Loss: 0.1659
Epoch 39, Loss: 0.1814
Epoch 40, Loss: 0.2031
Epoch 41, Loss: 0.2190
Epoch 42, Loss: 0.2357
Epoch 43, Loss: 0.1466
Epoch 44, Loss: 0.13

In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
from PIL import ImageFilter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== 0. Load your trained components ======
# encoders, restrictions, duals, edges should be loaded from your training
# Example:
# from train_multimodal_mnist import encoders, restrictions, duals, edges
embed_dim = 128
edge_dim = 64

# ====== 1. Multi-modal MNIST test dataset ======
class MultiModalMNIST(torch.utils.data.Dataset):
    def __init__(self, train=False):
        self.dataset = datasets.MNIST(root="./data", train=train, download=True)
        
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        m1 = TF.to_tensor(img)
        m2 = TF.to_tensor(img.filter(ImageFilter.FIND_EDGES))
        m3 = TF.to_tensor(TF.invert(img))
        return (m1, m2, m3), label
    
    def __len__(self):
        return len(self.dataset)

test_loader = DataLoader(MultiModalMNIST(train=False), batch_size=256, shuffle=False)

# ====== 2. Encode + Optionally Reconstruct via Dual Maps ======
def encode_modality(modality_idx, loader, use_reconstruction=False):
    enc = encoders[modality_idx].eval()
    embs, labels = [], []
    with torch.no_grad():
        for mods, y in loader:
            x = mods[modality_idx].to(device)
            h = enc(x)  # [B, embed_dim]

            if use_reconstruction:
                # Aggregate reconstructions from all edges touching this node
                reconstructions = []
                for (i, j) in edges:
                    if modality_idx in (i, j):
                        edge_key = f"{i}-{j}"
                        node_key = f"{modality_idx}->{edge_key}"
                        dual_key = f"{edge_key}->{modality_idx}"
                        P = restrictions[node_key]    # [edge_dim, embed_dim]
                        Q = duals[dual_key]           # [embed_dim, edge_dim]

                        edge_emb = h @ P.t()          # project to edge
                        node_recon = edge_emb @ Q.t() # reconstruct node
                        reconstructions.append(node_recon)

                if len(reconstructions) > 0:
                    h = sum(reconstructions) / len(reconstructions)

            h = F.normalize(h, dim=-1)
            embs.append(h)
            labels.append(y)
    return torch.cat(embs), torch.cat(labels)


# ====== 3. Zero-Shot Learning Evaluation ======
def zero_shot_eval(mod_idx=0, use_reconstruction=False):
    print(f"\n[Zero-Shot Eval] Using modality {mod_idx} embeddings "
          f"{'(reconstructed)' if use_reconstruction else '(raw)'}")
    # Encode all test data
    test_embs, test_labels = encode_modality(mod_idx, test_loader, use_reconstruction)

    # Compute class centroids
    num_classes = 10
    centroids = []
    for c in range(num_classes):
        class_embs = test_embs[test_labels == c]
        centroids.append(class_embs.mean(dim=0))
    centroids = F.normalize(torch.stack(centroids), dim=-1)  # [10, D]

    # Classify by cosine similarity
    sims = test_embs @ centroids.t()  # [N, 10]
    preds = sims.argmax(dim=1)
    acc = (preds == test_labels.to(device)).float().mean().item()
    print(f"Zero-Shot Accuracy (modality {mod_idx}): {acc*100:.2f}%")
    return acc


# ====== 4. Cross-Modal Retrieval ======
def cross_modal_retrieval(query_mod=0, gallery_mod=1, top_k=(1,5,10), use_reconstruction=False):
    print(f"\n[Cross-Modal Retrieval] Query: {query_mod} -> Gallery: {gallery_mod} "
          f"{'(reconstructed)' if use_reconstruction else '(raw)'}")

    q_embs, q_labels = encode_modality(query_mod, test_loader, use_reconstruction)
    g_embs, g_labels = encode_modality(gallery_mod, test_loader, use_reconstruction)

    # Compute cosine similarity
    sim = q_embs @ g_embs.t()
    ranks = sim.argsort(dim=1, descending=True)

    recalls = {k: 0 for k in top_k}
    for i, label in enumerate(q_labels):
        retrieved_labels = g_labels[ranks[i]]
        for k in top_k:
            if label in retrieved_labels[:k]:
                recalls[k] += 1

    for k in top_k:
        recalls[k] /= len(q_labels)
        print(f"Recall@{k}: {recalls[k]*100:.2f}%")
    return recalls


# ====== 5. Run evaluations ======
if __name__ == "__main__":
    # Zero-shot on raw and reconstructed embeddings
    zero_shot_eval(mod_idx=0, use_reconstruction=False)
    zero_shot_eval(mod_idx=0, use_reconstruction=True)

    # Cross-modal retrieval
    cross_modal_retrieval(query_mod=0, gallery_mod=2, use_reconstruction=False)
    cross_modal_retrieval(query_mod=0, gallery_mod=2, use_reconstruction=True)



[Zero-Shot Eval] Using modality 0 embeddings (raw)
Zero-Shot Accuracy (modality 0): 72.86%

[Zero-Shot Eval] Using modality 0 embeddings (reconstructed)
Zero-Shot Accuracy (modality 0): 57.18%

[Cross-Modal Retrieval] Query: 0 -> Gallery: 2 (raw)
Recall@1: 10.20%
Recall@5: 30.55%
Recall@10: 44.53%

[Cross-Modal Retrieval] Query: 0 -> Gallery: 2 (reconstructed)
Recall@1: 10.50%
Recall@5: 31.18%
Recall@10: 46.85%


In [3]:
# Embeddings 128 
'''
[Zero-Shot Eval] Using modality 0 embeddings (raw)
Zero-Shot Accuracy (modality 0): 67.30%

[Zero-Shot Eval] Using modality 0 embeddings (reconstructed)
Zero-Shot Accuracy (modality 0): 64.59%

[Cross-Modal Retrieval] Query: 0 -> Gallery: 2 (raw)
Recall@1: 11.04%
Recall@5: 31.41%
Recall@10: 45.51%

[Cross-Modal Retrieval] Query: 0 -> Gallery: 2 (reconstructed)
Recall@1: 10.28%
Recall@5: 30.39%
Recall@10: 44.75%
'''



# Embeddings 64

'''
[Zero-Shot Eval] Using modality 0 embeddings (raw)
Zero-Shot Accuracy (modality 0): 75.59%

[Zero-Shot Eval] Using modality 0 embeddings (reconstructed)
Zero-Shot Accuracy (modality 0): 54.00%

[Cross-Modal Retrieval] Query: 0 -> Gallery: 2 (raw)
Recall@1: 8.23%
Recall@5: 25.64%
Recall@10: 38.85%

[Cross-Modal Retrieval] Query: 0 -> Gallery: 2 (reconstructed)
Recall@1: 9.04%
Recall@5: 28.85%
Recall@10: 43.70%

'''

'\n[Zero-Shot Eval] Using modality 0 embeddings (raw)\nZero-Shot Accuracy (modality 0): 75.59%\n\n[Zero-Shot Eval] Using modality 0 embeddings (reconstructed)\nZero-Shot Accuracy (modality 0): 54.00%\n\n[Cross-Modal Retrieval] Query: 0 -> Gallery: 2 (raw)\nRecall@1: 8.23%\nRecall@5: 25.64%\nRecall@10: 38.85%\n\n[Cross-Modal Retrieval] Query: 0 -> Gallery: 2 (reconstructed)\nRecall@1: 9.04%\nRecall@5: 28.85%\nRecall@10: 43.70%\n\n'