# MTUNet++ Architecture Overview

## High-Level Architecture
1. Input Processing:
   - Query image and support set images are processed through a CNN backbone
   - Feature maps (Fmap ∈ ℝa×h×w) are extracted for each image

2. Main Components:
   - CNN Backbone (Modified ResNet-18)
   - Pattern Extractor Module
   - Pairwise Matching Module (MLP-based)

## Pattern Extractor Module Architecture
1. Feature Processing:
   - 1×1 convolution layer followed by ReLU activation
   - Dimensionality reduction from a to b
   - Flattening: Fmap' ∈ ℝb×v (v = h×w)
   - Integration with learnable positional embedding (Pl)

2. Attention Mechanism:
   - Iterative process (R times)
   - Uses Gated Recurrent Unit with Skip Connections (GRUsc)
   - Pattern updates: K(r+1) = GRUsc(W(r), K(r))
   - Attention calculation through normalization function

3. Pattern Processing:
   - Self-attention mechanism over spatial dimensions
   - Dot-product similarity calculation
   - Pattern refinement using GRUsc
   - Attention map adjustment using Hadamard product

## Configuration Details

### CNN Backbone (ResNet-18 Modifications):
- Removed first two downsampling layers
- First conv layer: 7×7 → 3×3
- Output feature maps: 512
- Fixed parameters during training

### Pattern Extractor Module:
- GRUsc hidden dimension: 256
- Update iterations: 3
- Number of patterns: 7
- Networks gq and gM: 3 fully connected layers with ReLU

### Training Configuration:
1. Initial Phase:
   - Learning rate: 10⁻⁴
   - Rate reduction: 10× at epoch 40
   - Total epochs: 150

2. Fine-tuning Phase:
   - CNN and pattern extractor learning rate: 10⁻⁵
   - 20 iterations
   - 500 episodes per epoch for 2-way tasks
   - Other components: Initial learning rate 10⁻⁴
   - Rate reduction: 10× at epoch 10

### Implementation Details:
- Framework: PyTorch
- Optimizer: AdaBelief
- Input image size: 80×80
- Data augmentation: Random flipping and affine transformations
- Evaluation: 2000 episodes of 2-way classification
- Support images: N = 5 or 10
- Query images: 15 per class

## Training Process Flow
1. Task-based training of backbone CNN
2. Independent training of attention module
3. Training of few-shot classifier
4. Model selection based on validation performance (2,000 episodes)

## Mathematical Formulations
1. Feature Extraction:
   - Fmap = f𝜙(x) ∈ ℝa×h×w

2. Pattern Attention:
   - Att = fpe(Fmap) ∈ ℝu×v

3. Similarity Scoring:
   - score(Oq, Om) = σ(f𝜃([Oq, Om]))

4. Classification:
   - m* = argmax_m score(Oq, Om)

flowchart TD
    subgraph Input
        QI[Query Image]
        SS[Support Set Images]
    end

    subgraph CNN["CNN Backbone (ResNet-18)"]
        F1[Feature Extraction]
    end

    subgraph PE["Pattern Extractor Module"]
        C1[1x1 Conv + ReLU]
        FT[Flatten Operation]
        PE1[Positional Embedding]
        AT1[Self-Attention]
        GRU[GRU with Skip Connections]
        AGG[Attention Aggregation]
        AP[Average Pooling]
    end

    subgraph PM["Pairwise Matching"]
        CON[Feature Concatenation]
        MLP[Multi-Layer Perceptron]
        SC[Similarity Score]
    end

    %% Main flow
    QI --> F1
    SS --> F1
    F1 --> |Fmap ∈ ℝa×h×w| C1
    C1 --> |Reduced Dim| FT
    FT --> |Fmap' ∈ ℝb×v| PE1
    PE1 --> AT1
    AT1 --> |Att'| GRU
    GRU --> |K(r+1)| AT1
    GRU --> |Final Attention| AGG
    AGG --> |Att''| AP
    AP --> |O| CON
    CON --> MLP
    MLP --> |score| SC

    %% Iterative loop
    AT1 --> |R iterations| AT1

