# Code starts here


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transf
from data_feed import DataFeed, DataFeed_image_pos
from build_net import resnet50, NN_beam_pred, MultinomialLogisticRegression

# Image Feature Extractor
class ImageFeatureExtractor(nn.Module):
    def __init__(self, output_dim=128):
        super(ImageFeatureExtractor, self).__init__()
        base_model = resnet50(pretrained=True, num_classes=64)
        #base_model.fc = nn.Identity()  # Remove classification layer
        self.feature_extractor = base_model
        self.fc = nn.Linear(128, output_dim)  # Project to desired output dimension
        #self.bn = nn.BatchNorm1d(output_dim)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.fc(x)
        #x = self.bn(x)
        return x

# Position Feature Extractor
class PosFeatureExtractor(nn.Module):
    def __init__(self, input_dim=4, output_dim=128):
        super(PosFeatureExtractor, self).__init__()
        self.feature_extractor = NN_beam_pred(num_features=input_dim, num_output=output_dim)
        #self.bn = nn.BatchNorm1d(output_dim)

    def forward(self, x):
        x = self.feature_extractor(x)
        #x = self.bn(x)
        return x





In [24]:
class MultiModalNetwork(nn.Module):
    def __init__(self, input_size_audio=None, input_size_visual=None, hidden_size=256, output_size=64, z_dim=100):
        super(MultiModalNetwork, self).__init__()
        
        self.has_audio = input_size_audio is not None
        self.has_visual = input_size_visual is not None
        self.z_dim = z_dim  # Dimensionality of random noise for generator
        
        if self.has_audio:
            # Audio feature extractor
            self.audio_feature_extractor = PosFeatureExtractor(output_dim=hidden_size)
            
            # Common and Specific classifiers for audio
            self.audio_common_classifier = nn.Linear(hidden_size // 2, output_size)
            self.audio_specific_classifier = nn.Linear(hidden_size // 2, output_size)
        
        if self.has_visual:
            # Visual feature extractor
            self.visual_feature_extractor = ImageFeatureExtractor(output_dim=hidden_size)

            # Common and Specific classifiers for visual
            self.visual_common_classifier = nn.Linear(hidden_size // 2, output_size)
            self.visual_specific_classifier = nn.Linear(hidden_size // 2, output_size)
        
        # Common classifier shared by both modalities
        self.common_classifier = nn.Linear(hidden_size // 2, output_size)  # Common features

        # Generator Network for learning modality-common features
        self.generator = nn.Sequential(
            nn.Linear(z_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size // 2),  # Output common modality features
        )

    def forward(self, audio_input=None, visual_input=None, z=None):
        audio_features = self.audio_feature_extractor(audio_input) if self.has_audio and audio_input is not None else None
        visual_features = self.visual_feature_extractor(visual_input) if self.has_visual and visual_input is not None else None

        common_audio_features = None
        common_visual_features = None
        specific_audio_features = None
        specific_visual_features = None
        
        if self.has_audio and audio_features is not None:
            # Split audio features into common and specific parts
            common_audio_features = audio_features[:, :audio_features.size(1) // 2]
            specific_audio_features = audio_features[:, audio_features.size(1) // 2:]
        
        if self.has_visual and visual_features is not None:
            # Split visual features into common and specific parts
            common_visual_features = visual_features[:, :visual_features.size(1) // 2]
            specific_visual_features = visual_features[:, visual_features.size(1) // 2:]

        # Generate modality-common features if z is provided (for knowledge distillation)
        generated_common_features = None
        generated_common_pred = None

        if z is not None:
            generated_common_features = self.generator(z)  # Generated audio common features
            generated_common_pred = self.common_classifier(generated_common_features)  # Generated audio common features

        # Process each modality's common features with their classifiers
        final_audio_pred = None
        final_visual_pred = None
        mean_common_features = 0
        modality = 0

        if common_audio_features is not None:
            modality += 1
            mean_common_features += common_audio_features if mean_common_features is not None else 0
            common_audio_pred = self.audio_common_classifier(common_audio_features)
            specific_audio_pred = self.audio_specific_classifier(specific_audio_features)
            final_audio_pred = common_audio_pred + specific_audio_pred  # Combining both predictions

        if common_visual_features is not None:
            modality += 1
            mean_common_features += common_visual_features if mean_common_features is not None else 0
            common_visual_pred = self.visual_common_classifier(common_visual_features)
            specific_visual_pred = self.visual_specific_classifier(specific_visual_features)
            final_visual_pred = common_visual_pred + specific_visual_pred  # Combining both predictions

        # Normalize mean_common_features by the number of contributing modalities
        if modality > 0:
            mean_common_features = mean_common_features / modality  

        # Compute final prediction by averaging predictions of all classifiers
        if final_audio_pred is not None and final_visual_pred is not None:
            final_prediction = (final_audio_pred + final_visual_pred) 
        elif final_audio_pred is not None:
            final_prediction = final_audio_pred
        elif final_visual_pred is not None:
            final_prediction = final_visual_pred
        else:
            final_prediction = None  # No valid predictions

        

        return final_prediction, (final_audio_pred, final_visual_pred, common_audio_features, common_visual_features, specific_audio_features, specific_visual_features, generated_common_features, generated_common_pred, labels, mean_common_features)
    def compute_loss(self, audio_input, visual_input, labels, z, alpha1=1.0, alpha2=1.0, alpha_gen=1.0, beta=1.0, alpha_kd=1.0):
        # Forward pass
        final_prediction, (final_audio_pred, final_visual_pred, common_audio_features, common_visual_features, specific_audio_features, specific_visual_features, generated_common_features, generated_common_pred, labels, mean_common_features) = self(audio_input, visual_input, z)

        # Initialize losses to zero
        similarity_loss = 0.0
        auxiliary_loss = 0.0
        difference_loss = 0.0
        generation_loss = 0.0
        kd_loss = 0.0  # Knowledge Distillation Loss

        # 1) Knowledge Distillation Loss (using the local generator)
        if common_audio_features is not None:
            kd_loss += self.compute_knowledge_distillation_loss(common_audio_features, z)
        if common_visual_features is not None:
            kd_loss += self.compute_knowledge_distillation_loss(common_visual_features, z)

        # 2) Similarity Loss (F_sim_k)
        if common_audio_features is not None and common_visual_features is not None:
            kl_loss_audio = self.compute_kl_divergence(common_audio_features, common_visual_features)
            similarity_loss = kl_loss_audio / 2  # Normalized by the number of modalities

        # 3) Auxiliary Classification Loss (F_cls_k)
        if common_audio_features is not None:
            auxiliary_loss += self.compute_auxiliary_classification_loss(common_audio_features, labels)
        if common_visual_features is not None:
            auxiliary_loss += self.compute_auxiliary_classification_loss(common_visual_features, labels)

        # 4) Difference Loss (F_dif_k) - Orthogonality between common and specific features
        if common_audio_features is not None and specific_audio_features is not None:
            difference_loss += self.compute_difference_loss(common_audio_features, specific_audio_features)
        if common_visual_features is not None and specific_visual_features is not None:
            difference_loss += self.compute_difference_loss(common_visual_features, specific_visual_features)

        # 5) Generation Loss (F_gen_k)
        if generated_common_features is not None:
            generation_loss += self.compute_generation_loss(generated_common_features, generated_common_pred, mean_common_features, labels, beta) 

        # 6) Total Loss (F_dec_k)
        total_loss = alpha1 * similarity_loss + alpha2 * difference_loss + auxiliary_loss + alpha_gen * generation_loss + alpha_kd * kd_loss
        return total_loss, similarity_loss, auxiliary_loss, difference_loss, generation_loss, kd_loss

    def compute_knowledge_distillation_loss(self, common_features, z):
        # Generate modality-common features using the local generator (input noise z)
        generated_features = self.generator(z)  # Generate modality-common features from noise
    
        # Pass both real common features and generated features through the common classifier
        common_features_pred = self.common_classifier(common_features)
        generated_features_pred = self.common_classifier(generated_features)
    
        # Apply softmax to both the predicted features
        softmax_common_features = F.softmax(common_features_pred, dim=-1)
        softmax_generated_features = F.softmax(generated_features_pred, dim=-1)
    
        # Compute KL divergence between the softmax outputs of the real and generated features
        kd_loss = F.kl_div(softmax_common_features.log(), softmax_generated_features, reduction='batchmean')
    
        return kd_loss

    def compute_kl_divergence(self, common_audio_features, common_visual_features):
        # Apply softmax to features and compute KL divergence
        softmax_audio = F.softmax(common_audio_features, dim=-1)
        softmax_visual = F.softmax(common_visual_features, dim=-1)
        kl_divergence = F.kl_div(softmax_audio.log(), softmax_visual, reduction='batchmean')
        return kl_divergence

    def compute_auxiliary_classification_loss(self, common_features, labels):
        # Cross-entropy loss for auxiliary classification
        return F.cross_entropy(self.common_classifier(common_features), labels)

    def compute_difference_loss(self, common_features, specific_features):
        # Orthogonality loss to ensure modality-common and modality-specific features are distinct
        return torch.norm(torch.matmul(common_features.T, specific_features), p='fro')**2

    def compute_generation_loss(self, generated_common_features, generated_common_pred, mean_common_features, labels, beta):
        # Mean squared error loss to ensure the generated features align with the true features
        generated_common_pred = F.softmax(generated_common_pred, dim=-1)
        return F.cross_entropy(generated_common_pred, labels) + beta * F.mse_loss(generated_common_features, mean_common_features)





In [27]:
# Initialize Model Parameters
input_size_audio = 4  # Example size, modify according to your data
input_size_visual = 224  # Example size, modify according to your data
hidden_size = 256  # Hidden size of the network
output_size = 64  # Number of output classes
z_dim = 100  # Dimensionality of the generator's input noise

# Initialize the MultiModalNetwork model
model = MultiModalNetwork(input_size_audio=input_size_audio, input_size_visual=input_size_visual, hidden_size=hidden_size, output_size=output_size, z_dim=z_dim)

# Example input data
audio_input = torch.randn(10, input_size_audio)  # Batch of 10 audio samples
visual_input = torch.randn(10, 3, input_size_visual, input_size_visual)  # Batch of 10 visual samples
labels = torch.randint(0, output_size, (10,))  # Random labels for the batch
z = torch.randn(10, z_dim)  # Random noise for generator

x,_ = model(audio_input=audio_input, visual_input=visual_input, z=z)


Output layer dim = 64
<class 'build_net.Bottleneck'>


In [29]:
labels

tensor([62, 29, 36, 60, 25, 33, 33, 16, 31,  1])

In [31]:
# Compute the total loss with both modalities present
total_loss, similarity_loss, auxiliary_loss, difference_loss, generation_loss, kd_loss = model.compute_loss(audio_input, visual_input, labels, z)

print("Total Loss:", total_loss)
print("Similarity Loss:", similarity_loss)
print("Auxiliary Classification Loss:", auxiliary_loss)
print("Difference Loss:", difference_loss)
print("Generation Loss:", generation_loss)
print("Knowledge Distillation Loss:", kd_loss)

Total Loss: tensor(32704.4082, grad_fn=<AddBackward0>)
Similarity Loss: tensor(0.0440, grad_fn=<DivBackward0>)
Auxiliary Classification Loss: tensor(8.3068, grad_fn=<AddBackward0>)
Difference Loss: tensor(32691.8125, grad_fn=<AddBackward0>)
Generation Loss: tensor(4.2153, grad_fn=<AddBackward0>)
Knowledge Distillation Loss: tensor(0.0293, grad_fn=<AddBackward0>)


In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
# Initialize the optimizer (Adam in this case)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train(model, audio_loader, visual_loader, labels_loader, z_loader, num_epochs=10, alpha1=1.0, alpha2=1.0, alpha_gen=1.0, alpha_kd=1.0):
    model.train()  # Set the model to training mode
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        for audio_input, visual_input, labels, z in zip(audio_loader, visual_loader, labels_loader, z_loader):
            optimizer.zero_grad()  # Zero the gradients

            # Compute the loss
            loss_outputs = model.compute_loss(audio_input, visual_input, labels, z, alpha1, alpha2, alpha_gen, alpha_kd)
            total_loss, similarity_loss, auxiliary_loss, difference_loss, generation_loss, kd_loss = loss_outputs

            # Backpropagate the loss
            total_loss.backward()
            optimizer.step()  # Update the model parameters

            running_loss += total_loss.item()

        # Print the average loss after each epoch
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(audio_loader)}")

# Initialize Model Parameters
input_size_audio = 4  # Example size, modify according to your data
input_size_visual = 224  # Example size, modify according to your data
hidden_size = 256  # Hidden size of the network
output_size = 64  # Number of output classes
z_dim = 100  # Dimensionality of the generator's input noise

# Simulate 100 batches of random tensors for audio, visual, labels, and z
audio_loader = [torch.randn(10, input_size_audio) for _ in range(100)]
visual_loader = [torch.randn(10, 3, input_size_visual, input_size_visual) for _ in range(100)]
labels_loader = [torch.randint(0, output_size, (10,)) for _ in range(100)]
z_loader = [torch.randn(10, z_dim) for _ in range(100)]

# Train the model
train(model, audio_loader, visual_loader, labels_loader, z_loader, num_epochs=10)

# Evaluation function
def evaluate(model, audio_input, visual_input):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # No need to calculate gradients during inference
        final_pred, (final_audio_pred, final_visual_pred, *_) = model(audio_input, visual_input)
        return final_audio_pred, final_visual_pred

# Example evaluation
audio_input = torch.randn(1, input_size_audio)  # Single audio sample
visual_input = torch.randn(1, input_size_visual)  # Single visual sample
audio_pred, visual_pred = evaluate(model, audio_input, visual_input)
print("Audio Prediction:", audio_pred)
print("Visual Prediction:", visual_pred)


 10%|█         | 1/10 [03:21<30:12, 201.40s/it]

Epoch [1/10], Loss: 342.80546870231626


 20%|██        | 2/10 [06:42<26:51, 201.41s/it]

Epoch [2/10], Loss: 10.78820011138916


 30%|███       | 3/10 [10:26<24:40, 211.54s/it]

Epoch [3/10], Loss: 9.52267520904541


 40%|████      | 4/10 [14:17<21:55, 219.25s/it]

Epoch [4/10], Loss: 8.88086745262146


 50%|█████     | 5/10 [17:37<17:41, 212.22s/it]

Epoch [5/10], Loss: 8.543712978363038


 60%|██████    | 6/10 [20:57<13:52, 208.11s/it]

Epoch [6/10], Loss: 8.472490768432618


 70%|███████   | 7/10 [24:16<10:15, 205.16s/it]

Epoch [7/10], Loss: 8.4392414188385


 80%|████████  | 8/10 [27:43<06:51, 205.90s/it]

Epoch [8/10], Loss: 8.430868740081786


 90%|█████████ | 9/10 [31:09<03:25, 205.64s/it]

Epoch [9/10], Loss: 8.410905723571778


100%|██████████| 10/10 [34:32<00:00, 207.29s/it]

Epoch [10/10], Loss: 8.397985458374023





RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 224]

In [None]:
# Needed Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import seaborn as sns
import networkx as nx
from torch.utils.data import DataLoader
import torchvision.transforms as transf
from data_feed import DataFeed, DataFeed_image_pos
from build_net import resnet50, NN_beam_pred, MultinomialLogisticRegression

In [None]:
# Fixing the seed for reproducibility
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)


if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# or full reproducibility
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Now you can use the `device` variable to move your model and data to the correct device
print(f"Using device: {device}")

In [None]:
# Directory containing the saved CSV files
output_dir = "./feature_nonIID/"

# Load one of the CSV files for EDA (e.g., user_0_outputs.csv)
df = pd.read_csv(output_dir + "user_0_pos_height_beam.csv")

# Quick overview of the data
print("Data Overview:")
print(df.head())
print("\nData Summary:")
print(df.describe())

# Check for missing values
print("\nMissing Values:")
print(df.isnull().sum())

In [None]:
########################################################################
########################### Data pre-processing ########################
########################################################################
no_users = 20
batch_size = 64
img_resize = transf.Resize((224, 224))
img_norm = transf.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
proc_pipe = transf.Compose(
    [transf.ToPILImage(),
     img_resize,
     transf.ToTensor(),
     img_norm]
)
dataset_dir = "feature_nonIID/"
train_loaders = []
test_loaders = []
val_loaders = []

for user_id in range(no_users):
    train_dir = dataset_dir + f'user_{user_id}_pos_height_beam_train.csv'
    val_dir = dataset_dir + f'user_{user_id}_pos_height_beam_val.csv'
    test_dir = dataset_dir + f'user_{user_id}_pos_height_beam_test.csv'
    
    train_dataset = DataFeed_image_pos(train_dir, transform=proc_pipe)
    val_dataset = DataFeed_image_pos(root_dir=val_dir, transform=proc_pipe)
    test_dataset = DataFeed_image_pos(root_dir=test_dir, transform=proc_pipe)
    
    
    train_loaders.append(DataLoader(train_dataset,
                              batch_size=batch_size,
                              #num_workers=8,
                              shuffle=True))
    val_loaders.append(DataLoader(val_dataset,
                            batch_size=batch_size,
                            #num_workers=8,
                            shuffle=False))
    test_loaders.append(DataLoader(test_dataset,
                            batch_size=batch_size,
                            #num_workers=8,
                            shuffle=False))
print("All loadred are loaded")

In [None]:
# Model Preperation#
all_models = []
available_modalities = ["pos_height", "images"]
modality_size = {"pos_height": 128, "images": 128}


In [None]:
# Configuration
import random
no_users = 20  # Example: Number of users
available_modalities = ["pos_height", "images"]
modality_size = {"pos_height": 128, "images": 128}
group_definitions = {
    1: ["pos_height"],        # Group 1: Only pos_height
    2: ["images"],            # Group 2: Only images
    3: ["pos_height", "images"]  # Group 3: Both modalities
}

# Assign each user to a group randomly
weights = [0.2, 0.3, 0.5]  # Probabilities for groups 1, 2, and 3

# Generate user_groups with weighted random choices
user_groups = random.choices([1, 2, 3], weights=weights, k=no_users)

# Assign modalities to users based on their group
user_modalities = [group_definitions[group] for group in user_groups]

# Compute output sizes for each user based on their modalities
output_sizes = [sum(modality_size[modality] for modality in user_modality) for user_modality in user_modalities]

# Store models (placeholders for actual models)
all_models = []

# Example output (for verification)
print(f"User Groups: {user_groups[:10]}")  # Show first 10 users' groups
print(f"User Modalities: {user_modalities[:10]}")  # Show first 10 users' modalities
print(f"Output Sizes: {output_sizes[:10]}")  # Show first 10 users' output sizes

In [None]:
def sinkhorn_knopp(matrix, tol=1e-9, max_iter=1000):
    """
    Converts a given matrix to a doubly stochastic matrix using the Sinkhorn-Knopp algorithm.
    
    Parameters:
        matrix (np.ndarray): The input matrix to be transformed.
        tol (float): The tolerance for convergence.
        max_iter (int): Maximum number of iterations for convergence.
    
    Returns:
        np.ndarray: A doubly stochastic matrix.
    """
    matrix = matrix.copy()
    for _ in range(max_iter):
        # Normalize rows
        row_sums = matrix.sum(axis=1, keepdims=True)
        matrix /= row_sums

        # Normalize columns
        col_sums = matrix.sum(axis=0, keepdims=True)
        matrix /= col_sums

        # Check for convergence
        if np.allclose(matrix.sum(axis=1), 1, atol=tol) and np.allclose(matrix.sum(axis=0), 1, atol=tol):
            break

    return matrix
    
def create_random_topology(num_users, similarity_matrix, edge_probability=0.3):
    """
    Creates a connected random topology using NetworkX.
    Returns the adjacency matrix.
    """
    while True:
        graph = nx.erdos_renyi_graph(num_users, edge_probability)
        adjacency_matrix = nx.to_numpy_array(graph)
        new_adj = np.multiply(adjacency_matrix, similarity_matrix)
        new_graph = nx.from_numpy_array(new_adj)
        if nx.is_connected(new_graph):
            break

    # Convert graph to adjacency matrix
    adjacency_matrix = nx.to_numpy_array(new_graph)
    return adjacency_matrix

def prepare_mixing_matrices(adjacency_matrix, similarity_matrices):
    """
    Computes a mixing matrix for each modality by multiplying the adjacency matrix 
    with the similarity matrix for that modality.
    Returns a dictionary of mixing matrices.
    """
    adjacency_matrices = {}
    mixing_matrices = {}
    for modality, similarity_matrix in similarity_matrices.items():
        # Element-wise multiplication of adjacency and similarity matrices
        combined_matrix = adjacency_matrix * similarity_matrix
        adjacency_matrices[modality] = combined_matrix
        
        # Normalize to create a doubly matrix
        mixing_matrix = sinkhorn_knopp(combined_matrix)
        
        
        mixing_matrices[modality] = mixing_matrix
    
    return mixing_matrices, adjacency_matrices




In [None]:
# Create random connected topology
#adjacency_matrix = create_random_topology(no_users, edge_probability=0.3)
# Initialize adjacency matrix
similarity_matrix = np.zeros((no_users, no_users), dtype=int)

# Construct the adjacency matrix
for i in range(no_users):
    for j in range(no_users):
        if i != j:  # No self-loops
            # Check if users i and j share any modalities
            if set(user_modalities[i]) & set(user_modalities[j]):
                similarity_matrix[i, j] = 1

# Display the adjacency matrix
print("Adjacency Matrix:")
print(similarity_matrix)

# Prepare mixing matrices for each modality
#mixing_matrices, adjacency_matrices = prepare_mixing_matrices(adjacency_matrix, similarity_matrices)
adjacency_matrix = create_random_topology(20, similarity_matrix, edge_probability=0.3)
print(adjacency_matrix)

In [None]:
# Draw the graph
# Define colors for the groups
group_colors = {1: 'red', 2: 'green', 3: 'blue'}
node_colors = [group_colors[group] for group in user_groups]
G = nx.from_numpy_array(similarity_matrix)
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, edge_color='gray', node_size=1000, node_color=node_colors, font_size=20, font_color='black')

# Show the plot
plt.show()

In [None]:
# Draw the graph





G = nx.from_numpy_array(adjacency_matrix)
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, edge_color='gray', node_size=1000, node_color=node_colors, font_size=20, font_color='black')

# Show the plot
plt.show()
print(user_groups)

In [None]:
# Similarity matrices
adj_per_modality = {}
for modality in available_modalities:
    adj = np.zeros((no_users, no_users))
    for node in range(no_users):
        for neighbor in G.neighbors(node):
            if modality in user_modalities[neighbor] and modality in user_modalities[node]:
                adj[node, neighbor] = 1.    
    adj_per_modality[modality] = adj


In [None]:
G_modality = nx.from_numpy_array(adj_per_modality["images"])
pos = nx.spring_layout(G_modality)
nx.draw(G_modality, pos, with_labels=True, edge_color='gray', node_size=1000, node_color=node_colors, font_size=20, font_color='black')

# Show the plot
plt.show()

In [None]:
def construct_mixing_matrix(Adj, method="metropolis"):
    n = Adj.shape[0]
    W = np.zeros((n, n))  # Initialize weight matrix

    for i in range(n):
        degree_i = np.sum(Adj[i, :])

        for j in range(n):
            if Adj[i, j] == 1.0:
                degree_j = np.sum(Adj[j, :])
    
                if method == "metropolis":
                    W[i, j] = 1 / (max(degree_i, degree_j) + 1)
                elif method == "uniform":
                    W[i, j] = 1 / degree_i

        # Diagonal weight
        W[i, i] = 1 - np.sum(W[i, :])

    return W

mixing_matrices = {}
for modality in available_modalities:
    mixing_matrices[modality] = construct_mixing_matrix(adj_per_modality[modality], method="metropolis")
    print(np.sum(mixing_matrices[modality], 0))
    print(np.sum(mixing_matrices[modality], 1))
    lamb = np.linalg.eigvals(mixing_matrices[modality])
    lamb.sort()
    print(lamb)

In [None]:
G_modality = nx.from_numpy_array(adj_per_modality["pos_height"])
pos = nx.spring_layout(G_modality)
largest_cc = max(nx.connected_components(G_modality), key=len)

# Convert to sorted list of indices
connected_nodes = sorted(largest_cc)

# Extract the submatrix corresponding to the connected subgraph
W_reduced = mixing_matrices["pos_height"][np.ix_(connected_nodes, connected_nodes)]
lamb = np.linalg.eigvals(W_reduced)
lamb.sort()
print(lamb)

In [None]:
# Image Feature Extractor
class ImageFeatureExtractor(nn.Module):
    def __init__(self, output_dim=128):
        super(ImageFeatureExtractor, self).__init__()
        base_model = resnet50(pretrained=True, num_classes=64)
        #base_model.fc = nn.Identity()  # Remove classification layer
        self.feature_extractor = base_model
        #self.fc = nn.Linear(128, output_dim)  # Project to desired output dimension
        self.bn = nn.BatchNorm1d(output_dim)

    def forward(self, x):
        x = self.feature_extractor(x)
        #x = self.fc(x)
        x = self.bn(x)
        return x

# Position Feature Extractor
class PosFeatureExtractor(nn.Module):
    def __init__(self, input_dim=4, output_dim=128):
        super(PosFeatureExtractor, self).__init__()
        self.feature_extractor = NN_beam_pred(num_features=input_dim, num_output=output_dim)
        self.bn = nn.BatchNorm1d(output_dim)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.bn(x)
        return x

# Classification Head
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes=64):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# Main Model with Named Sub-Networks
class Classifier(nn.Module):
    def __init__(self, use_image=True, use_pos=True, feature_dim=128, num_classes=64):
        super(Classifier, self).__init__()
        
        # Store sub-networks in a dictionary
        self.sub_networks = nn.ModuleDict()

        if use_image:
            self.sub_networks["images"] = ImageFeatureExtractor(output_dim=feature_dim)
        if use_pos:
            self.sub_networks["pos_height"] = PosFeatureExtractor(output_dim=feature_dim)

        # Determine input size for classification head
        input_dim = (feature_dim if use_image else 0) + (feature_dim if use_pos else 0)
        self.classifier = ClassificationHead(input_dim, num_classes)

    def forward(self, images=None, pos_height=None):
        features = []

        if "images" in self.sub_networks and images is not None:
            features.append(self.sub_networks["images"](images))

        if "pos_height" in self.sub_networks and pos_height is not None:
            features.append(self.sub_networks["pos_height"](pos_height))

        if not features:
            raise ValueError("At least one modality (image or pos) must be used")

        x = torch.cat(features, dim=1) if len(features) > 1 else features[0]
        return self.classifier(x)


In [None]:
model = Classifier(use_image=True, use_pos=True)

# Extract the image feature extractor using its name
image_extractor = model.sub_networks["pos_height"]

classifier_head = model.classifier


In [None]:
lr = 1e-3
optimizers = []
all_models = []
classifier_optimizers = []
for user_id in range(no_users):
    if "images" in user_modalities[user_id] and "pos_height" in user_modalities[user_id]:
        user_model = Classifier(use_image=True, use_pos=True).to(device)
    elif "pos_height" in user_modalities[user_id]:
        user_model = Classifier(use_image=False, use_pos=True).to(device)
    elif "images" in user_modalities[user_id]:
        user_model = Classifier(use_image=True, use_pos=False).to(device)
    local_optimizer = optim.Adam(user_model.parameters(), lr=lr)
    class_optim = optim.Adam(user_model.classifier.parameters(), lr=lr)

    all_models.append(user_model)
    optimizers.append(local_optimizer)
    classifier_optimizers.append(class_optim)
base_models = Classifier(use_image=True, use_pos=True).to(device)

In [None]:


# Decentralized aggregation function
def per_modelaity_decentralized_aggregation(user_models, mixing_matrices, available_modalities, user_modalities, base_models):
    num_users = len(user_models)
    with torch.no_grad():
        for modality in available_modalities:
            # Get the mixing matrix for the current modality
            mixing_matrix = mixing_matrices[modality]
            
            # Convert user model parameters to vectors for aggregation
            aggregated_models = []
            aggregated_updates = []
            for user_id, user_model in enumerate(user_models):
                if modality in user_modalities[user_id]:
                    aggregated_models.append(torch.nn.utils.parameters_to_vector(user_model.sub_networks[modality].parameters()))
                    aggregated_updates.append(torch.zeros_like(aggregated_models[-1]))
                else:
                    aggregated_models.append(0)
                    aggregated_updates.append(0)
            
            
            # Perform model aggregation based on the mixing matrix for this modality
            for i in range(num_users):
                for j in range(num_users):
                    if mixing_matrix[i, j] > 0:
                        aggregated_updates[i] += mixing_matrix[i, j] * aggregated_models[j]
            
            # Update user models with aggregated parameters for the current modality
            for user_id in range(num_users):
                if modality in user_modalities[user_id]:
                    torch.nn.utils.vector_to_parameters(aggregated_updates[user_id], user_models[user_id].sub_networks[modality].parameters())

#per_modelaity_decentralized_aggregation(all_models, mixing_matrices, available_modalities, user_modalities, base_models)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

def train_local_model(local_modalities, model, train_loader, criterion, optimizer, epochs, device):
    """
    Trains a local multi-modal model.

    Args:
        local_modalities (list): Modalities to use (e.g., ['image', 'pos']).
        model (Classifier): Multi-modal classification model.
        train_loader (DataLoader): Training data loader.
        criterion (nn.CrossEntropyLoss): Loss function.
        optimizer (torch.optim.Optimizer): Optimizer.
        epochs (int): Number of training epochs.
        device (torch.device): Device (CPU/GPU).

    Returns:
        tuple: Minimum training loss, maximum training accuracy.
    """
    # Unfreeze the layers
    # freezing first layers 
    for mod in local_modalities:
        for param in model.sub_networks[mod].parameters():
            param.requires_grad = True  # Freezes the feature extractor
    model.to(device)
    model.train()

    training_losses = []
    training_accuracies = []

    for epoch in range(epochs):
        epoch_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        for batch in train_loader:
            inputs, labels = batch

            # Prepare input data for selected modalities
            modality_inputs = {mod: inputs[mod].to(device) for mod in local_modalities}
            labels = labels.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(**modality_inputs)

            # Compute loss
            loss = criterion(model, outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Update metrics
            epoch_loss += loss.item()
            _, predicted = torch.max(outputs, dim=1)
            correct_predictions += (predicted == labels).sum().item()
            total_samples += labels.size(0)

        # Compute loss and accuracy
        avg_loss = epoch_loss / len(train_loader)
        accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0

        training_losses.append(avg_loss)
        training_accuracies.append(accuracy)

        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

    return min(training_losses), max(training_accuracies)


In [None]:
def validate_user_models(user_id, model, val_loader, criterion, local_modalities, device):
    """
    Validates a trained multi-modal model.

    Args:
        user_id (int): User identifier.
        model (Classifier): Multi-modal classification model.
        val_loader (DataLoader): Validation data loader.
        criterion (nn.CrossEntropyLoss): Loss function.
        local_modalities (list): Modalities to use (e.g., ['image', 'pos']).
        device (torch.device): Device (CPU/GPU).

    Returns:
        dict: Validation loss and accuracy.
    """
    model.to(device)
    model.eval()

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch

            # Prepare input data for selected modalities
            modality_inputs = {mod: inputs[mod].to(device) for mod in local_modalities}
            labels = labels.to(device)

            # Forward pass
            outputs = model(**modality_inputs)
            loss = criterion(model, outputs, labels)

            # Accumulate metrics
            total_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, dim=1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)

    # Compute loss and accuracy
    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0

    print(f"User {user_id + 1} - Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

    return {"loss": avg_loss, "accuracy": accuracy}
