In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
from PIL import ImageFilter
import random
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from collections import defaultdict



device = torch.device("cuda" if torch.cuda.is_available() else "CPU")


# ==== 1. Multi-modal MNIST Dataset ====
class MultiModalMNIST(Dataset):
    def __init__(self, train=True):
        self.dataset = datasets.MNIST(root="./data", train=train, download=True)
        
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        
        # Modality 1: original
        m1 = TF.to_tensor(img)
        # Modality 2: edge-detected
        m2 = TF.to_tensor(img.filter(ImageFilter.FIND_EDGES))
        # Modality 3: inverted
        m3 = TF.to_tensor(TF.invert(img))
        
        return (m1.view(-1), m2.view(-1), m3.view(-1)), label  # flatten each modality
    
    def __len__(self):
        return len(self.dataset)

# ====== 1. Multi-modal MNIST dataset ======
class MultiModalMNIST(torch.utils.data.Dataset):
    def __init__(self, train=True):
        self.dataset = datasets.MNIST(root="./data", train=train, download=True)
        
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        
        # Modality 1: original
        m1 = TF.to_tensor(img)
        # Modality 2: edge-detected
        m2 = TF.to_tensor(img.filter(ImageFilter.FIND_EDGES))
        # Modality 3: inverted
        m3 = TF.to_tensor(TF.invert(img))
        
        return (m1, m2, m3), label
    def __len__(self):
        return len(self.dataset)





# --- Build training and testing sets ---
train_dataset = MultiModalMNIST(train=True)
test_dataset = MultiModalMNIST(train=False)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# ==== 2. Graph Setup ====
# We have 3 nodes (modality-specific learners)
num_nodes = 3
graph = {
    0: [1, 2],
    1: [0, 2],
    2: [0, 1]
}

# Node-specific DataLoaders (extracting single modality)
def modality_dataloader(dataset, modality_idx):
    class SingleModality(Dataset):
        def __init__(self, base_dataset, idx):
            self.base = base_dataset
            self.idx = idx
        def __len__(self):
            return len(self.base)
        def __getitem__(self, i):
            Xs, y = self.base[i]
            return Xs[self.idx], y
    return DataLoader(SingleModality(dataset, modality_idx), batch_size=batch_size, shuffle=True)

data_loaders = {i: modality_dataloader(train_dataset, i) for i in range(num_nodes)}

# ==== 3. Node Encoders and Maps ====
import torch.nn as nn

# --- Node Encoder ---
class NodeEncoder(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, embedding_dim)
        )
    def forward(self, x):
        return self.encoder(x)

