In [3]:
import os
import cv2
import numpy as np
import copy
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [4]:
class MicroExpressionDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.samples = []
        self.label_encoding = {'anger': 0, 'disgust': 1, 'happiness': 2}  # Update as needed
        
        for category in sorted(os.listdir(directory)):
            class_dir = os.path.join(directory, category)
            if os.path.isdir(class_dir):
                for file in os.listdir(class_dir):
                    if file.endswith('.avi'):
                        self.samples.append((os.path.join(class_dir, file), category))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path, category = self.samples[idx]
        frames = self.extract_frames(file_path)
        
        if self.transform:
            # Apply transformations to each frame
            frames = np.stack([self.transform(frame) for frame in frames])
        
        label = self.encode_label(category)
        return frames, label

    def extract_frames(self, video_path):
        cap = cv2.VideoCapture(video_path)
        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))  # Convert BGR to RGB
        cap.release()
        frames_processed = np.array(frames)
        return frames_processed

    def encode_label(self, category):
        return self.label_encoding[category]

# Define the transformations
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create datasets
train_dataset = MicroExpressionDataset(directory='train', transform=transform)
test_dataset = MicroExpressionDataset(directory='test', transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [5]:
# Define the 3D CNN model architecture
class MicroExpressionCNN(nn.Module):
    def __init__(self, num_classes):
        super(MicroExpressionCNN, self).__init__()
        # Example layers (you'll need to adjust in_channels, out_channels, kernel_size, etc.)
        self.conv1 = nn.Conv3d(in_channels=3, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=1)
        self.pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)
        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=1, padding=1)
        # Add additional layers as needed...
        self.fc1 = nn.Linear(128 * 4 * 4 * 4, 512)  # Adjust the dimensions to match the output after flattening
        self.fc2 = nn.Linear(512, num_classes)
        
    def forward(self, x):
        # Forward pass through the network
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # Add additional layers as needed...
        x = x.view(-1, 128 * 4 * 4 * 4)  # Flatten the tensor for the fully connected layer
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
# Define the FedAvg algorithm function
def federated_averaging(global_model, client_models):
    global_state_dict = global_model.state_dict()
    averaged_state_dict = {key: torch.zeros_like(value) for key, value in global_state_dict.items()}
    
    # Aggregate the parameters from all models
    for model in client_models:
        local_state_dict = model.state_dict()
        for key, value in local_state_dict.items():
            averaged_state_dict[key] += value / len(client_models)
    
    # Load the averaged parameters back into the global model
    global_model.load_state_dict(averaged_state_dict)
    return global_model

In [None]:
# Initialize the global model
# Initialize the global model
num_classes = 3
global_model = MicroExpressionCNN(num_classes=num_classes)

# Assume we have defined num_rounds and num_clients
num_rounds = 1  # For example, we might want to train for 10 communication rounds
num_clients = 1  # Assuming we have 5 clients
num_local_epochs = 3

# Assume we have a list of dataloaders for each client
client_dataloaders = [...]  # A list of PyTorch DataLoader instances for each client

# Training loop
for round in range(num_rounds):
    print(f"Starting round {round+1}/{num_rounds}")
    client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]
    
    for client_model, train_loader in zip(client_models, client_dataloaders):
        optimizer = optim.SGD(client_model.parameters(), lr=0.01, momentum=0.9)  # Define the optimizer
        criterion = nn.CrossEntropyLoss()  # Define the loss function
        
        client_model.train()  # Set model to training mode
        train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)

        for epoch in range(num_local_epochs):
            for data, target in train_loader:
                optimizer.zero_grad()
                output = client_model(data)  # Forward pass
                loss = criterion(output, target)  # Compute the loss
                loss.backward()  # Backward pass
                optimizer.step()  # Update parameters
                print(f"Client model loss: {loss.item()}")
                pass
    
    # Update the global model using a federated average of the client models
    global_model = federated_averaging(global_model, client_models)