In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Multi-Modal Learning Model
class MultiModalNetwork(nn.Module):
    def __init__(self, input_size_audio=None, input_size_visual=None, hidden_size=128, output_size=6):
        super(MultiModalNetwork, self).__init__()
        
        self.has_audio = input_size_audio is not None
        self.has_visual = input_size_visual is not None
        
        if self.has_audio:
            self.audio_feature_extractor = nn.Linear(input_size_audio, hidden_size)
            self.specific_audio_classifier = nn.Linear(hidden_size, output_size)
        
        if self.has_visual:
            self.visual_feature_extractor = nn.Linear(input_size_visual, hidden_size)
            self.specific_visual_classifier = nn.Linear(hidden_size, output_size)
        
        if self.has_audio and self.has_visual:
            self.common_classifier = nn.Linear(hidden_size, output_size)
        
    def forward(self, audio_input=None, visual_input=None):
        audio_features = self.audio_feature_extractor(audio_input) if self.has_audio and audio_input is not None else None
        visual_features = self.visual_feature_extractor(visual_input) if self.has_visual and visual_input is not None else None

        common_features = None
        if self.has_audio and self.has_visual and audio_features is not None and visual_features is not None:
            common_features = (audio_features + visual_features) / 2
            specific_audio_features = audio_features - common_features
            specific_visual_features = visual_features - common_features
        else:
            specific_audio_features = audio_features
            specific_visual_features = visual_features
        
        common_pred = self.common_classifier(common_features) if common_features is not None else None
        audio_pred = self.specific_audio_classifier(specific_audio_features) if self.has_audio and specific_audio_features is not None else None
        visual_pred = self.specific_visual_classifier(specific_visual_features) if self.has_visual and specific_visual_features is not None else None
        
        return common_pred, audio_pred, visual_pred, audio_features, visual_features, common_features, specific_audio_features, specific_visual_features

def similarity_loss(common_features, common_classifier, mask):
    loss = 0.0
    count = 0.0
    for modality_set in mask:
        if len(modality_set) < 2:
            continue
        modality_list = list(modality_set)
        print(modality_list)
        Dk = common_features[modality_list[0]].shape[0]
        all_pairs = [(m1, m2) for i, m1 in enumerate(modality_list) for m2 in modality_list[i + 1:]]
        count += len(all_pairs) * Dk

        for m1, m2 in all_pairs:
            for d in range(Dk):
                h_m1 = common_features[m1][d].unsqueeze(0)
                h_m2 = common_features[m2][d].unsqueeze(0)
                z1 = F.softmax(common_classifier(h_m1), dim=1)
                z2 = F.softmax(common_classifier(h_m2), dim=1)
                loss += F.kl_div(z1.log(), z2, reduction='sum')

    if count > 0:
        loss = loss / count
    else:
        loss = torch.tensor(0.0, device=loss.device)

    return loss

def auxiliary_classification_loss(common_features, common_classifier, labels, mask):
    loss = 0.0
    count = 0.0
    for modality_set in mask:
        modality_list = list(modality_set)
        Dk = common_features[modality_list[0]].shape[0]
        count += len(modality_list) * Dk

        for m in modality_list:
            for d in range(Dk):
                h_m = common_features[m][d].unsqueeze(0)
                y_m = labels[m][d].unsqueeze(0)
                pred = F.softmax(common_classifier(h_m), dim=1)
                loss += F.cross_entropy(pred, y_m)

    if count > 0:
        loss = loss / count
    else:
        loss = torch.tensor(0.0, device=loss.device)

    return loss

def difference_loss(common_features, specific_features, mask):
    loss = 0.0
    for modality_set in mask:
        for m in modality_set:
            if common_features[m] is not None and specific_features[m] is not None:
                loss += torch.norm(torch.matmul(common_features[m].T, specific_features[m]), p='fro')**2
    return loss

def feature_decomposition_loss(common_features, specific_features, common_classifier, labels, mask, alpha1, alpha2, alpha3):
    sim_loss = similarity_loss(common_features, common_classifier, mask)
    cls_loss = auxiliary_classification_loss(common_features, common_classifier, labels, mask)
    diff_loss = difference_loss(common_features, specific_features, mask)
    total_loss = alpha1 * sim_loss + alpha2 * cls_loss + alpha3 * diff_loss
    return total_loss