# MTUNet++ Data Flow Process

## 1. Input Processing
- **Query Image (xq)**: Single image for classification
- **Support Set (Ds)**: Set of labeled images for comparison
  - M classes with N images per class
  - Total: M×N support images

## 2. Feature Extraction (CNN Stage)
1. **Input → Feature Maps**
   - CNN processes both query and support images
   - Output: Fmap = f𝜙(x) ∈ ℝa×h×w
   - Uses modified ResNet-18 backbone

## 3. Pattern Extractor Flow
1. **Dimensionality Reduction**
   - Input: Fmap ∈ ℝa×h×w
   - 1×1 convolution + ReLU
   - Output: Reduced dimension from a to b

2. **Spatial Processing**
   - Flatten operation: Fmap' ∈ ℝb×v (v = h×w)
   - Add positional embedding: Fmap' = Fmap' + Pl

3. **Attention Mechanism (Iterative Process)**
   - Input: Flattened features
   - Pattern Generation:
     1. Calculate similarity: gq(K(r)) gM(Fmap')
     2. Apply normalization: Att(r) = 𝜚(Att'(r))
     3. Update patterns: K(r+1) = GRUsc(W(r), K(r))
   - Iterations: R times
   - Output: Final attention maps

4. **Feature Aggregation**
   - Aggregate attention: Att'' = 1/u × Att(r)
   - Average pooling: O = 1/(h×w) × ∑Att''ij × Fmapij

## 4. Pairwise Matching Flow
1. **Feature Processing**
   - Query features: Oq
   - Support features: Om (averaged if N > 1)

2. **Similarity Computation**
   - Concatenate features: [Oq, Om]
   - MLP processing: f𝜃([Oq, Om])
   - Output: similarity score

3. **Classification**
   - Compare scores across all support classes
   - Select class with highest similarity score
   - Final output: predicted class m*

## Data Dimensions at Key Points
1. Initial Features: ℝa×h×w
2. Reduced Features: ℝb×v
3. Attention Maps: ℝu×v
4. Final Features: ℝb
5. Similarity Scores: ℝM (M = number of classes)

## Key Transformations
1. **Spatial → Pattern Space**
   - Feature maps → Pattern attention
   - Dimension: (a×h×w) → (u×v)

2. **Pattern → Classification Space**
   - Pattern features → Similarity scores
   - Dimension: (b) → (M)

In [None]:
# This code installs the necessary Python packages for the project.
# - torch: A deep learning framework for building and training neural networks.
# - torchvision: A package that provides datasets, model architectures, and image transformations for computer vision.
# - tqdm: A library for adding progress bars to loops.
# - pillow: A library for image processing.
# - adabelief-pytorch: An optimizer that combines the benefits of Adam and RMSProp optimizers.
!pip install torch torchvision tqdm pillow
!pip install adabelief-pytorch

In [None]:
"""
This script sets up the necessary imports for a PyTorch-based deep learning project. 
It includes the following libraries and modules:

- torch: The main PyTorch library.
- torch.nn: A sub-library containing neural network layers and functions.
- torch.nn.functional: A sub-library containing functional interfaces for neural network layers.
- torchvision.models: Pre-trained models provided by the torchvision library.
- torchvision.transforms: Common image transformations provided by the torchvision library.
- torch.utils.data: Utilities for data handling, including Dataset and DataLoader classes.
- numpy: A library for numerical operations.
- PIL (Python Imaging Library): A library for opening, manipulating, and saving images.
- os: A module for interacting with the operating system.
- random: A module for generating random numbers.
- tqdm: A library for creating progress bars.
- adabelief_pytorch: The AdaBelief optimizer for PyTorch.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
import random
from tqdm import tqdm
from adabelief_pytorch import AdaBelief

In [None]:
"""
ModifiedResNet18 is a custom neural network model based on ResNet-18 with the following modifications:
- The first convolutional layer is modified to have a kernel size of 3 and a stride of 2.
- The first two residual layers (layer1 and layer2) are modified to remove downsampling by setting their stride to 1.
- All parameters are frozen to prevent them from being updated during training.

Attributes:
    conv1 (nn.Conv2d): Modified first convolutional layer.
    bn1 (nn.BatchNorm2d): Batch normalization layer from the original ResNet-18.
    relu (nn.ReLU): ReLU activation function from the original ResNet-18.
    layer1 (nn.Sequential): Modified first residual layer with no downsampling.
    layer2 (nn.Sequential): Modified second residual layer with no downsampling.
    layer3 (nn.Sequential): Third residual layer from the original ResNet-18.
    layer4 (nn.Sequential): Fourth residual layer from the original ResNet-18.

Methods:
    _modify_layer(layer, stride):
        Modifies the given residual layer to set the stride of the first convolutional layer and the downsample layer (if present) to the specified stride.
        
    forward(x):
        Defines the forward pass of the network.
        Args:
            x (torch.Tensor): Input tensor of shape [B, 3, H, W].
        Returns:
            torch.Tensor: Output tensor of shape [B, 512, H/16, W/16].
"""

class ModifiedResNet18(nn.Module):
    def __init__(self):
        super(ModifiedResNet18, self).__init__()
        resnet = models.resnet18(pretrained=True)
        
        # Modified first conv with stride 2
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        
        # Modify layer1 and layer2 to remove downsampling
        self.layer1 = self._modify_layer(resnet.layer1, stride=1)
        self.layer2 = self._modify_layer(resnet.layer2, stride=1)
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        # Freeze parameters
        for param in self.parameters():
            param.requires_grad = False
    
    def _modify_layer(self, layer, stride):
        for block in layer:
            block.conv1.stride = (stride, stride)
            block.conv2.stride = (1, 1)
            if block.downsample is not None:
                block.downsample[0].stride = (stride, stride)
        return layer
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        return x  # Output: [B, 512, H/16, W/16]


In [None]:
"""
GRUWithSkipConnection is a custom neural network module that combines a GRU cell with a skip connection.

Attributes:
    gru (nn.GRUCell): A GRU cell that processes the input sequence.
    skip_proj (nn.Linear): A linear layer that projects the input to the hidden dimension for the skip connection.

Methods:
    __init__(input_dim, hidden_dim):
        Initializes the GRUWithSkipConnection module with the specified input and hidden dimensions.
    
    forward(x, h):
        Performs a forward pass through the GRU cell and adds a skip connection.
        
        Args:
            x (Tensor): The input tensor of shape (batch_size, input_dim).
            h (Tensor): The hidden state tensor of shape (batch_size, hidden_dim).
        
        Returns:
            Tensor: The output tensor of shape (batch_size, hidden_dim) after applying the GRU cell and skip connection.
"""
class GRUWithSkipConnection(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GRUWithSkipConnection, self).__init__()
        self.gru = nn.GRUCell(input_dim, hidden_dim)
        self.skip_proj = nn.Linear(input_dim, hidden_dim)
        
    def forward(self, x, h):
        h_new = self.gru(x, h)
        skip = self.skip_proj(x)
        return h_new + skip

In [None]:
"""
PatternExtractor is a neural network module designed to extract patterns from input feature maps using iterative self-attention and GRU-based refinement.

Attributes:
    hidden_dim (int): The dimensionality of the hidden layers.
    num_patterns (int): The number of patterns to extract.
    num_iterations (int): The number of iterations for pattern refinement.
    conv1x1 (nn.Conv2d): A 1x1 convolutional layer for initial feature processing.
    positional_embedding (nn.Parameter): A learnable positional embedding added to the input features.
    attention_gate (nn.Parameter): A learnable parameter for attention gating.
    grusc (GRUWithSkipConnection): A GRU module with skip connections for pattern refinement.
    pattern_init (nn.Sequential): A network for initializing patterns.
    query_net (nn.Sequential): A network for generating query vectors for self-attention.
    key_net (nn.Sequential): A network for generating key vectors for self-attention.

Methods:
    forward(x):
        Forward pass of the PatternExtractor.
        
        Args:
            x (torch.Tensor): Input feature map of shape [B, in_channels, H, W].
        
        Returns:
            patterns (torch.Tensor): Extracted patterns of shape [B, num_patterns, hidden_dim].
            attn (torch.Tensor): Attention maps of shape [B, num_patterns, H, W].
"""

class PatternExtractor(nn.Module):
    def __init__(self, in_channels=512, hidden_dim=256, num_patterns=7, num_iterations=3):
        super(PatternExtractor, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_patterns = num_patterns
        self.num_iterations = num_iterations
        
        # Initial feature processing
        self.conv1x1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
        self.positional_embedding = nn.Parameter(torch.randn(1, hidden_dim, 1, 1))
        
        # Attention gating parameter
        self.attention_gate = nn.Parameter(torch.ones(1))
        
        # GRU with skip connections
        self.grusc = GRUWithSkipConnection(hidden_dim, hidden_dim)
        
        # Pattern networks
        self.pattern_init = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_patterns * hidden_dim)
        )
        
        # Attention networks
        self.query_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.key_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, x):
        batch_size = x.size(0)
        
        # Initial feature processing
        x = self.conv1x1(x)  # [B, hidden_dim, H, W]
        x = x + self.positional_embedding
        h, w = x.shape[-2:]
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, hidden_dim]
        
        # Initialize patterns
        patterns = self.pattern_init(x.mean(1))  # [B, num_patterns * hidden_dim]
        patterns = patterns.view(batch_size, self.num_patterns, self.hidden_dim)
        
        # Iterative pattern refinement
        h_state = torch.zeros(batch_size * self.num_patterns, self.hidden_dim).to(x.device)
        
        for _ in range(self.num_iterations):
            # Self-attention
            q = self.query_net(patterns.reshape(-1, self.hidden_dim))
            k = self.key_net(x.reshape(-1, self.hidden_dim))
            
            # Compute attention scores
            attn = torch.matmul(q.view(batch_size, self.num_patterns, -1),
                              k.view(batch_size, -1, self.hidden_dim).transpose(1, 2))
            attn = F.softmax(attn / (self.hidden_dim ** 0.5), dim=-1)
            
            # Apply attention gating
            attn = attn * torch.sigmoid(self.attention_gate)
            
            # Update patterns
            context = torch.bmm(attn, x)  # [B, num_patterns, hidden_dim]
            
            # GRU update
            context_flat = context.reshape(-1, self.hidden_dim)
            h_state = self.grusc(context_flat, h_state)
            patterns = h_state.view(batch_size, self.num_patterns, self.hidden_dim)
        
        return patterns, attn.view(batch_size, self.num_patterns, h, w)


In [None]:
"""
PairwiseMatchingModule is a neural network module designed for pairwise matching of query and support features.

Args:
    hidden_dim (int): The dimension of the hidden layers in the matching network.

Attributes:
    matching_net (nn.Sequential): A sequential neural network consisting of linear layers and ReLU activations, 
                                  which processes the concatenated query and support features to produce a matching score.

Methods:
    forward(query_features, support_features):
        Computes the matching scores between query features and support features.

        Args:
            query_features (torch.Tensor): A tensor of shape [B, H] or [B, *, H] representing the query features.
            support_features (torch.Tensor): A tensor of shape [S, H] or [S, *, H] representing the support features.

        Returns:
            torch.Tensor: A tensor of shape [B, S] containing the matching scores for each query-support pair.
"""

class PairwiseMatchingModule(nn.Module):
    def __init__(self, hidden_dim):
        super(PairwiseMatchingModule, self).__init__()
        self.matching_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, query_features, support_features):
        # Ensure both inputs are 2D
        if query_features.dim() > 2:
            query_features = query_features.mean(1)  # Average across any extra dimensions
        if support_features.dim() > 2:
            support_features = support_features.mean(1)  # Average across any extra dimensions
            
        batch_size = query_features.size(0)
        
        # Reshape for pairwise comparison
        query_expanded = query_features.unsqueeze(1)  # [B, 1, H]
        
        # Combine features
        combined = torch.cat([
            query_expanded.expand(-1, support_features.size(0), -1),  # [B, S, H]
            support_features.unsqueeze(0).expand(batch_size, -1, -1)  # [B, S, H]
        ], dim=-1)
        
        # Get scores
        scores = self.matching_net(combined.view(-1, combined.size(-1)))
        return scores.view(batch_size, -1)


In [None]:
"""
MTUNetPlusPlus is a neural network model designed for few-shot learning tasks. It consists of a backbone network for feature extraction, a pattern extractor for identifying relevant patterns, and a pairwise matching module for comparing query and support images.

Attributes:
    backbone (nn.Module): The feature extraction backbone network, here a modified ResNet18.
    pattern_extractor (PatternExtractor): Module to extract patterns from the features.
    matching_module (PairwiseMatchingModule): Module to perform pairwise matching between query and support features.

Methods:
    __init__(hidden_dim=256):
        Initializes the MTUNetPlusPlus model with the specified hidden dimension for the pattern extractor.
    
    forward(query_img, support_imgs=None, return_features=False):
        Forward pass of the model.
        
        Args:
            query_img (Tensor): The query image tensor.
            support_imgs (Tensor, optional): The support images tensor. Defaults to None.
            return_features (bool, optional): If True, returns the extracted features instead of matching scores. Defaults to False.
        
        Returns:
            If return_features is True, returns a tuple of query patterns and attention maps.
            If support_imgs is provided, returns the matching scores between query and support images.
"""

class MTUNetPlusPlus(nn.Module):
    def __init__(self, hidden_dim=256):
        super(MTUNetPlusPlus, self).__init__()
        self.backbone = ModifiedResNet18()
        self.pattern_extractor = PatternExtractor(in_channels=512, hidden_dim=hidden_dim)
        self.matching_module = PairwiseMatchingModule(hidden_dim)
        
    def forward(self, query_img, support_imgs=None, return_features=False):
        query_features = self.backbone(query_img)
        query_patterns, query_attn = self.pattern_extractor(query_features)
        
        if support_imgs is None or return_features:
            return query_patterns, query_attn
        
        support_features = self.backbone(support_imgs)
        support_patterns, _ = self.pattern_extractor(support_features)
        
        # Ensure patterns are properly averaged before matching
        query_features = query_patterns.mean(1)   # Average across patterns
        support_features = support_patterns.mean(1)  # Average across patterns
        
        scores = self.matching_module(query_features, support_features)
        
        return scores


In [None]:
"""
ModelTrainer class for training and validating a model with support and query images.

Attributes:
    model (torch.nn.Module): The model to be trained.
    train_loader (DataLoader): DataLoader for the training data.
    val_loader (DataLoader): DataLoader for the validation data.
    optimizer (torch.optim.Optimizer): Optimizer for training the model.
    device (torch.device): Device to run the model on (e.g., 'cpu' or 'cuda').
    checkpoint_dir (str): Directory to save model checkpoints.
    best_accuracy (float): Best validation accuracy achieved.
    best_epoch (int): Epoch at which the best validation accuracy was achieved.

Methods:
    train_episode(support_images, support_labels, query_images, query_labels):
        Trains the model for one episode using support and query images and labels.
        
    validate_episode(support_images, support_labels, query_images, query_labels):
        Validates the model for one episode using support and query images and labels.
        
    train_epoch(epoch):
        Trains the model for one epoch and returns the average loss and accuracy.
        
    validate():
        Validates the model on the validation set and returns the average loss and accuracy.
        
    save_checkpoint(epoch, accuracy):
        Saves a checkpoint of the model at the given epoch with the given accuracy.
        
    _adjust_learning_rates():
        Adjusts the learning rates for different parts of the model.
        
    train_phase(num_epochs, phase_name="Training"):
        Trains the model for a specified number of epochs in a given phase.
        
    train_full():
        Completes a two-phase training process: initial training and fine-tuning.
"""
class ModelTrainer:
    def __init__(self, model, train_loader, val_loader, optimizer, device,
                 checkpoint_dir='checkpoints'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.device = device
        self.checkpoint_dir = checkpoint_dir
        
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        self.best_accuracy = 0.0
        self.best_epoch = 0
    
    # Modify train_episode method
    def train_episode(self, support_images, support_labels, query_images, query_labels):
        self.model.train()
        
        # Reshape tensors to correct dimensions
        support_images = support_images.squeeze(0)  # Remove extra batch dimension
        query_images = query_images.squeeze(0)
        support_labels = support_labels.squeeze(0)
        query_labels = query_labels.squeeze(0)
        
        # Move to device
        support_images = support_images.to(self.device)
        support_labels = support_labels.to(self.device)
        query_images = query_images.to(self.device)
        query_labels = query_labels.to(self.device)
        
        self.optimizer.zero_grad()
        scores = self.model(query_images, support_images)
        loss = F.cross_entropy(scores, query_labels)
        
        loss.backward()
        self.optimizer.step()
        
        predictions = scores.max(1)[1]
        accuracy = (predictions == query_labels).float().mean()
        
        return loss.item(), accuracy.item()
        
    
    def validate_episode(self, support_images, support_labels, query_images, query_labels):
        self.model.eval()
        
        with torch.no_grad():
            support_images = support_images.to(self.device)
            support_labels = support_labels.to(self.device)
            query_images = query_images.to(self.device)
            query_labels = query_labels.to(self.device)
            
            scores = self.model(query_images, support_images)
            loss = F.cross_entropy(scores, query_labels)
            
            predictions = scores.max(1)[1]
            accuracy = (predictions == query_labels).float().mean()
        
        return loss.item(), accuracy.item()
    
    def train_epoch(self, epoch):
        total_loss = 0
        total_accuracy = 0
        
        pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
        for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in pbar:
            loss, accuracy = self.train_episode(support_imgs, support_labels, query_imgs, query_labels)
            
            total_loss += loss
            total_accuracy += accuracy
            
            pbar.set_description(f'Epoch {epoch} | Loss: {loss:.4f} | Acc: {accuracy:.4f}')
        
        return total_loss / len(self.train_loader), total_accuracy / len(self.train_loader)
    
    def validate(self):
        total_loss = 0
        total_accuracy = 0
        
        for support_imgs, support_labels, query_imgs, query_labels in tqdm(self.val_loader):
            loss, accuracy = self.validate_episode(support_imgs, support_labels, query_imgs, query_labels)
            total_loss += loss
            total_accuracy += accuracy
        
        return total_loss / len(self.val_loader), total_accuracy / len(self.val_loader)
    
    def save_checkpoint(self, epoch, accuracy):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'accuracy': accuracy
        }
        
        path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save(checkpoint, path)
        
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.best_epoch = epoch
            best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
            torch.save(checkpoint, best_path)
    
    def _adjust_learning_rates(self):
        for param_group in self.optimizer.param_groups:
            if any(p in self.model.backbone.parameters() for p in param_group['params']):
                param_group['lr'] = 1e-5
            else:
                param_group['lr'] = 1e-4
    
    def train_phase(self, num_epochs, phase_name="Training"):
        for epoch in range(num_epochs):
            print(f"\n{phase_name} - Epoch {epoch+1}/{num_epochs}")
            
            train_loss, train_acc = self.train_epoch(epoch)
            print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
            
            val_loss, val_acc = self.validate()
            print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
            
            self.save_checkpoint(epoch, val_acc)
            
            if phase_name == "Initial Training" and epoch == 40:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] *= 0.1
       
    def train_full(self):
        """Complete two-phase training process"""
        # Phase 1: Initial training
        print("Starting Phase 1: Initial Training")
        self.train_phase(num_epochs=150, phase_name="Initial Training")
        
        # Phase 2: Fine-tuning
        print("\nStarting Phase 2: Fine-tuning")
        self._adjust_learning_rates()
        self.train_phase(num_epochs=20, phase_name="Fine-tuning")
        
        print(f"\nTraining completed! Best accuracy: {self.best_accuracy:.4f} at epoch {self.best_epoch}")


In [None]:
def split_classes(root_dir, val_split=0.2, random_seed=42):
    """
    Split classes into training and validation sets.

    This function takes a root directory containing class folders, shuffles the class names, 
    and splits them into training and validation sets based on the specified validation split proportion.

    Args:
        root_dir (str): Path to the data directory containing class folders.
        val_split (float, optional): Proportion of classes to use for validation. Default is 0.2.
        random_seed (int, optional): Random seed for reproducibility. Default is 42.

    Returns:
        tuple: A tuple containing two lists:
            - train_classes (list): List of class names for training.
            - val_classes (list): List of class names for validation.
    """
    train_classes = classes[:split_idx]
    val_classes = classes[split_idx:]
    
    return train_classes, val_classes


In [None]:
class EpisodeDataset(Dataset):
    def __init__(self, root_dir, allowed_classes, transform=None, n_way=2, n_support=5, n_query=15):
        """
        Dataset class for few-shot learning episodes


            root_dir (str): Root directory containing class folders.olders.
            allowed_classes (list): List of class names this dataset can use.can use.
            transform (callable, optional): Image transformations to be applied.Image transformations to be applied.
            n_way (int, optional): Number of classes per episode. Default is 2.sses per episode. Default is 2.
            n_support (int, optional): Number of support examples per class. Default is 5.amples per class. Default is 5.
            n_query (int, optional): Number of query examples per class. Default is 15.amples per class. Default is 15.

ributes:
            root_dir (str): Root directory containing class folders.
            transform (callable, optional): Image transformations to be applied.
            n_way (int): Number of classes per episode.
            n_support (int): Number of support examples per class.
            n_query (int): Number of query examples per class.
            classes (list): List of class names this dataset can use.
            class_to_idx (dict): Dictionary mapping class names to indices.
            images_by_class (dict): Dictionary mapping class names to lists of image paths.

        Attributes:
            __len__(): Returns the number of episodes per epoch.
            __getitem__(idx): Generates one episode of data.
                Args:
                    idx (int): Index of the episode.
                Returns:
                    tuple: Tuple containing support images, support labels, query images, and query labels.
            root_dir (str): Root directory containing class folders.
        """
                support_labels.append(label)
            
            for img_path in selected_images[self.n_support:]:
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                query_images.append(image)
                query_labels.append(label)
        
        support_images = torch.stack(support_images)
        support_labels = torch.tensor(support_labels)
        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)
        
        return support_images, support_labels, query_images, query_labels


In [None]:
"""
Main function to set up and train a model using episodic training for few-shot learning.

Steps:
1. Set random seeds for reproducibility.
2. Set the device to GPU if available, otherwise CPU.
3. Define image transformations for data augmentation and normalization.
4. Split the dataset into training and validation classes.
5. Create datasets for training and validation using the specified class splits and transformations.
6. Create data loaders for the training and validation datasets.
7. Initialize the model and optimizer.
8. Create a ModelTrainer instance and start the training process.

Functions:
- split_classes(data_path): Splits the dataset into training and validation classes.
- EpisodeDataset: Custom dataset class for episodic training.
- MTUNetPlusPlus: Model architecture.
- AdaBelief: Optimizer.
- ModelTrainer: Class to handle the training process.

Usage:
Run the script to start the training process.
"""

def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((80, 80)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Split classes into train and validation sets
    data_path = '/kaggle/input/ham10000-and-gan/synthetic_images'  # Update this to your actual data path
    train_classes, val_classes = split_classes(data_path)
    
    print(f"Number of training classes: {len(train_classes)}")
    print(f"Number of validation classes: {len(val_classes)}")
    
    # Create datasets with respective class splits
    train_dataset = EpisodeDataset(
        root_dir=data_path,
        allowed_classes=train_classes,
        transform=transform,
        n_way=2,
        n_support=5,
        n_query=15
    )
    
    val_dataset = EpisodeDataset(
        root_dir=data_path,
        allowed_classes=val_classes,
        transform=transform,
        n_way=2,
        n_support=5,
        n_query=15
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
    # Rest of the main function remains the same
    model = MTUNetPlusPlus(hidden_dim=256).to(device)
    optimizer = AdaBelief(model.parameters(), lr=1e-4)
    
    trainer = ModelTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        device=device
    )
    
    trainer.train_full()

if __name__ == '__main__':
    main()