# Code starts here


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transf
from data_feed import DataFeed, DataFeed_image_pos
from build_net import resnet50, NN_beam_pred, MultinomialLogisticRegression, ImageFeatureExtractor, PosFeatureExtractor



In [None]:
class MultiModalNetwork(nn.Module):
    def __init__(self, input_size_audio=None, input_size_visual=None, hidden_size=256, output_size=64, z_dim=100):
        super(MultiModalNetwork, self).__init__()
        
        self.has_audio = input_size_audio is not None
        self.has_visual = input_size_visual is not None
        self.z_dim = z_dim  # Dimensionality of random noise for generator
        
        if self.has_audio:
            # Audio feature extractor
            self.audio_feature_extractor = PosFeatureExtractor(output_dim=hidden_size)
            
            # Common and Specific classifiers for audio
            #self.audio_common_classifier = nn.Linear(hidden_size // 2, output_size)
            self.audio_specific_classifier = nn.Linear(hidden_size // 2, output_size)
        
        if self.has_visual:
            # Visual feature extractor
            self.visual_feature_extractor = ImageFeatureExtractor(output_dim=hidden_size)

            # Common and Specific classifiers for visual
            #self.visual_common_classifier = nn.Linear(hidden_size // 2, output_size)
            self.visual_specific_classifier = nn.Linear(hidden_size // 2, output_size)
        
        # Common classifier shared by both modalities
        self.common_classifier = nn.Linear(hidden_size // 2, output_size)  # Common features

        # Generator Network for learning modality-common features
        self.generator = nn.Sequential(
            nn.Linear(z_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size // 2),  # Output common modality features
        )

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, audio_input=None, visual_input=None, z=None):
        audio_features = self.audio_feature_extractor(audio_input) if self.has_audio and audio_input is not None else None
        visual_features = self.visual_feature_extractor(visual_input) if self.has_visual and visual_input is not None else None

        common_audio_features = None
        common_visual_features = None
        specific_audio_features = None
        specific_visual_features = None
        
        if self.has_audio and audio_features is not None:
            # Split audio features into common and specific parts
            common_audio_features = audio_features[:, :audio_features.size(1) // 2]
            specific_audio_features = audio_features[:, audio_features.size(1) // 2:]
        
        if self.has_visual and visual_features is not None:
            # Split visual features into common and specific parts
            common_visual_features = visual_features[:, :visual_features.size(1) // 2]
            specific_visual_features = visual_features[:, visual_features.size(1) // 2:]

        # Generate modality-common features if z is provided (for knowledge distillation)
        generated_common_features = None
        generated_common_pred = None

        if z is not None:
            generated_common_features = self.generator(z)  # Generated audio common features
            generated_common_pred = self.common_classifier(generated_common_features)  # Generated audio common features

        # Process each modality's common features with their classifiers
        final_audio_pred = None
        final_visual_pred = None
        mean_common_features = 0
        modality = 0

        if common_audio_features is not None:
            modality += 1
            mean_common_features += common_audio_features if mean_common_features is not None else 0
            common_audio_pred = self.common_classifier(common_audio_features)
            specific_audio_pred = self.audio_specific_classifier(specific_audio_features)
            final_audio_pred = common_audio_pred + specific_audio_pred  # Combining both predictions

        if common_visual_features is not None:
            modality += 1
            mean_common_features += common_visual_features if mean_common_features is not None else 0
            common_visual_pred = self.common_classifier(common_visual_features)
            specific_visual_pred = self.visual_specific_classifier(specific_visual_features)
            final_visual_pred = common_visual_pred + specific_visual_pred  # Combining both predictions

        # Normalize mean_common_features by the number of contributing modalities
        if modality > 0:
            mean_common_features = mean_common_features / modality  

        # Compute final prediction by averaging predictions of all classifiers
        if final_audio_pred is not None and final_visual_pred is not None:
            final_prediction = (final_audio_pred + final_visual_pred) 
        elif final_audio_pred is not None:
            final_prediction = final_audio_pred
        elif final_visual_pred is not None:
            final_prediction = final_visual_pred
        else:
            final_prediction = None  # No valid predictions

        

        return final_prediction, (final_audio_pred, final_visual_pred, common_audio_features, common_visual_features, specific_audio_features, specific_visual_features, generated_common_features, generated_common_pred, mean_common_features)
    def compute_loss(self, audio_input, visual_input, labels, z, alpha0=1.0, alpha1=1.0, alpha2=1.0, alpha3=1.0, alpha_gen=1.0, alpha_kd=1.0):
        # Forward pass
        final_prediction, (final_audio_pred, final_visual_pred, common_audio_features, common_visual_features, specific_audio_features, specific_visual_features, generated_common_features, generated_common_pred, mean_common_features) = self(audio_input, visual_input, z)

        # Initialize losses to zero
        similarity_loss = 0.0
        auxiliary_loss = 0.0
        difference_loss = 0.0
        generation_loss = 0.0
        kd_loss = 0.0  # Knowledge Distillation Loss
        classification_loss = 0.0

        # 1) Knowledge Distillation Loss (using the local generator)
        if common_audio_features is not None:
            kd_loss += self.compute_knowledge_distillation_loss(common_audio_features, z)
        if common_visual_features is not None:
            kd_loss += self.compute_knowledge_distillation_loss(common_visual_features, z)

        # 2) Similarity Loss (F_sim_k)
        if common_audio_features is not None and common_visual_features is not None:
            kl_loss_audio = self.compute_kl_divergence(common_audio_features, common_visual_features)
            similarity_loss = kl_loss_audio / 2  # Normalized by the number of modalities

        # 3) Auxiliary Classification Loss (F_cls_k)
        if common_audio_features is not None:
            auxiliary_loss += self.compute_auxiliary_classification_loss(common_audio_features, labels)
        if common_visual_features is not None:
            auxiliary_loss += self.compute_auxiliary_classification_loss(common_visual_features, labels)

        # 4) Difference Loss (F_dif_k) - Orthogonality between common and specific features
        if common_audio_features is not None and specific_audio_features is not None:
            difference_loss += self.compute_difference_loss(common_audio_features, specific_audio_features)
        if common_visual_features is not None and specific_visual_features is not None:
            difference_loss += self.compute_difference_loss(common_visual_features, specific_visual_features)

        # 5) Generation Loss (F_gen_k)
        if generated_common_features is not None:
            generation_loss += self.compute_generation_loss(generated_common_features, generated_common_pred, mean_common_features, labels)

        classification_loss += self.compute_classification_loss(final_prediction, labels)
        

        # 6) Total Loss (F_dec_k)
        total_loss = alpha0 * classification_loss + alpha1 * similarity_loss + alpha2 * difference_loss + alpha3 * auxiliary_loss + alpha_gen * generation_loss + alpha_kd * kd_loss
        return total_loss, classification_loss, similarity_loss, auxiliary_loss, difference_loss, generation_loss, kd_loss

    def compute_knowledge_distillation_loss(self, common_features, z):
        # Generate modality-common features using the local generator (input noise z)
        generated_features = self.generator(z)  # Generate modality-common features from noise
    
        # Pass both real common features and generated features through the common classifier
        common_features_pred = self.common_classifier(common_features)
        generated_features_pred = self.common_classifier(generated_features)
    
        # Apply softmax to both the predicted features
        softmax_common_features = F.softmax(common_features_pred, dim=-1)
        softmax_generated_features = F.softmax(generated_features_pred, dim=-1)
    
        # Compute KL divergence between the softmax outputs of the real and generated features
        kd_loss = F.kl_div(softmax_common_features.log(), softmax_generated_features, reduction='batchmean')
    
        return kd_loss

    def compute_kl_divergence(self, common_audio_features, common_visual_features):
        # Apply softmax to features and compute KL divergence
        softmax_audio = F.softmax(common_audio_features, dim=-1)
        softmax_visual = F.softmax(common_visual_features, dim=-1)
        kl_divergence = F.kl_div(softmax_audio.log(), softmax_visual, reduction='batchmean')
        return kl_divergence

    def compute_auxiliary_classification_loss(self, common_features, labels):
        # Cross-entropy loss for auxiliary classification
        return F.cross_entropy(self.common_classifier(common_features), labels)

    def compute_difference_loss(self, common_features, specific_features):
        # Orthogonality loss to ensure modality-common and modality-specific features are distinct
        return torch.norm(torch.matmul(common_features.T, specific_features), p='fro')**2

    def compute_generation_loss(self, generated_common_features, generated_common_pred, mean_common_features, labels):
        # Mean squared error loss to ensure the generated features align with the true features
        generated_common_pred = F.softmax(generated_common_pred, dim=-1)
        beta = 1.0
        return F.cross_entropy(generated_common_pred, labels) + beta * F.mse_loss(generated_common_features, mean_common_features)

    def compute_classification_loss(self, preds, label):
        
        return self.criterion(preds, label)





In [None]:
# Initialize Model Parameters
input_size_audio = 4  # Example size, modify according to your data
input_size_visual = 224  # Example size, modify according to your data
hidden_size = 256  # Hidden size of the network
output_size = 64  # Number of output classes
z_dim = 100  # Dimensionality of the generator's input noise

# Initialize the MultiModalNetwork model
model = MultiModalNetwork(input_size_audio=input_size_audio, input_size_visual=input_size_visual, hidden_size=hidden_size, output_size=output_size, z_dim=z_dim)

# Example input data
audio_input = torch.randn(10, input_size_audio)  # Batch of 10 audio samples
visual_input = torch.randn(10, 3, input_size_visual, input_size_visual)  # Batch of 10 visual samples
labels = torch.randint(0, output_size, (10,))  # Random labels for the batch
z = torch.randn(10, z_dim)  # Random noise for generator

x = model(audio_input=audio_input, visual_input=visual_input, z=z)


In [None]:
labels

In [None]:
# Compute the total loss with both modalities present
alpha = 1e-6
num_epochs=1
alpha0 = 1e1
alpha1=alpha
alpha2=alpha
alpha3 = alpha
alpha_gen=alpha
alpha_kd=alpha
total_loss, classification_loss, similarity_loss, auxiliary_loss, difference_loss, generation_loss, kd_loss = model.compute_loss(audio_input, visual_input, labels, z, alpha0, alpha1, alpha2, alpha3, alpha_gen, alpha_kd)
# Print all the losses
print(f"Total Loss: {total_loss}")
print(f"Classification Loss: {classification_loss}")
print(f"Similarity Loss: {similarity_loss}")
print(f"Auxiliary Loss: {auxiliary_loss}")
print(f"Difference Loss: {difference_loss}")
print(f"Generation Loss: {generation_loss}")
print(f"Knowledge Distillation Loss: {kd_loss}")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
# Initialize the optimizer (Adam in this case)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train(model, audio_loader, visual_loader, labels_loader, z_loader, num_epochs=10, alpha0=1.0, alpha1=1.0, alpha2=1.0, alpha_gen=1.0, alpha_kd=1.0):
    model.train()  # Set the model to training mode
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        for audio_input, visual_input, labels, z in zip(audio_loader, visual_loader, labels_loader, z_loader):
            optimizer.zero_grad()  # Zero the gradients

            # Compute the loss
            loss_outputs = model.compute_loss(audio_input, visual_input, labels, z, alpha0, alpha1, alpha2, alpha_gen, alpha_kd)
            total_loss, class_loss, similarity_loss, auxiliary_loss, difference_loss, generation_loss, kd_loss = loss_outputs

            # Backpropagate the loss
            total_loss.backward()
            optimizer.step()  # Update the model parameters

            running_loss += total_loss.item()

        # Print the average loss after each epoch
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(audio_loader)}")

# Initialize Model Parameters
input_size_audio = 4  # Example size, modify according to your data
input_size_visual = 224  # Example size, modify according to your data
hidden_size = 256  # Hidden size of the network
output_size = 64  # Number of output classes
z_dim = 100  # Dimensionality of the generator's input noise

# Simulate 100 batches of random tensors for audio, visual, labels, and z
audio_loader = [torch.randn(10, input_size_audio) for _ in range(100)]
visual_loader = [torch.randn(10, 3, input_size_visual, input_size_visual) for _ in range(100)]
labels_loader = [torch.randint(0, output_size, (10,)) for _ in range(100)]
z_loader = [torch.randn(10, z_dim) for _ in range(100)]

# Train the model
#train(model, audio_loader, visual_loader, labels_loader, z_loader, num_epochs=10)

# Evaluation function
def evaluate(model, audio_input, visual_input):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # No need to calculate gradients during inference
        final_pred, (final_audio_pred, final_visual_pred, *_) = model(audio_input, visual_input)
        return final_audio_pred, final_visual_pred

# Example evaluation
#audio_input = torch.randn(1, input_size_audio)  # Single audio sample
#visual_input = torch.randn(1, input_size_visual)  # Single visual sample
#audio_pred, visual_pred = evaluate(model, audio_input, visual_input)
#print("Audio Prediction:", audio_pred)
#print("Visual Prediction:", visual_pred)


In [None]:
# Needed Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import seaborn as sns
import networkx as nx
from torch.utils.data import DataLoader
import torchvision.transforms as transf
from data_feed import DataFeed, DataFeed_image_pos
from build_net import resnet50, NN_beam_pred, MultinomialLogisticRegression, ImageFeatureExtractor, PosFeatureExtractor


In [None]:
# Fixing the seed for reproducibility
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)


if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# or full reproducibility
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Now you can use the `device` variable to move your model and data to the correct device
print(f"Using device: {device}")

In [None]:
# Directory containing the saved CSV files
output_dir = "./feature_IID/"

# Load one of the CSV files for EDA (e.g., user_0_outputs.csv)
df = pd.read_csv(output_dir + "user_0_pos_height_beam.csv")

# Quick overview of the data
print("Data Overview:")
print(df.head())
print("\nData Summary:")
print(df.describe())

# Check for missing values
print("\nMissing Values:")
print(df.isnull().sum())

In [None]:
########################################################################
########################### Data pre-processing ########################
########################################################################
no_users = 20
batch_size = 64
img_resize = transf.Resize((224, 224))
img_norm = transf.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
proc_pipe = transf.Compose(
    [transf.ToPILImage(),
     img_resize,
     transf.ToTensor(),
     img_norm]
)
dataset_dir = "feature_IID/"
train_loaders = []
test_loaders = []
val_loaders = []

for user_id in range(no_users):
    train_dir = dataset_dir + f'user_{user_id}_pos_height_beam_train.csv'
    val_dir = dataset_dir + f'user_{user_id}_pos_height_beam_val.csv'
    test_dir = dataset_dir + f'user_{user_id}_pos_height_beam_test.csv'
    
    train_dataset = DataFeed_image_pos(train_dir, transform=proc_pipe)
    val_dataset = DataFeed_image_pos(root_dir=val_dir, transform=proc_pipe)
    test_dataset = DataFeed_image_pos(root_dir=test_dir, transform=proc_pipe)
    
    
    train_loaders.append(DataLoader(train_dataset,
                              batch_size=batch_size,
                              #num_workers=8,
                              shuffle=True))
    val_loaders.append(DataLoader(val_dataset,
                            batch_size=batch_size,
                            #num_workers=8,
                            shuffle=False))
    test_loaders.append(DataLoader(test_dataset,
                            batch_size=batch_size,
                            #num_workers=8,
                            shuffle=False))
print("All loadred are loaded")

In [None]:
# Model Preperation#
all_models = []
available_modalities = ["pos_height", "images"]
modality_size = {"pos_height": 128, "images": 128}


In [None]:
# Configuration
import random
no_users = 20  # Example: Number of users
available_modalities = ["pos_height", "images"]
modality_size = {"pos_height": 128, "images": 128}
group_definitions = {
    1: ["pos_height"],        # Group 1: Only pos_height
    2: ["images"],            # Group 2: Only images
    3: ["pos_height", "images"]  # Group 3: Both modalities
}

# Assign each user to a group randomly
weights = [0.2, 0.3, 0.5]  # Probabilities for groups 1, 2, and 3

# Generate user_groups with weighted random choices
user_groups = random.choices([1, 2, 3], weights=weights, k=no_users)

# Assign modalities to users based on their group
user_modalities = [group_definitions[group] for group in user_groups]

# Compute output sizes for each user based on their modalities
output_sizes = [sum(modality_size[modality] for modality in user_modality) for user_modality in user_modalities]

# Store models (placeholders for actual models)
all_models = []

# Example output (for verification)
print(f"User Groups: {user_groups[:10]}")  # Show first 10 users' groups
print(f"User Modalities: {user_modalities[:10]}")  # Show first 10 users' modalities
print(f"Output Sizes: {output_sizes[:10]}")  # Show first 10 users' output sizes

In [None]:
def sinkhorn_knopp(matrix, tol=1e-9, max_iter=1000):
    """
    Converts a given matrix to a doubly stochastic matrix using the Sinkhorn-Knopp algorithm.
    
    Parameters:
        matrix (np.ndarray): The input matrix to be transformed.
        tol (float): The tolerance for convergence.
        max_iter (int): Maximum number of iterations for convergence.
    
    Returns:
        np.ndarray: A doubly stochastic matrix.
    """
    matrix = matrix.copy()
    for _ in range(max_iter):
        # Normalize rows
        row_sums = matrix.sum(axis=1, keepdims=True)
        matrix /= row_sums

        # Normalize columns
        col_sums = matrix.sum(axis=0, keepdims=True)
        matrix /= col_sums

        # Check for convergence
        if np.allclose(matrix.sum(axis=1), 1, atol=tol) and np.allclose(matrix.sum(axis=0), 1, atol=tol):
            break

    return matrix
    
def create_random_topology(num_users, similarity_matrix, edge_probability=0.3):
    """
    Creates a connected random topology using NetworkX.
    Returns the adjacency matrix.
    """
    while True:
        graph = nx.erdos_renyi_graph(num_users, edge_probability)
        adjacency_matrix = nx.to_numpy_array(graph)
        new_adj = np.multiply(adjacency_matrix, similarity_matrix)
        new_graph = nx.from_numpy_array(new_adj)
        if nx.is_connected(new_graph):
            break

    # Convert graph to adjacency matrix
    adjacency_matrix = nx.to_numpy_array(new_graph)
    return adjacency_matrix

def prepare_mixing_matrices(adjacency_matrix, similarity_matrices):
    """
    Computes a mixing matrix for each modality by multiplying the adjacency matrix 
    with the similarity matrix for that modality.
    Returns a dictionary of mixing matrices.
    """
    adjacency_matrices = {}
    mixing_matrices = {}
    for modality, similarity_matrix in similarity_matrices.items():
        # Element-wise multiplication of adjacency and similarity matrices
        combined_matrix = adjacency_matrix * similarity_matrix
        adjacency_matrices[modality] = combined_matrix
        
        # Normalize to create a doubly matrix
        mixing_matrix = sinkhorn_knopp(combined_matrix)
        
        
        mixing_matrices[modality] = mixing_matrix
    
    return mixing_matrices, adjacency_matrices




In [None]:
# Create random connected topology
#adjacency_matrix = create_random_topology(no_users, edge_probability=0.3)
# Initialize adjacency matrix
similarity_matrix = np.zeros((no_users, no_users), dtype=int)

# Construct the adjacency matrix
for i in range(no_users):
    for j in range(no_users):
        if i != j:  # No self-loops
            # Check if users i and j share any modalities
            if set(user_modalities[i]) & set(user_modalities[j]):
                similarity_matrix[i, j] = 1

# Display the adjacency matrix
print("Adjacency Matrix:")
print(similarity_matrix)

# Prepare mixing matrices for each modality
#mixing_matrices, adjacency_matrices = prepare_mixing_matrices(adjacency_matrix, similarity_matrices)
adjacency_matrix = create_random_topology(20, similarity_matrix, edge_probability=0.3)
print(adjacency_matrix)

In [None]:
def sinkhorn_knopp(matrix, tol=1e-9, max_iter=1000):
    """
    Converts a given matrix to a doubly stochastic matrix using the Sinkhorn-Knopp algorithm.
    
    Parameters:
        matrix (np.ndarray): The input matrix to be transformed.
        tol (float): The tolerance for convergence.
        max_iter (int): Maximum number of iterations for convergence.
    
    Returns:
        np.ndarray: A doubly stochastic matrix.
    """
    matrix = matrix.copy()
    for _ in range(max_iter):
        # Normalize rows
        row_sums = matrix.sum(axis=1, keepdims=True)
        matrix /= row_sums

        # Normalize columns
        col_sums = matrix.sum(axis=0, keepdims=True)
        matrix /= col_sums

        # Check for convergence
        if np.allclose(matrix.sum(axis=1), 1, atol=tol) and np.allclose(matrix.sum(axis=0), 1, atol=tol):
            break

    return matrix
    
def create_random_topology(num_users, similarity_matrix, edge_probability=0.3):
    """
    Creates a connected random topology using NetworkX.
    Returns the adjacency matrix.
    """
    while True:
        graph = nx.erdos_renyi_graph(num_users, edge_probability)
        adjacency_matrix = nx.to_numpy_array(graph)
        new_adj = np.multiply(adjacency_matrix, similarity_matrix)
        new_graph = nx.from_numpy_array(new_adj)
        if nx.is_connected(new_graph):
            break

    # Convert graph to adjacency matrix
    adjacency_matrix = nx.to_numpy_array(new_graph)
    return adjacency_matrix

def prepare_mixing_matrices(adjacency_matrix, similarity_matrices):
    """
    Computes a mixing matrix for each modality by multiplying the adjacency matrix 
    with the similarity matrix for that modality.
    Returns a dictionary of mixing matrices.
    """
    adjacency_matrices = {}
    mixing_matrices = {}
    for modality, similarity_matrix in similarity_matrices.items():
        # Element-wise multiplication of adjacency and similarity matrices
        combined_matrix = adjacency_matrix * similarity_matrix
        adjacency_matrices[modality] = combined_matrix
        
        # Normalize to create a doubly matrix
        mixing_matrix = sinkhorn_knopp(combined_matrix)
        
        
        mixing_matrices[modality] = mixing_matrix
    
    return mixing_matrices, adjacency_matrices




In [None]:
# Draw the graph
# Define colors for the groups
group_colors = {1: 'red', 2: 'green', 3: 'blue'}
node_colors = [group_colors[group] for group in user_groups]
G = nx.from_numpy_array(similarity_matrix)
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, edge_color='gray', node_size=1000, node_color=node_colors, font_size=20, font_color='black')

# Show the plot
plt.show()

In [None]:
# Draw the graph





G = nx.from_numpy_array(adjacency_matrix)
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, edge_color='gray', node_size=1000, node_color=node_colors, font_size=20, font_color='black')

# Show the plot
plt.show()
print(user_groups)

In [None]:
# Similarity matrices
adj_per_modality = {}
for modality in available_modalities:
    adj = np.zeros((no_users, no_users))
    for node in range(no_users):
        for neighbor in G.neighbors(node):
            if modality in user_modalities[neighbor] and modality in user_modalities[node]:
                adj[node, neighbor] = 1.    
    adj_per_modality[modality] = adj


In [None]:
G_modality = nx.from_numpy_array(adj_per_modality["images"])
pos = nx.spring_layout(G_modality)
nx.draw(G_modality, pos, with_labels=True, edge_color='gray', node_size=1000, node_color=node_colors, font_size=20, font_color='black')

# Show the plot
plt.show()

In [None]:
def construct_mixing_matrix(Adj, method="metropolis"):
    n = Adj.shape[0]
    W = np.zeros((n, n))  # Initialize weight matrix

    for i in range(n):
        degree_i = np.sum(Adj[i, :])

        for j in range(n):
            if Adj[i, j] == 1.0:
                degree_j = np.sum(Adj[j, :])
    
                if method == "metropolis":
                    W[i, j] = 1 / (max(degree_i, degree_j) + 1)
                elif method == "uniform":
                    W[i, j] = 1 / degree_i

        # Diagonal weight
        W[i, i] = 1 - np.sum(W[i, :])

    return W

mixing_matrices = {}
for modality in available_modalities:
    mixing_matrices[modality] = construct_mixing_matrix(adj_per_modality[modality], method="metropolis")
    print(np.sum(mixing_matrices[modality], 0))
    print(np.sum(mixing_matrices[modality], 1))
    lamb = np.linalg.eigvals(mixing_matrices[modality])
    lamb.sort()
    print(lamb)

In [None]:
G_modality = nx.from_numpy_array(adj_per_modality["pos_height"])
pos = nx.spring_layout(G_modality)
largest_cc = max(nx.connected_components(G_modality), key=len)

# Convert to sorted list of indices
connected_nodes = sorted(largest_cc)

# Extract the submatrix corresponding to the connected subgraph
W_reduced = mixing_matrices["pos_height"][np.ix_(connected_nodes, connected_nodes)]
lamb = np.linalg.eigvals(W_reduced)
lamb.sort()
print(lamb)

In [None]:
lr = 1e-2

input_size_audio = 4  # Example size, modify according to your data
input_size_visual = 224  # Example size, modify according to your data
hidden_size = 256  # Hidden size of the network
output_size = 64  # Number of output classes
z_dim = 100  # Dimensionality of the generator's input noise
optimizers = []
all_models = []

classifier_optimizers = []
for user_id in range(no_users):
    if "images" in user_modalities[user_id] and "pos_height" in user_modalities[user_id]:
        user_model = MultiModalNetwork(
    input_size_audio=input_size_audio,
    input_size_visual=input_size_visual,
    hidden_size=hidden_size,
    output_size=output_size,
    z_dim=z_dim
).to(device)
    elif "pos_height" in user_modalities[user_id]:
        user_model = MultiModalNetwork(
    input_size_audio=input_size_audio,
    input_size_visual=None,
    hidden_size=hidden_size,
    output_size=output_size,
    z_dim=z_dim
).to(device)
    elif "images" in user_modalities[user_id]:
        user_model = MultiModalNetwork(
    input_size_audio=None,
    input_size_visual=input_size_visual,
    hidden_size=hidden_size,
    output_size=output_size,
    z_dim=z_dim
).to(device)
    local_optimizer = optim.Adam(user_model.parameters(), lr=lr)

    all_models.append(user_model)
    optimizers.append(local_optimizer)
base_models = MultiModalNetwork(
    input_size_audio=input_size_audio,
    input_size_visual=input_size_visual,
    hidden_size=hidden_size,
    output_size=output_size,
    z_dim=z_dim
).to(device)

In [None]:
import torch

def per_modelaity_decentralized_aggregation(user_models, mixing_matrices, available_modalities, user_modalities):
    num_users = len(user_models)
    with torch.no_grad():
        for modality in available_modalities:
            mixing_matrix = mixing_matrices[modality]
            
            # Collect model parameters for feature extractors and classifiers
            aggregated_feature_extractors = []
            aggregated_common_classifiers = []
            aggregated_specific_classifiers = []
            
            updates_feature_extractors = []
            updates_common_classifiers = []
            updates_specific_classifiers = []
            
            for user_id, user_model in enumerate(user_models):
                if modality in user_modalities[user_id]:
                    # Extract feature extractor and classifier parameters
                    feature_extractor = user_model.audio_feature_extractor if modality == 'pos_height' else user_model.visual_feature_extractor
                    specific_classifier = user_model.audio_specific_classifier if modality == 'pos_height' else user_model.visual_specific_classifier
                    common_classifier = user_model.common_classifier
                    # Convert to vector
                    aggregated_feature_extractors.append(torch.nn.utils.parameters_to_vector(feature_extractor.parameters()))
                    aggregated_common_classifiers.append(torch.nn.utils.parameters_to_vector(common_classifier.parameters()))
                    aggregated_specific_classifiers.append(torch.nn.utils.parameters_to_vector(specific_classifier.parameters()))
                    
                    # Initialize update vectors
                    updates_feature_extractors.append(torch.zeros_like(aggregated_feature_extractors[-1]))
                    updates_common_classifiers.append(torch.zeros_like(aggregated_common_classifiers[-1]))
                    updates_specific_classifiers.append(torch.zeros_like(aggregated_specific_classifiers[-1]))
                else:
                    aggregated_feature_extractors.append(0)
                    aggregated_common_classifiers.append(0)
                    aggregated_specific_classifiers.append(0)
                    
                    updates_feature_extractors.append(0)
                    updates_common_classifiers.append(0)
                    updates_specific_classifiers.append(0)
            
            # Aggregate models using the mixing matrix
            for i in range(num_users):
                for j in range(num_users):
                    if mixing_matrix[i, j] > 0:
                        updates_feature_extractors[i] += mixing_matrix[i, j] * aggregated_feature_extractors[j]
                        updates_common_classifiers[i] += mixing_matrix[i, j] * aggregated_common_classifiers[j]
                        updates_specific_classifiers[i] += mixing_matrix[i, j] * aggregated_specific_classifiers[j]
            
            # Update user models with aggregated parameters
            for user_id in range(num_users):
                if modality in user_modalities[user_id]:
                    feature_extractor = user_models[user_id].audio_feature_extractor if modality == 'pos_height' else user_models[user_id].visual_feature_extractor
                    common_classifier = user_models[user_id].common_classifier
                    specific_classifier = user_models[user_id].audio_specific_classifier if modality == 'pos_height' else user_models[user_id].visual_specific_classifier
                    
                    torch.nn.utils.vector_to_parameters(updates_feature_extractors[user_id], feature_extractor.parameters())
                    torch.nn.utils.vector_to_parameters(updates_common_classifiers[user_id], common_classifier.parameters())
                    torch.nn.utils.vector_to_parameters(updates_specific_classifiers[user_id], specific_classifier.parameters())


In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

def train_local_model(local_modalities, model, data_loader, optimizer, num_epochs=10, alpha0=1.0, alpha1=1.0, alpha2=1.0, alpha3=1.0, alpha_gen=1.0, alpha_kd=1.0, device="cuda"):
    """
    Trains a local multi-modal model using decentralized modalities and computes accuracy.

    Args:
        model (MultiModalNetwork): Multi-modal classification model.
        data_loaders (dict): Dictionary containing modality-specific data loaders.
            Example: {"audio": audio_loader, "visual": visual_loader, "labels": labels_loader, "z": z_loader}
        optimizer (torch.optim.Optimizer): Optimizer for the model parameters.
        num_epochs (int): Number of training epochs.
        alpha1, alpha2, alpha_gen, alpha_kd (float): Loss weighting factors.
        device (str): Device to use ("cuda" or "cpu").

    Returns:
        training_losses (list): List of training losses per epoch.
        training_accuracies (list): List of training accuracies per epoch.
    """
    model.to(device)
    model.train()

    # Freeze feature extractors but keep classifiers trainable
    if hasattr(model, "audio_feature_extractor"):
        for param in model.audio_feature_extractor.parameters():
            param.requires_grad = False
    if hasattr(model, "visual_feature_extractor"):
        for param in model.visual_feature_extractor.parameters():
            param.requires_grad = False

    training_losses = []
    training_accuracies = []
    print(num_epochs)
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        for batch in data_loader:
            inputs, labels = batch
            z = torch.randn(labels.shape[0], 100).to(device)
            # Prepare input data for selected modalities
            modality_inputs = {mod: inputs[mod].to(device) for mod in local_modalities}
            # Assuming 'inputs' is a dictionary containing the data for all modalities
            modality_inputs = {mod: inputs[mod].to(device) for mod in local_modalities}
            
            # Initialize inputs as None
            audio_input = None
            visual_input = None
            
            # Split based on available modalities
            if "pos_height" in local_modalities:
                audio_input = modality_inputs["pos_height"]
            if "images" in local_modalities:
                visual_input = modality_inputs["images"]
            labels = labels.to(device)
            optimizer.zero_grad()  # Zero the gradients

            # Move inputs to device
            audio_input = audio_input.to(device) if audio_input is not None else None
            visual_input = visual_input.to(device) if visual_input is not None else None
            labels = labels.to(device)
            z = z.to(device) if z is not None else None
            final_pred = model(audio_input=audio_input, visual_input=visual_input, z=z)


            # Forward pass and compute loss
            loss_outputs = model.compute_loss(audio_input, visual_input, labels, z, alpha0, alpha1, alpha2, alpha3, alpha_gen, alpha_kd)
            total_loss, classification_loss, similarity_loss, auxiliary_loss, difference_loss, generation_loss, kd_loss = loss_outputs

            #print(f"Total Loss: {total_loss}")
            #print(f"Classification Loss: {classification_loss}")
            #print(f"Similarity Loss: {similarity_loss}")
            #print(f"Auxiliary Loss: {auxiliary_loss}")
            #print(f"Difference Loss: {difference_loss}")
            #print(f"Generation Loss: {generation_loss}")
            #print(f"Knowledge Distillation Loss: {kd_loss}")

            # Backpropagate the loss
            total_loss.backward()
            optimizer.step()  # Update the model parameters

            running_loss += total_loss.item()

            # Compute accuracy
            with torch.no_grad():
                predictions,_ = model(audio_input, visual_input, z)
                _, predicted_labels = torch.max(predictions, dim=1)
                correct_predictions += (predicted_labels == labels).sum().item()
                total_samples += labels.size(0)

        # Compute epoch loss and accuracy
        avg_loss = running_loss / len(data_loader)
        accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0

        training_losses.append(avg_loss)
        training_accuracies.append(accuracy)

        # Print epoch results
        print(f"Epoch [{epoch + 1}/{num_epochs}] - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

    return training_losses, training_accuracies


In [None]:
import torch

def validate_user_models(user_id, model, data_loader, local_modalities, alpha0=1.0, alpha1=1.0, alpha2=1.0, alpha3=1.0, alpha_gen=1.0, alpha_kd=1.0, device="cuda"):
    """
    Validates a trained multi-modal model using the data from different modalities.

    Args:
        user_id (int): User identifier.
        model (MultiModalNetwork): Multi-modal classification model.
        data_loader (DataLoader): DataLoader for the validation set.
        criterion (nn.CrossEntropyLoss): Loss function.
        local_modalities (list): Modalities to use (e.g., ['audio', 'visual']).
        device (torch.device): Device (CPU/GPU).

    Returns:
        dict: Validation loss and accuracy.
    """
    model.to(device)
    model.eval()

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch
            z = torch.randn(labels.shape[0], 100).to(device)
            
            # Prepare input data for selected modalities
            modality_inputs = {mod: inputs[mod].to(device) for mod in local_modalities}
            
            # Initialize inputs as None
            audio_input = None
            visual_input = None
            
            # Split based on available modalities
            if "pos_height" in local_modalities:
                audio_input = modality_inputs["pos_height"]
            if "images" in local_modalities:
                visual_input = modality_inputs["images"]
            
            labels = labels.to(device)
            
            # Forward pass
            outputs, _ = model(audio_input=audio_input, visual_input=visual_input, z=z)
            loss_outputs = model.compute_loss(audio_input, visual_input, labels, z, alpha0, alpha1, alpha2, alpha3, alpha_gen, alpha_kd)
            loss, classification_loss, similarity_loss, auxiliary_loss, difference_loss, generation_loss, kd_loss = loss_outputs
            
            # Accumulate metrics
            total_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, dim=1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)

    # Compute average loss and accuracy
    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0

    print(f"User {user_id + 1} - Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

    return {"loss": avg_loss, "accuracy": accuracy}


In [None]:
def cross_entropy_loss_with_l2(model, logits, targets, l2_strength=1e-3):
    criterion = nn.CrossEntropyLoss()
    loss = criterion(logits, targets)
    l2_reg = sum(param.pow(2).sum() for param in model.parameters())
    
    return loss + l2_strength * l2_reg 

In [None]:
import numpy as np
import torch
import torch.nn as nn

# Hyperparameters
lambda_reg = 0.01
eta = 0.001
alpha = 1e-1
num_epochs=1
alpha0 = 1e-2
alpha1=alpha
alpha2=alpha * 1e-3
alpha3 = alpha 
alpha_gen=alpha
alpha_kd=alpha
beta = 1.0
# Dictionaries to store metrics
group_train_loss_histories = {1: [], 2: [], 3: []}
group_train_accuracy_histories = {1: [], 2: [], 3: []}
group_val_loss_histories = {1: [], 2: [], 3: []}
group_val_accuracy_histories = {1: [], 2: [], 3: []}

global_rounds = 100
local_epochs = 1

# Decentralized Federated Learning Loop
for round_num in range(global_rounds):
    print(f"Global Round {round_num + 1}")

    # Decentralized aggregation step
    per_modelaity_decentralized_aggregation(all_models, mixing_matrices, available_modalities, user_modalities)

    # Temporary storage for this round
    epoch_group_train_losses = {1: [], 2: [], 3: []}
    epoch_group_train_accuracies = {1: [], 2: [], 3: []}
    epoch_group_val_losses = {1: [], 2: [], 3: []}
    epoch_group_val_accuracies = {1: [], 2: [], 3: []}

    # Training phase
    for user_id in range(no_users):
        print(f"Training model for User {user_id + 1}")
        user_models = all_models[user_id]
        group = user_groups[user_id]

        # Train local model for the user's available modalities
        train_loss, train_accuracy = train_local_model(user_modalities[user_id], user_models, train_loaders[user_id], optimizers[user_id], num_epochs, alpha0, alpha1, alpha2, alpha3, alpha_gen, alpha_kd, device)

        # Store in group-wise metrics
        epoch_group_train_losses[group].append(train_loss)
        epoch_group_train_accuracies
    # Validation phase
    for user_id in range(no_users):
        user_models = all_models[user_id]
        val_dict = validate_user_models(
            user_id, user_models, 
            val_loaders[user_id], 
            user_modalities[user_id], 
            alpha0, alpha1, alpha2, alpha3, alpha_gen, 
            alpha_kd, device)
        group = user_groups[user_id]
        epoch_group_val_losses[group].append(val_dict["loss"])
        epoch_group_val_accuracies[group].append(val_dict["accuracy"])

    # Store final metrics for each group
    for group in [1, 2, 3]:
        group_train_loss_histories[group].append(epoch_group_train_losses[group])
        group_train_accuracy_histories[group].append(epoch_group_train_accuracies[group])
        group_val_loss_histories[group].append(epoch_group_val_losses[group])
        group_val_accuracy_histories[group].append(epoch_group_val_accuracies[group])

    # Print final results for this round
    print(f"---- Global Round {round_num + 1} Metrics ----")
    for group in [1, 2, 3]:
        print(f"  Group {group} - Train Loss: {np.mean(group_train_loss_histories[group][-1]):.4f}, Train Accuracy: {np.mean(group_train_accuracy_histories[group][-1]):.4f}")
        print(f"  Group {group} - Val Loss: {np.mean(group_val_loss_histories[group][-1]):.4f}, Val Accuracy: {np.mean(group_val_accuracy_histories[group][-1]):.4f}")

        

In [None]:
num_epochs = 25#global_rounds
# Convert metrics to numpy arrays for easy manipulation
group_train_loss_histories = {k: np.array(v) for k, v in group_train_loss_histories.items()}
group_train_accuracy_histories = {k: np.array(v) for k, v in group_train_accuracy_histories.items()}
group_val_loss_histories = {k: np.array(v) for k, v in group_val_loss_histories.items()}
group_val_accuracy_histories = {k: np.array(v) for k, v in group_val_accuracy_histories.items()}

# Handle potential one-dimensional arrays
group_train_loss_mean = {k: v.mean(axis=1) if v.ndim > 1 else v for k, v in group_train_loss_histories.items()}
group_train_loss_std = {k: v.std(axis=1) if v.ndim > 1 else np.zeros_like(v) for k, v in group_train_loss_histories.items()}
group_val_loss_mean = {k: v.mean(axis=1) if v.ndim > 1 else v for k, v in group_val_loss_histories.items()}
group_val_loss_std = {k: v.std(axis=1) if v.ndim > 1 else np.zeros_like(v) for k, v in group_val_loss_histories.items()}

group_train_acc_mean = {k: v.mean(axis=1) if v.ndim > 1 else v for k, v in group_train_accuracy_histories.items()}
group_train_acc_std = {k: v.std(axis=1) if v.ndim > 1 else np.zeros_like(v) for k, v in group_train_accuracy_histories.items()}
group_val_acc_mean = {k: v.mean(axis=1) if v.ndim > 1 else v for k, v in group_val_accuracy_histories.items()}
group_val_acc_std = {k: v.std(axis=1) if v.ndim > 1 else np.zeros_like(v) for k, v in group_val_accuracy_histories.items()}

# Combined Plot for All Modalities
plt.figure(figsize=(12, 6))

# Loss Plot
plt.subplot(1, 2, 1)
for group in [1, 2, 3]:
    plt.plot(range(1, num_epochs + 1), group_train_loss_mean[group], label=f"Group {group} Train Loss")
    plt.fill_between(range(1, num_epochs + 1), 
                     group_train_loss_mean[group] - group_train_loss_std[group], 
                     group_train_loss_mean[group] + group_train_loss_std[group], 
                     alpha=0.2)
    plt.plot(range(1, num_epochs + 1), group_val_loss_mean[group], label=f"Group {group} Validation Loss", linestyle="dashed")
    plt.fill_between(range(1, num_epochs + 1), 
                     group_val_loss_mean[group] - group_val_loss_std[group], 
                     group_val_loss_mean[group] + group_val_loss_std[group], 
                     alpha=0.2)

plt.title("Loss over Epochs for All Groups")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()

# Accuracy Plot
plt.subplot(1, 2, 2)
for group in [1, 2, 3]:
    plt.plot(range(1, num_epochs + 1), group_train_acc_mean[group], label=f"Group {group} Train Accuracy")
    plt.fill_between(range(1, num_epochs + 1), 
                     group_train_acc_mean[group] - group_train_acc_std[group], 
                     group_train_acc_mean[group] + group_train_acc_std[group], 
                     alpha=0.2)
    plt.plot(range(1, num_epochs + 1), group_val_acc_mean[group], label=f"Group {group} Validation Accuracy", linestyle="dashed")
    plt.fill_between(range(1, num_epochs + 1), 
                     group_val_acc_mean[group] - group_val_acc_std[group], 
                     group_val_acc_mean[group] + group_val_acc_std[group], 
                     alpha=0.2)

plt.title("Accuracy over Epochs for All Groups")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim([0, 1])  # Ensure y-axis is between 0 and 1 for accuracy
plt.legend()
plt.grid()

plt.tight_layout()
plt.show()


In [None]:
import json

# Convert numpy arrays to lists for serialization
data_to_save = {
    "group_train_loss_mean": {k: v.tolist() for k, v in group_train_loss_mean.items()},
    "group_train_loss_std": {k: v.tolist() for k, v in group_train_loss_std.items()},
    "group_val_loss_mean": {k: v.tolist() for k, v in group_val_loss_mean.items()},
    "group_val_loss_std": {k: v.tolist() for k, v in group_val_loss_std.items()},
    "group_train_acc_mean": {k: v.tolist() for k, v in group_train_acc_mean.items()},
    "group_train_acc_std": {k: v.tolist() for k, v in group_train_acc_std.items()},
    "group_val_acc_mean": {k: v.tolist() for k, v in group_val_acc_mean.items()},
    "group_val_acc_std": {k: v.tolist() for k, v in group_val_acc_std.items()}
}

with open("KL_metrics_IID.json", "w") as f:
    json.dump(data_to_save, f)
print("KL_metrics_IID.json")