#  Improving Gating Network with Attention Mechanism

## Overview
In this notebook, we explore an **improved approach** to the **Gating Network** by integrating an **attention block** into the CNN architecture. The addition of the **Self-Attention** mechanism is aimed at enhancing the gating process, allowing the network to **focus on relevant features** more effectively when deciding which expert (in our MoE setup) should be activated for a given input.

### Key Enhancements:
- **Gating Network**: Traditionally, the gating network uses a CNN to process input features and generate a distribution over the experts. The goal is to assign the appropriate expert based on the input data.
- **Self-Attention Block**: By adding an attention mechanism within the CNN, we allow the model to **learn contextual relationships** between features. This helps the network focus on **important regions** of the input image and **improves performance** when assigning experts in complex scenarios.
- **Benefits of Attention**:
  - **Better Feature Selection**: The attention mechanism highlights the most relevant parts of the input, allowing the gating network to make more informed decisions.
  - **Improved Context Understanding**: Attention helps the model learn long-range dependencies within the image, improving the performance of the gating network, especially in challenging scenarios.

This notebook aims to assess how the integration of attention improves the **efficiency and accuracy** of the gating network, ultimately enhancing the overall performance of the **Mixture of Experts (MoE)** model.


In [1]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split


### Fire Scenario Dataset Preparation

We define a custom `FireScenarioDataset` class to load fire scenario classification images and their corresponding labels. Each image is preprocessed by resizing to 224×224 and converting to a tensor. The dataset is split into training and validation sets using an 80/20 ratio. PyTorch `DataLoader`s are then used for batching and shuffling during training and evaluation.


In [2]:
class FireScenarioDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        image_path = os.path.join(self.images_dir, image_name)
        label_path = os.path.join(self.labels_dir, os.path.splitext(image_name)[0] + '.txt')

        image = Image.open(image_path).convert("RGB")
        with open(label_path, 'r') as f:
            label = int(f.read().strip())

        if self.transform:
            image = self.transform(image)

        return image, label


In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = FireScenarioDataset(
    images_dir="DataSets/scenario_classifier_dataset/images",
    labels_dir="DataSets/scenario_classifier_dataset/labels",
    transform=transform
)

# Train-validation split
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)


###  Gating Network with Self-Attention

We define a `GatingCNNWithAttention` model to classify fire images into four high-level scenarios (e.g., Indoor, Outdoor, Far-field, Satellite). The model consists of convolutional layers followed by a **self-attention block** to capture spatial dependencies. The final output logits represent the class probabilities used as gating weights in the Mixture of Experts (MoE) system.

Key components:
- **Self-Attention Module**: Enhances feature representation by computing long-range dependencies across spatial positions.
- **Convolutional Backbone**: Stack of Conv-BatchNorm-ReLU-MaxPool layers for feature extraction.
- **Global Average Pooling**: Reduces spatial dimensions before classification.


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(dim, dim // 8, 1)
        self.key = nn.Conv2d(dim, dim // 8, 1)
        self.value = nn.Conv2d(dim, dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))  # Learnable scaling

    def forward(self, x):
        batch_size, C, width, height = x.size()

        proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)  # B x N x C'
        proj_key = self.key(x).view(batch_size, -1, width * height)  # B x C' x N
        energy = torch.bmm(proj_query, proj_key)  # B x N x N
        attention = F.softmax(energy, dim=-1)

        proj_value = self.value(x).view(batch_size, -1, width * height)  # B x C x N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # B x C x N
        out = out.view(batch_size, C, width, height)

        out = self.gamma * out + x  # Residual connection
        return out

class GatingCNNWithAttention(nn.Module):
    def __init__(self, num_classes=4):
        super(GatingCNNWithAttention, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),  # [B, 32, 224, 224]
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # [B, 32, 112, 112]

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # [B, 64, 112, 112]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # [B, 64, 56, 56]

            SelfAttention(64),  # Add Attention here!

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # [B, 128, 56, 56]
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))  # Global Average Pooling [B, 128, 1, 1]
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, num_classes)  # Final output for 4 scenarios
        )

    def forward(self, x):
        x = self.features(x)
        logits = self.classifier(x)
        return logits  # raw logits (we apply softmax separately when needed)


###  Training the Gating Network

We train the `GatingCNNWithAttention` using cross-entropy loss and the Adam optimizer. The dataset is split into training and validation sets (80/20), and the model is trained for 20 epochs. Accuracy is computed at each epoch for both training and validation sets to monitor performance.

**Training Details**:
- **Loss Function**: CrossEntropyLoss
- **Optimizer**: Adam (learning rate = 1e-3)
- **Batch Size**: 32
- **Epochs**: 20
- **Device**: Automatically selects GPU if available


In [7]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Assume your dataset is already defined: `dataset`

# Train-validation split
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

# Initialize model, optimizer, loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GatingCNNWithAttention(num_classes=4).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)  # raw logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    train_acc = correct / total
    avg_loss = running_loss / len(train_loader)

    # Validation
    model.eval()
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, preds = outputs.max(1)
            val_correct += preds.eq(labels).sum().item()
            val_total += labels.size(0)

    val_acc = val_correct / val_total

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f} - Train Acc: {train_acc:.4f} - Val Acc: {val_acc:.4f}")

print("✅ Training finished!")


Epoch 1/20 - Loss: 0.6966 - Train Acc: 0.7500 - Val Acc: 0.8013
Epoch 2/20 - Loss: 0.4436 - Train Acc: 0.8250 - Val Acc: 0.8375
Epoch 3/20 - Loss: 0.3480 - Train Acc: 0.8775 - Val Acc: 0.8750
Epoch 4/20 - Loss: 0.2958 - Train Acc: 0.8947 - Val Acc: 0.7325
Epoch 5/20 - Loss: 0.2430 - Train Acc: 0.9169 - Val Acc: 0.8838
Epoch 6/20 - Loss: 0.2178 - Train Acc: 0.9250 - Val Acc: 0.9062
Epoch 7/20 - Loss: 0.2142 - Train Acc: 0.9228 - Val Acc: 0.9187
Epoch 8/20 - Loss: 0.1871 - Train Acc: 0.9334 - Val Acc: 0.9175
Epoch 9/20 - Loss: 0.1795 - Train Acc: 0.9403 - Val Acc: 0.9237
Epoch 10/20 - Loss: 0.1791 - Train Acc: 0.9353 - Val Acc: 0.9050
Epoch 11/20 - Loss: 0.1674 - Train Acc: 0.9406 - Val Acc: 0.9450
Epoch 12/20 - Loss: 0.1626 - Train Acc: 0.9413 - Val Acc: 0.9175
Epoch 13/20 - Loss: 0.1650 - Train Acc: 0.9416 - Val Acc: 0.8788
Epoch 14/20 - Loss: 0.1540 - Train Acc: 0.9463 - Val Acc: 0.9175
Epoch 15/20 - Loss: 0.1423 - Train Acc: 0.9500 - Val Acc: 0.9250
Epoch 16/20 - Loss: 0.1442 - Train

In [8]:
torch.save(model.state_dict(), "improved_gating_cnn.pth")
print("✅ Gating CNN saved!")

✅ Gating CNN saved!