# Local Training Function
def train_local(model, devices, num_epochs=5, alpha1=1.0, alpha2=1.0, alpha3=1.0):
    for device_id, device_data in devices.items():
        local_model = model
        optimizer = optim.Adam(local_model.parameters(), lr=0.001)
        
        for epoch in range(num_epochs):
            inputs_audio, inputs_visual, labels = device_data['data']
            optimizer.zero_grad()
            
            common_pred, audio_pred, visual_pred, audio_features, visual_features, common_features, specific_audio_features, specific_visual_features = local_model(inputs_audio, inputs_visual)
            print("Shape of common_pred:", common_pred.shape)
            print("Shape of audio_pred:", audio_pred.shape)
            print("Shape of visual_pred:", visual_pred.shape)
            print("Shape of audio_features:", audio_features.shape)
            print("Shape of visual_features:", visual_features.shape)
            print("Shape of common_features:", common_features.shape)
            print("Shape of specific_audio_features:", specific_audio_features.shape)
            print("Shape of specific_visual_features:", specific_visual_features.shape)
            print()
            loss = 0
            mask = []
            if audio_features is not None:
                if visual_features is not None:
                    mask.append({'audio', 'visual'})
                else:
                    mask.append({'audio'})
            elif visual_features is not None:
                mask.append({'visual'})

            features = {"audio": audio_features, "visual": visual_features}
            specifics = {"audio": specific_audio_features, "visual": specific_visual_features}
            labels_dict = {"audio": labels['audio'], "visual": labels['visual']}

            # Compute loss
            feature_loss = feature_decomposition_loss(common_features, specifics, local_model.common_classifier, labels_dict, mask, alpha1, alpha2, alpha3)
            total_loss = feature_loss
            total_loss.backward()
            
            optimizer.step()
            
        # After training on the device, print final model loss (optional)
        print(f"Device {device_id} training completed")




In [45]:
# Initialize Model Parameters
input_size_audio = 256  # Example size, modify according to your data
input_size_visual = 256  # Example size, modify according to your data
hidden_size = 128  # Hidden size of the network
output_size = 6  # Number of output classes

# Initialize Models
model = MultiModalNetwork(input_size_audio=input_size_audio, input_size_visual=input_size_visual, hidden_size=hidden_size, output_size=output_size)

# Setup Devices (simulating multiple devices in a decentralized setup)
# This should be a dictionary where each device has its local data and epochs.
# Example setup:
devices = {
    0: {'data': (torch.randn(100, input_size_audio), torch.randn(100, input_size_visual), {'audio': torch.randint(0, 6, (100,)), 'visual': torch.randint(0, 6, (100,))}), 'local_epochs': 5},
    1: {'data': (torch.randn(100, input_size_audio), torch.randn(100, input_size_visual), {'audio': torch.randint(0, 6, (100,)), 'visual': torch.randint(0, 6, (100,))}), 'local_epochs': 5},
    2: {'data': (torch.randn(100, input_size_audio), torch.randn(100, input_size_visual), {'audio': torch.randint(0, 6, (100,)), 'visual': torch.randint(0, 6, (100,))}), 'local_epochs': 5},
}


# Training Hyperparameters
num_epochs = 5  # Number of local training epochs per device

# Train the model locally on each device
train_local(model=model, devices=devices, num_epochs=num_epochs)

# The model is now trained locally on each device without aggregation or knowledge distillation

Shape of common_pred: torch.Size([100, 6])
Shape of audio_pred: torch.Size([100, 6])
Shape of visual_pred: torch.Size([100, 6])
Shape of audio_features: torch.Size([100, 128])
Shape of visual_features: torch.Size([100, 128])
Shape of common_features: torch.Size([100, 128])
Shape of specific_audio_features: torch.Size([100, 128])
Shape of specific_visual_features: torch.Size([100, 128])

['visual', 'audio']


