# MTUNet++ Architecture Overview

## High-Level Architecture
1. Input Processing:
   - Query image and support set images are processed through a CNN backbone
   - Feature maps (Fmap ∈ ℝa×h×w) are extracted for each image

2. Main Components:
   - CNN Backbone (Modified ResNet-18)
   - Pattern Extractor Module
   - Pairwise Matching Module (MLP-based)

## Pattern Extractor Module Architecture
1. Feature Processing:
   - 1×1 convolution layer followed by ReLU activation
   - Dimensionality reduction from a to b
   - Flattening: Fmap' ∈ ℝb×v (v = h×w)
   - Integration with learnable positional embedding (Pl)

2. Attention Mechanism:
   - Iterative process (R times)
   - Uses Gated Recurrent Unit with Skip Connections (GRUsc)
   - Pattern updates: K(r+1) = GRUsc(W(r), K(r))
   - Attention calculation through normalization function

3. Pattern Processing:
   - Self-attention mechanism over spatial dimensions
   - Dot-product similarity calculation
   - Pattern refinement using GRUsc
   - Attention map adjustment using Hadamard product

## Configuration Details

### CNN Backbone (ResNet-18 Modifications):
- Removed first two downsampling layers
- First conv layer: 7×7 → 3×3
- Output feature maps: 512
- Fixed parameters during training

### Pattern Extractor Module:
- GRUsc hidden dimension: 256
- Update iterations: 3
- Number of patterns: 7
- Networks gq and gM: 3 fully connected layers with ReLU

### Training Configuration:
1. Initial Phase:
   - Learning rate: 10⁻⁴
   - Rate reduction: 10× at epoch 40
   - Total epochs: 150

2. Fine-tuning Phase:
   - CNN and pattern extractor learning rate: 10⁻⁵
   - 20 iterations
   - 500 episodes per epoch for 2-way tasks
   - Other components: Initial learning rate 10⁻⁴
   - Rate reduction: 10× at epoch 10

### Implementation Details:
- Framework: PyTorch
- Optimizer: AdaBelief
- Input image size: 80×80
- Data augmentation: Random flipping and affine transformations
- Evaluation: 2000 episodes of 2-way classification
- Support images: N = 5 or 10
- Query images: 15 per class

## Training Process Flow
1. Task-based training of backbone CNN
2. Independent training of attention module
3. Training of few-shot classifier
4. Model selection based on validation performance (2,000 episodes)

## Mathematical Formulations
1. Feature Extraction:
   - Fmap = f𝜙(x) ∈ ℝa×h×w

2. Pattern Attention:
   - Att = fpe(Fmap) ∈ ℝu×v

3. Similarity Scoring:
   - score(Oq, Om) = σ(f𝜃([Oq, Om]))

4. Classification:
   - m* = argmax_m score(Oq, Om)

flowchart TD
    subgraph Input
        QI[Query Image]
        SS[Support Set Images]
    end

    subgraph CNN["CNN Backbone (ResNet-18)"]
        F1[Feature Extraction]
    end

    subgraph PE["Pattern Extractor Module"]
        C1[1x1 Conv + ReLU]
        FT[Flatten Operation]
        PE1[Positional Embedding]
        AT1[Self-Attention]
        GRU[GRU with Skip Connections]
        AGG[Attention Aggregation]
        AP[Average Pooling]
    end

    subgraph PM["Pairwise Matching"]
        CON[Feature Concatenation]
        MLP[Multi-Layer Perceptron]
        SC[Similarity Score]
    end

    %% Main flow
    QI --> F1
    SS --> F1
    F1 --> |Fmap ∈ ℝa×h×w| C1
    C1 --> |Reduced Dim| FT
    FT --> |Fmap' ∈ ℝb×v| PE1
    PE1 --> AT1
    AT1 --> |Att'| GRU
    GRU --> |K(r+1)| AT1
    GRU --> |Final Attention| AGG
    AGG --> |Att''| AP
    AP --> |O| CON
    CON --> MLP
    MLP --> |score| SC

    %% Iterative loop
    AT1 --> |R iterations| AT1

# MTUNet++ Data Flow Process

## 1. Input Processing
- **Query Image (xq)**: Single image for classification
- **Support Set (Ds)**: Set of labeled images for comparison
  - M classes with N images per class
  - Total: M×N support images

## 2. Feature Extraction (CNN Stage)
1. **Input → Feature Maps**
   - CNN processes both query and support images
   - Output: Fmap = f𝜙(x) ∈ ℝa×h×w
   - Uses modified ResNet-18 backbone

## 3. Pattern Extractor Flow
1. **Dimensionality Reduction**
   - Input: Fmap ∈ ℝa×h×w
   - 1×1 convolution + ReLU
   - Output: Reduced dimension from a to b

2. **Spatial Processing**
   - Flatten operation: Fmap' ∈ ℝb×v (v = h×w)
   - Add positional embedding: Fmap' = Fmap' + Pl

3. **Attention Mechanism (Iterative Process)**
   - Input: Flattened features
   - Pattern Generation:
     1. Calculate similarity: gq(K(r)) gM(Fmap')
     2. Apply normalization: Att(r) = 𝜚(Att'(r))
     3. Update patterns: K(r+1) = GRUsc(W(r), K(r))
   - Iterations: R times
   - Output: Final attention maps

4. **Feature Aggregation**
   - Aggregate attention: Att'' = 1/u × Att(r)
   - Average pooling: O = 1/(h×w) × ∑Att''ij × Fmapij

## 4. Pairwise Matching Flow
1. **Feature Processing**
   - Query features: Oq
   - Support features: Om (averaged if N > 1)

2. **Similarity Computation**
   - Concatenate features: [Oq, Om]
   - MLP processing: f𝜃([Oq, Om])
   - Output: similarity score

3. **Classification**
   - Compare scores across all support classes
   - Select class with highest similarity score
   - Final output: predicted class m*

## Data Dimensions at Key Points
1. Initial Features: ℝa×h×w
2. Reduced Features: ℝb×v
3. Attention Maps: ℝu×v
4. Final Features: ℝb
5. Similarity Scores: ℝM (M = number of classes)

## Key Transformations
1. **Spatial → Pattern Space**
   - Feature maps → Pattern attention
   - Dimension: (a×h×w) → (u×v)

2. **Pattern → Classification Space**
   - Pattern features → Similarity scores
   - Dimension: (b) → (M)