class TinyEncoder(nn.Module):
    def __init__(self, embedding_dim=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(16*7*7, embedding_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)




# --- Maps ---
class RestrictionMap(nn.Module):  # Node -> Edge
    def __init__(self, node_dim, edge_dim):
        super().__init__()
        self.map = nn.Linear(node_dim, edge_dim, bias=False)
    def forward(self, h):
        return self.map(h)

class TransportMap(nn.Module):  # Edge -> Node
    def __init__(self, edge_dim, node_dim):
        super().__init__()
        self.map = nn.Linear(edge_dim, node_dim, bias=False)
    def forward(self, z):
        return self.map(z)

# --- Initialize models ---
input_dim = 28*28      # flattened MNIST
embedding_dim = 64
edge_dim = 64

#encoders = {i: NodeEncoder(input_dim, embedding_dim) for i in range(num_nodes)}
encoders = {i: TinyEncoder(embedding_dim).to(device) for i in range(num_nodes)}


# Initialize P and Q maps
P_maps = {i: {} for i in range(num_nodes)}
Q_maps = {i: {} for i in range(num_nodes)}
for i in range(num_nodes):
    for j in graph[i]:
        P_maps[i][j] = RestrictionMap(embedding_dim, edge_dim).to(device)
        Q_maps[i][j] = TransportMap(edge_dim, embedding_dim).to(device)

# ==== 4. Optimizers ====
optimizer_dict = {
    i: torch.optim.Adam(list(encoders[i].parameters()) +
                        [p for P in P_maps[i].values() for p in P.parameters()] +
                        [p for Q in Q_maps[i].values() for p in Q.parameters()],
                        lr=1e-3)
    for i in range(num_nodes)
}

print("Dataset and graph setup complete.")
print(f"Graph: {graph}")
print(f"DataLoaders: {len(data_loaders)} nodes ready.")
print(f"Device: {device}")




Dataset and graph setup complete.
Graph: {0: [1, 2], 1: [0, 2], 2: [0, 1]}
DataLoaders: 3 nodes ready.
Device: cuda


# implementation 1

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict

# --- Loss Functions ---
def contrastive_loss(local_emb, transported_emb, temperature=0.1):
    local_norm = F.normalize(local_emb, dim=-1)
    transported_norm = F.normalize(transported_emb, dim=-1)
    
    logits = torch.matmul(local_norm, transported_norm.T) / temperature
    labels = torch.arange(local_emb.size(0), device=local_emb.device)
    return F.cross_entropy(logits, labels)

def sheaf_laplacian_loss(h_i, h_j, P_ij, P_ji):
    """
    ||P_ij h_i - P_ji h_j||^2
    """
    z_i = P_ij(h_i)
    z_j = P_ji(h_j)
    return F.mse_loss(z_i, z_j)


# --- One Training Step ---
def decentralized_training_step(node_id, batch_x, 
                                neighbors_data, 
                                encoder, P_maps, Q_maps, 
                                lambda_lap=10.0, beta_contrast=1.0):
    """
    One node-centric sheaf training step
    """
    # --- Local embedding ---
    h_i = encoder(batch_x)  # [B, d_node]

    lap_loss, contrast_loss = 0.0, 0.0
    for j, (x_j, enc_j) in neighbors_data.items():
        h_j = enc_j(x_j)

        # Laplacian term
        lap_loss += sheaf_laplacian_loss(
            h_i, h_j, P_maps[node_id][j], P_maps[j][node_id]
        )

        # Transport embedding from j->i
        transported = Q_maps[node_id][j](
            P_maps[j][node_id](h_j)
        )
        
        # Contrastive term
        contrast_loss += contrastive_loss(h_i, transported)
        #print(lap_loss, contrast_loss, lambda_lap, beta_contrast)
    
    return lambda_lap * lap_loss + beta_contrast * contrast_loss


# --- Main Training Loop ---
def train_sheaf_decentralized(graph, data_loaders, encoders, P_maps, Q_maps,
                              optimizer_dict, epochs=10, device=device):
    """
    graph: adjacency dict {i: [j1, j2, ...]}
    data_loaders: dict {i: DataLoader}
    encoders: dict {i: NodeEncoder}
    P_maps, Q_maps: dict {i: {j: map}}
    optimizer_dict: dict {i: torch.optim.Optimizer}
    """
    for i in encoders:
        encoders[i].to(device)

    for epoch in range(epochs):
        total_losses = defaultdict(float)

        # zip(*...) aligns batches across all nodes
        for batch_nodes in zip(*data_loaders.values()):
            # Compute local losses for each node
            for i, batch in enumerate(batch_nodes):
                x_i = batch[0].to(device)  # Assume (data, label)

                # Collect neighbors
                neighbors_data = {}
                for j in graph[i]:
                    x_j = batch_nodes[j][0].to(device)
                    neighbors_data[j] = (x_j, encoders[j])

                # Compute loss
                loss = decentralized_training_step(
                    i, x_i, neighbors_data, 
                    encoders[i], P_maps, Q_maps
                )

                optimizer_dict[i].zero_grad()
                loss.backward()
                optimizer_dict[i].step()
                
                total_losses[i] += loss.item()

        avg_loss = sum(total_losses.values()) / len(total_losses)
        print(f"Epoch {epoch+1}/{epochs}: avg loss {avg_loss:.4f}")



@torch.no_grad()
def encode_node(node_id, data_loader, encoder, device="cuda"):
    encoder.eval()
    all_embs, all_labels = [], []
    for x, y in data_loader:
        x = x.to(device)
        h = encoder(x)  # [B, d]
        all_embs.append(h)
        all_labels.append(y.to(device))
    return torch.cat(all_embs, dim=0), torch.cat(all_labels, dim=0)


@torch.no_grad()
def sheaf_node_eval(graph, data_loaders, encoders, P_maps, Q_maps, device="cuda"):
    """
    Evaluates:
    1. Zero-shot node-level classification
    2. Cross-modal retrieval between nodes
    using transported embeddings.
    """
    # --- Step 1: Encode all nodes ---
    node_embs, node_labels = {}, {}
    for i in graph.keys():
        h, y = encode_node(i, data_loaders[i], encoders[i].to(device), device)
        node_embs[i], node_labels[i] = h, y

    # --- Step 2: Zero-shot classification ---
    print("\n[Zero-Shot Classification]")
    zs_accs = []
    for target in graph.keys():
        best_acc = 0.0

        for source in graph.keys():
            if source == target:
                continue

            # Transport source embeddings into target space
            transported = Q_maps[target][source](
                P_maps[source][target](node_embs[source])
            )

            # Nearest neighbor in embedding space
            sim = torch.matmul(
                F.normalize(transported, dim=-1),
                F.normalize(node_embs[target], dim=-1).T
            )  # [N_source, N_target]

            preds = node_labels[target][sim.argmax(dim=-1)]
            acc = accuracy_score(node_labels[source].cpu(), preds.cpu())
            best_acc = max(best_acc, acc)

        zs_accs.append(best_acc)
        print(f"Node {target}: best zero-shot acc {best_acc:.4f}")

    print(f"Average Zero-Shot Accuracy: {sum(zs_accs)/len(zs_accs):.4f}")

    # --- Step 3: Cross-modal retrieval ---
    print("\n[Cross-Modal Retrieval]")
    retrieval_scores = []
    for i in graph.keys():
        for j in graph[i]:
            if i < j:  # avoid double-counting
                # Transport j embeddings to i space
                h_j2i = Q_maps[i][j](P_maps[j][i](node_embs[j]))

                # Cosine similarities
                sim = F.cosine_similarity(
                    F.normalize(node_embs[i], dim=-1)[:, None, :],
                    F.normalize(h_j2i, dim=-1)[None, :, :],
                    dim=-1
                )  # [N_i, N_j]

                avg_sim = sim.max(dim=1)[0].mean().item()
                retrieval_scores.append(avg_sim)
                print(f"Node {i}<->{j}: avg retrieval sim {avg_sim:.4f}")

    print(f"Average Retrieval Score: {sum(retrieval_scores)/len(retrieval_scores):.4f}")
    return zs_accs, retrieval_scores

In [3]:
train_sheaf_decentralized(
    graph, data_loaders, encoders, P_maps, Q_maps,
    optimizer_dict, epochs=50, device=device
)




Epoch 1/50: avg loss 1716472630.9152
Epoch 2/50: avg loss 88858493.9486
Epoch 3/50: avg loss 26924558.3197
Epoch 4/50: avg loss 11809284.0068
Epoch 5/50: avg loss 5737292.7255
Epoch 6/50: avg loss 2847134.6883
Epoch 7/50: avg loss 1407418.8653
Epoch 8/50: avg loss 669632.2762
Epoch 9/50: avg loss 323574.0210
Epoch 10/50: avg loss 167954.9782
Epoch 11/50: avg loss 91574.4375
Epoch 12/50: avg loss 971555652148.8871
Epoch 13/50: avg loss 56850947120.3333
Epoch 14/50: avg loss 26597788294.2500
Epoch 15/50: avg loss 14189575567.7917
Epoch 16/50: avg loss 7149780888.8125
Epoch 17/50: avg loss 3333478882.3333
Epoch 18/50: avg loss 1405959273.1094
Epoch 19/50: avg loss 649293600.5208
Epoch 20/50: avg loss 300215410.8932
Epoch 21/50: avg loss 155136897.3477
Epoch 22/50: avg loss 83462028.1169
Epoch 23/50: avg loss 46180112.6489
Epoch 24/50: avg loss 26081503.4200
Epoch 25/50: avg loss 15078183.7037
Epoch 26/50: avg loss 8548683.0741
Epoch 27/50: avg loss 14416760382303.3047
Epoch 28/50: avg los

In [4]:
sheaf_node_eval(
    graph, data_loaders, encoders, P_maps, Q_maps, device=device
)


[Zero-Shot Classification]
Node 0: best zero-shot acc 0.1143
Node 1: best zero-shot acc 0.1062
Node 2: best zero-shot acc 0.0992
Average Zero-Shot Accuracy: 0.1066

[Cross-Modal Retrieval]


OutOfMemoryError: CUDA out of memory. Tried to allocate 13.41 GiB. GPU 0 has a total capacity of 31.74 GiB of which 4.37 GiB is free. Including non-PyTorch memory, this process has 27.37 GiB memory in use. Of the allocated memory 13.53 GiB is allocated by PyTorch, and 13.44 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Implementation 2

In [None]:
'''


import torch
from collections import defaultdict
import torch.nn.functional as F


def decentralized_training_step(node_id, batch_x, 
                                neighbors_data, 
                                encoder, P_maps, Q_maps, 
                                lambda_lap=1.0, beta_contrast=1.0):
    """
    node_id: int
    batch_x: [B, input_dim] for local node
    neighbors_data: dict {j: (batch_x_j, encoder_j)}
    P_maps, Q_maps: dict {(i,j): RestrictionMap or TransportMap}
    """
    B = batch_x.size(0)
    
    # --- Local embedding ---
    h_i = encoder(batch_x)  # [B, d_node]
    
    # --- Compute outgoing messages ---
    z_outgoing = {}
    for j, P_ij in P_maps[node_id].items():
        z_outgoing[j] = P_ij(h_i)  # [B, d_edge]

    # --- Receive neighbor messages and compute losses ---
    lap_loss, contrast_loss = 0.0, 0.0
    for j, (x_j, enc_j) in neighbors_data.items():
        h_j = enc_j(x_j)

        # Laplacian term
        z_i_to_j = P_maps[node_id][j](h_i)
        z_j_to_i = P_maps[j][node_id](h_j)
        lap_loss += F.mse_loss(z_i_to_j, z_j_to_i)

        # Transported embedding: Q_ij(P_ji(h_j))
        transported = Q_maps[node_id][j](z_j_to_i)
        contrast_loss += contrastive_loss(h_i, transported)

    total_loss = lambda_lap * lap_loss + beta_contrast * contrast_loss
    return total_loss


def contrastive_loss(local_emb, transported_emb, temperature=0.1):
    local_norm = F.normalize(local_emb, dim=-1)
    transported_norm = F.normalize(transported_emb, dim=-1)
    
    logits = torch.matmul(local_norm, transported_norm.T) / temperature
    labels = torch.arange(local_emb.size(0), device=local_emb.device)
    return F.cross_entropy(logits, labels)


def train_sheaf_decentralized(graph, data_loaders, encoders, P_maps, Q_maps,
                              optimizer_dict, epochs=10, 
                              lambda_lap=1.0, beta_contrast=1.0,
                              device=device):
    """
    graph: adjacency dict {i: [j1, j2, ...]}
    data_loaders: dict {i: DataLoader}
    encoders: dict {i: NodeEncoder}
    P_maps, Q_maps: dict {i: {j: map}}
    optimizer_dict: dict {i: torch.optim.Optimizer}
    """
    for epoch in range(epochs):
        total_losses = defaultdict(float)
        batch_count = 0

        for batch_nodes in zip(*data_loaders.values()):
            # Each batch_nodes is a tuple: (x_i, y_i), ..., one per node
            batch_count += 1

            for i, batch in enumerate(batch_nodes):
                batch_x = batch[0].to(device)  # assuming (x, y)

                # Get neighbor batches
                neighbors_data = {}
                for j in graph[i]:
                    x_j = batch_nodes[j][0].to(device)
                    neighbors_data[j] = (x_j, encoders[j].to(device))

                enc_i = encoders[i].to(device)
                enc_i.train()

                # Compute loss
                loss = decentralized_training_step(
                    i, batch_x, neighbors_data, enc_i, 
                    P_maps, Q_maps, 
                    lambda_lap, beta_contrast
                )

                optimizer_dict[i].zero_grad()
                loss.backward()
                optimizer_dict[i].step()

                total_losses[i] += loss.item()

        avg_epoch_loss = sum(total_losses.values()) / len(total_losses)
        print(f"[Epoch {epoch+1}] Avg loss: {avg_epoch_loss:.4f}")


'''

In [None]:
'''

train_sheaf_decentralized(
    graph=your_graph_dict,
    data_loaders=your_dataloaders,
    encoders=your_node_encoders,
    P_maps=your_restriction_maps,
    Q_maps=your_transport_maps,
    optimizer_dict=your_optimizer_dict,
    epochs=20,
    lambda_lap=1.0,
    beta_contrast=1.0,
    device=device
)
'''