IndexError: too many indices for tensor of dimension 2

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiModalNetwork(nn.Module):
    def __init__(self, input_size_audio=None, input_size_visual=None, hidden_size=128, output_size=6):
        super(MultiModalNetwork, self).__init__()
        
        self.has_audio = input_size_audio is not None
        self.has_visual = input_size_visual is not None
        
        if self.has_audio:
            # Audio feature extractor
            self.audio_feature_extractor = nn.Linear(input_size_audio, hidden_size)
            # Split audio output into common and specific features
            self.audio_common_classifier = nn.Linear(hidden_size // 2, output_size)  # Common classifier
            self.audio_specific_classifier = nn.Linear(hidden_size // 2, output_size)  # Specific classifier
        
        if self.has_visual:
            # Visual feature extractor
            self.visual_feature_extractor = nn.Linear(input_size_visual, hidden_size)
            # Split visual output into common and specific features
            self.visual_common_classifier = nn.Linear(hidden_size // 2, output_size)  # Common classifier
            self.visual_specific_classifier = nn.Linear(hidden_size // 2, output_size)  # Specific classifier
        
        # Common classifier shared by both modalities
        self.common_classifier = nn.Linear(hidden_size // 2, output_size)  # Same size for common features
    
    def forward(self, audio_input=None, visual_input=None):
        audio_features = self.audio_feature_extractor(audio_input) if self.has_audio and audio_input is not None else None
        visual_features = self.visual_feature_extractor(visual_input) if self.has_visual and visual_input is not None else None

        common_audio_features = None
        common_visual_features = None
        specific_audio_features = None
        specific_visual_features = None
        
        if self.has_audio and audio_features is not None:
            # Split audio features into common and specific parts
            common_audio_features = audio_features[:, :audio_features.size(1) // 2]
            specific_audio_features = audio_features[:, audio_features.size(1) // 2:]
            audio_common_pred = self.audio_common_classifier(common_audio_features)
            audio_specific_pred = self.audio_specific_classifier(specific_audio_features)
        
        if self.has_visual and visual_features is not None:
            # Split visual features into common and specific parts
            common_visual_features = visual_features[:, :visual_features.size(1) // 2]
            specific_visual_features = visual_features[:, visual_features.size(1) // 2:]
            visual_common_pred = self.visual_common_classifier(common_visual_features)
            visual_specific_pred = self.visual_specific_classifier(specific_visual_features)

        # Combine the common and specific predictions with a bias term
        final_audio_pred = audio_common_pred + audio_specific_pred if audio_features is not None else None
        final_visual_pred = visual_common_pred + visual_specific_pred if visual_features is not None else None
        
        # If both modalities are present, combine the final predictions
        combined_pred = None
        if final_audio_pred is not None and final_visual_pred is not None:
            combined_pred = final_audio_pred + final_visual_pred
        elif final_audio_pred is not None:
            combined_pred = final_audio_pred
        elif final_visual_pred is not None:
            combined_pred = final_visual_pred
        
        return combined_pred, audio_common_pred, audio_specific_pred, visual_common_pred, visual_specific_pred

# Initialize Model Parameters
input_size_audio = 256  # Example size, modify according to your data
input_size_visual = 256  # Example size, modify according to your data
hidden_size = 128  # Hidden size of the network
output_size = 6  # Number of output classes

# Initialize the MultiModalNetwork model
model = MultiModalNetwork(input_size_audio=input_size_audio, input_size_visual=input_size_visual, hidden_size=hidden_size, output_size=output_size)

# Example input data
audio_input = torch.randn(10, input_size_audio)  # Batch of 10 audio samples
visual_input = torch.randn(10, input_size_visual)  # Batch of 10 visual samples

# Forward pass through the model
combined_pred, audio_common_pred, audio_specific_pred, visual_common_pred, visual_specific_pred = model(audio_input, visual_input)

print("Combined Prediction: ", combined_pred)
print("Audio Common Prediction: ", audio_common_pred)
print("Audio Specific Prediction: ", audio_specific_pred)
print("Visual Common Prediction: ", visual_common_pred)
print("Visual Specific Prediction: ", visual_specific_pred)

print("Shape of Combined Prediction:", combined_pred.shape)
print("Shape of Audio Common Prediction:", audio_common_pred.shape)
print("Shape of Audio Specific Prediction:", audio_specific_pred.shape)
print("Shape of Visual Common Prediction:", visual_common_pred.shape)
print("Shape of Visual Specific Prediction:", visual_specific_pred.shape)


Combined Prediction:  tensor([[-0.3139,  0.8427, -0.2166, -0.2099,  0.0227, -0.2840],
        [ 0.2369,  0.7433, -0.9447,  0.8359, -0.3473,  0.7672],
        [-0.5883,  0.7150, -0.7838,  0.4420,  0.7502,  0.1397],
        [-0.2176,  0.3576, -1.3419,  0.4045, -0.7248, -0.1803],
        [ 0.9697, -0.0764,  1.4621, -1.8019, -0.8623, -0.1169],
        [ 0.2682,  1.1872,  0.6554,  0.1557,  1.2513,  0.4830],
        [-0.0960, -0.5458, -0.7141, -0.6288, -1.0434,  0.5657],
        [ 0.5358,  0.6907, -0.9732, -0.7303, -1.0202, -0.6421],
        [-0.3643,  0.6754, -0.4350, -0.2614,  0.7190, -0.8152],
        [ 0.3813,  0.5324,  1.5932, -0.2552, -0.5837, -0.2788]],
       grad_fn=<AddBackward0>)
Audio Common Prediction:  tensor([[-3.9049e-01,  3.7859e-01, -1.0550e-01,  4.4970e-02,  5.3332e-02,
         -1.0326e-01],
        [ 2.9309e-01, -8.4598e-02, -6.6774e-01,  1.5795e-01,  1.8329e-01,
         -8.3928e-02],
        [-1.5160e-01, -1.9393e-01, -7.3705e-01, -1.5110e-02, -9.0311e-02,
         -8.

In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiModalNetwork(nn.Module):
    def __init__(self, input_size_audio=None, input_size_visual=None, hidden_size=128, output_size=6):
        super(MultiModalNetwork, self).__init__()
        
        self.has_audio = input_size_audio is not None
        self.has_visual = input_size_visual is not None
        
        if self.has_audio:
            # Audio feature extractor
            self.audio_feature_extractor = nn.Linear(input_size_audio, hidden_size)
            # Split audio output into common and specific features
            self.audio_common_classifier = nn.Linear(hidden_size // 2, output_size)  # Common classifier
            self.audio_specific_classifier = nn.Linear(hidden_size // 2, output_size)  # Specific classifier
        
        if self.has_visual:
            # Visual feature extractor
            self.visual_feature_extractor = nn.Linear(input_size_visual, hidden_size)
            # Split visual output into common and specific features
            self.visual_common_classifier = nn.Linear(hidden_size // 2, output_size)  # Common classifier
            self.visual_specific_classifier = nn.Linear(hidden_size // 2, output_size)  # Specific classifier
        
        # Common classifier shared by both modalities
        self.common_classifier = nn.Linear(hidden_size // 2, output_size)  # Same size for common features

    def forward(self, audio_input=None, visual_input=None):
        audio_features = self.audio_feature_extractor(audio_input) if self.has_audio and audio_input is not None else None
        visual_features = self.visual_feature_extractor(visual_input) if self.has_visual and visual_input is not None else None

        common_audio_features = None
        common_visual_features = None
        specific_audio_features = None
        specific_visual_features = None
        
        if self.has_audio and audio_features is not None:
            # Split audio features into common and specific parts
            common_audio_features = audio_features[:, :audio_features.size(1) // 2]
            specific_audio_features = audio_features[:, audio_features.size(1) // 2:]
            audio_common_pred = self.audio_common_classifier(common_audio_features)
            audio_specific_pred = self.audio_specific_classifier(specific_audio_features)
        else:
            audio_common_pred = None
            audio_specific_pred = None
        
        if self.has_visual and visual_features is not None:
            # Split visual features into common and specific parts
            common_visual_features = visual_features[:, :visual_features.size(1) // 2]
            specific_visual_features = visual_features[:, visual_features.size(1) // 2:]
            visual_common_pred = self.visual_common_classifier(common_visual_features)
            visual_specific_pred = self.visual_specific_classifier(specific_visual_features)
        else:
            visual_common_pred = None
            visual_specific_pred = None

        # Combine the common and specific predictions with a bias term
        final_audio_pred = audio_common_pred + audio_specific_pred if audio_common_pred is not None and audio_specific_pred is not None else None
        final_visual_pred = visual_common_pred + visual_specific_pred if visual_common_pred is not None and visual_specific_pred is not None else None
        
        # If both modalities are present, combine the final predictions
        combined_pred = None
        if final_audio_pred is not None and final_visual_pred is not None:
            combined_pred = final_audio_pred + final_visual_pred
        elif final_audio_pred is not None:
            combined_pred = final_audio_pred
        elif final_visual_pred is not None:
            combined_pred = final_visual_pred
        
        return combined_pred, audio_common_pred, audio_specific_pred, visual_common_pred, visual_specific_pred, common_audio_features, common_visual_features, specific_audio_features, specific_visual_features

    def compute_loss(self, audio_input, visual_input, labels, alpha1=1.0, alpha2=1.0, alpha3=1.0):
        # Forward pass
        combined_pred, audio_common_pred, audio_specific_pred, visual_common_pred, visual_specific_pred, common_audio_features, common_visual_features, specific_audio_features, specific_visual_features = self(audio_input, visual_input)

        # Initialize losses to zero
        similarity_loss = 0.0
        auxiliary_loss = 0.0
        difference_loss = 0.0

        # 1) Similarity Loss (F_sim_k)
        if common_audio_features is not None and common_visual_features is not None:
            kl_loss_audio = self.compute_kl_divergence(common_audio_features, common_visual_features)
            similarity_loss = kl_loss_audio / 2  # Normalized by the number of modalities

        # 2) Auxiliary Classification Loss (F_cls_k)
        if common_audio_features is not None:
            auxiliary_loss += self.compute_auxiliary_classification_loss(common_audio_features, labels)
        if common_visual_features is not None:
            auxiliary_loss += self.compute_auxiliary_classification_loss(common_visual_features, labels)

        # 3) Difference Loss (F_dif_k)
        if common_audio_features is not None and specific_audio_features is not None:
            difference_loss += self.compute_difference_loss(common_audio_features, specific_audio_features)
        if common_visual_features is not None and specific_visual_features is not None:
            difference_loss += self.compute_difference_loss(common_visual_features, specific_visual_features)

        # 4) Total Loss (F_dec_k)
        total_loss = alpha1 * similarity_loss + alpha2 * auxiliary_loss + alpha3 * difference_loss
        return total_loss, similarity_loss, auxiliary_loss, difference_loss

    def compute_kl_divergence(self, common_audio_features, common_visual_features):
        # Apply softmax to features and compute KL divergence
        softmax_audio = F.softmax(common_audio_features, dim=-1)
        softmax_visual = F.softmax(common_visual_features, dim=-1)
        kl_divergence = F.kl_div(softmax_audio.log(), softmax_visual, reduction='batchmean')
        return kl_divergence

    def compute_auxiliary_classification_loss(self, common_features, labels):
        # Cross-entropy loss for auxiliary classification
        return F.cross_entropy(self.common_classifier(common_features), labels)

    def compute_difference_loss(self, common_features, specific_features):
        # Orthogonality loss to ensure modality-common and modality-specific features are distinct
        return torch.norm(torch.matmul(common_features.T, specific_features), p='fro')**2


# Initialize Model Parameters
input_size_audio = 256  # Example size, modify according to your data
input_size_visual = 256  # Example size, modify according to your data
hidden_size = 128  # Hidden size of the network
output_size = 6  # Number of output classes

# Initialize the MultiModalNetwork model
model = MultiModalNetwork(input_size_audio=input_size_audio, input_size_visual=input_size_visual, hidden_size=hidden_size, output_size=output_size)

# Example input data
audio_input = torch.randn(10, input_size_audio)  # Batch of 10 audio samples
visual_input = torch.randn(10, input_size_visual)  # Batch of 10 visual samples
labels = torch.randint(0, output_size, (10,))  # Random labels for the batch

# Compute the total loss with both modalities present
total_loss, similarity_loss, auxiliary_loss, difference_loss = model.compute_loss(audio_input, visual_input, labels)

print("Total Loss:", total_loss)
print("Similarity Loss:", similarity_loss)
print("Auxiliary Classification Loss:", auxiliary_loss)
print("Difference Loss:", difference_loss)

# Compute the total loss with only the audio modality
total_loss_audio, similarity_loss_audio, auxiliary_loss_audio, difference_loss_audio = model.compute_loss(audio_input, None, labels)

print("\nWhen only audio is present:")
print("Total Loss:", total_loss_audio)
print("Similarity Loss:", similarity_loss_audio)
print("Auxiliary Classification Loss:", auxiliary_loss_audio)
print("Difference Loss:", difference_loss_audio)

# Compute the total loss with only the visual modality
total_loss_visual, similarity_loss_visual, auxiliary_loss_visual, difference_loss_visual = model.compute_loss(None, visual_input, labels)

print("\nWhen only visual is present:")
print("Total Loss:", total_loss_visual)
print("Similarity Loss:", similarity_loss_visual)
print("Auxiliary Classification Loss:", auxiliary_loss_visual)
print("Difference Loss:", difference_loss_visual)


Total Loss: tensor(9519.3359, grad_fn=<AddBackward0>)
Similarity Loss: tensor(0.1491, grad_fn=<DivBackward0>)
Auxiliary Classification Loss: tensor(3.4934, grad_fn=<AddBackward0>)
Difference Loss: tensor(9515.6934, grad_fn=<AddBackward0>)

When only audio is present:
Total Loss: tensor(4791.9634, grad_fn=<AddBackward0>)
Similarity Loss: 0.0
Auxiliary Classification Loss: tensor(1.7387, grad_fn=<AddBackward0>)
Difference Loss: tensor(4790.2246, grad_fn=<AddBackward0>)

When only visual is present:
Total Loss: tensor(4727.2241, grad_fn=<AddBackward0>)
Similarity Loss: 0.0
Auxiliary Classification Loss: tensor(1.7547, grad_fn=<AddBackward0>)
Difference Loss: tensor(4725.4692, grad_fn=<AddBackward0>)
