In [1]:
## classique 

import numpy as np
import torch 
import h5py # pour gérer les formats de données utilisés ici 
import torch
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

## handmade
import train
import test

On commence par charger les données d'entraînement, que l'on va plus tard utiliser pour entraîner le modèle de few-shot learning en question. 

In [None]:
# Chemin vers le fichier de données
x, y, snr = test.load_data(type="enroll")


### Few-Shot Classification

In few-shot classification we are given a small support set of $ N $ labeled examples $ S = \{(x_1, y_1), \dots, (x_N, y_N)\} $, where each $ x_i \in \mathbb{R}^D $ is the $ D $-dimensional feature vector of an example and $ y_i \in \{1, \dots, K\} $ is the corresponding label. $ S_k $ denotes the set of examples labeled with class $ k $.

### Prototypical Networks classification

Prototypical networks compute an $ M $-dimensional representation $ c_k \in \mathbb{R}^M $, or prototype, of each class through an embedding function $ f_\phi : \mathbb{R}^D \to \mathbb{R}^M $ with learnable parameters $ \phi $. Each prototype is the mean vector of the embedded support points belonging to its class:

$$
c_k = \frac{1}{|S_k|} \sum_{(x_i, y_i) \in S_k} f_\phi(x_i) \tag{1}
$$

Given a distance function $ d : \mathbb{R}^M \times \mathbb{R}^M \to [0, +\infty) $, prototypical networks produce a distribution over classes for a query point $ x $ based on a softmax over distances to the prototypes in the embedding space:

$$
p_\phi(y = k \mid x) = \frac{\exp(-d(f_\phi(x), c_k))}{\sum_{k'} \exp(-d(f_\phi(x), c_{k'}))} \tag{2}
$$

Learning proceeds by minimizing the negative log-probability:

$$
J(\phi) = -\log p_\phi(y = k \mid x)
$$

of the true class $ k $ via stochastic gradient descent (SGD). Training episodes are formed by randomly selecting a subset of classes from the training set, then choosing a subset of examples within each class to act as the support set and a subset of the remainder to serve as query points. 


In [4]:
opt = {
    'train_dataset_path': 'train.hdf5',  # Path to the dataset
    'epochs': 50,  # Number of epochs for training
    'learning_rate': 0.001,  # Learning rate for the optimizer
    'lr_scheduler_step': 20,  # Step size for the learning rate scheduler
    'lr_scheduler_gamma': 0.5,  # Multiplicative factor for learning rate decay
    'iterations': 100,  # Number of episodes per epoch
    'classes_per_it_tr': 5,  # Number of classes per training iteration
    'num_support_tr': 20,  # Number of support samples per class for training
    'num_query_tr': 20,  # Number of query samples per class for training
    'classes_per_it_val': 5,  # Number of classes per validation iteration
    'num_support_val': 5,  # Number of support samples per class for validation
    'num_query_val': 15,  # Number of query samples per class for validation
    'manual_seed': 7,  # Manual seed for reproducibility
    'device': 'cuda:0'  # Whether to use CUDA for computation
}

In [5]:
model, best_acc, train_loss, train_acc = train.main(opt)

=== Epoch: 0 ===


100%|██████████| 100/100 [00:37<00:00,  2.70it/s]


Avg Train Loss: 1.1148308420181274, Avg Train Acc: 0.4220999984443188
=== Epoch: 1 ===


100%|██████████| 100/100 [00:37<00:00,  2.65it/s]


Avg Train Loss: 0.7292481771111489, Avg Train Acc: 0.6186999994516372
=== Epoch: 2 ===


100%|██████████| 100/100 [00:36<00:00,  2.72it/s]


Avg Train Loss: 0.6008706933259964, Avg Train Acc: 0.670800002515316
=== Epoch: 3 ===


100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


Avg Train Loss: 0.5666881918907165, Avg Train Acc: 0.6677999982237816
=== Epoch: 4 ===


100%|██████████| 100/100 [00:38<00:00,  2.60it/s]


Avg Train Loss: 0.5519236138463021, Avg Train Acc: 0.6754999977350234
=== Epoch: 5 ===


100%|██████████| 100/100 [00:36<00:00,  2.71it/s]


Avg Train Loss: 0.5675782862305642, Avg Train Acc: 0.6634000021219254
=== Epoch: 6 ===


100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


Avg Train Loss: 0.5833797383308411, Avg Train Acc: 0.6606999981403351
=== Epoch: 7 ===


100%|██████████| 100/100 [00:36<00:00,  2.71it/s]


Avg Train Loss: 0.5440657445788384, Avg Train Acc: 0.6737000012397766
=== Epoch: 8 ===


100%|██████████| 100/100 [00:37<00:00,  2.64it/s]


Avg Train Loss: 0.5599497389793396, Avg Train Acc: 0.6581000021100044
=== Epoch: 9 ===


100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


Avg Train Loss: 0.5553716641664505, Avg Train Acc: 0.6569000008702278
=== Epoch: 10 ===


100%|██████████| 100/100 [00:37<00:00,  2.69it/s]


Avg Train Loss: 0.5405883860588073, Avg Train Acc: 0.6740999987721443
=== Epoch: 11 ===


100%|██████████| 100/100 [00:36<00:00,  2.75it/s]


Avg Train Loss: 0.5404969346523285, Avg Train Acc: 0.6674000036716461
=== Epoch: 12 ===


100%|██████████| 100/100 [00:37<00:00,  2.68it/s]


Avg Train Loss: 0.5496581789851188, Avg Train Acc: 0.6666999983787537
=== Epoch: 13 ===


100%|██████████| 100/100 [00:37<00:00,  2.64it/s]


Avg Train Loss: 0.5458073112368583, Avg Train Acc: 0.6735000011324882
=== Epoch: 14 ===


100%|██████████| 100/100 [00:37<00:00,  2.69it/s]


Avg Train Loss: 0.5320988896489144, Avg Train Acc: 0.676099998652935
=== Epoch: 15 ===


100%|██████████| 100/100 [00:36<00:00,  2.73it/s]


Avg Train Loss: 0.5401873514056206, Avg Train Acc: 0.6643999961018562
=== Epoch: 16 ===


100%|██████████| 100/100 [00:37<00:00,  2.65it/s]


Avg Train Loss: 0.527891592681408, Avg Train Acc: 0.6817000016570092
=== Epoch: 17 ===


100%|██████████| 100/100 [00:36<00:00,  2.71it/s]


Avg Train Loss: 0.5368800872564315, Avg Train Acc: 0.6790000009536743
=== Epoch: 18 ===


100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


Avg Train Loss: 0.5061358344554902, Avg Train Acc: 0.6876999998092651
=== Epoch: 19 ===


100%|██████████| 100/100 [00:37<00:00,  2.70it/s]


Avg Train Loss: 0.5215732455253601, Avg Train Acc: 0.6780999994277954
=== Epoch: 20 ===


100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


Avg Train Loss: 0.502420184314251, Avg Train Acc: 0.7003999984264374
=== Epoch: 21 ===


100%|██████████| 100/100 [00:36<00:00,  2.70it/s]


Avg Train Loss: 0.5092365726828575, Avg Train Acc: 0.6984999978542328
=== Epoch: 22 ===


100%|██████████| 100/100 [00:38<00:00,  2.62it/s]


Avg Train Loss: 0.48342092722654345, Avg Train Acc: 0.7255999994277954
=== Epoch: 23 ===


100%|██████████| 100/100 [00:37<00:00,  2.69it/s]


Avg Train Loss: 0.4601048603653908, Avg Train Acc: 0.7387999987602234
=== Epoch: 24 ===


100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


Avg Train Loss: 0.4814364966750145, Avg Train Acc: 0.7320000022649765
=== Epoch: 25 ===


100%|██████████| 100/100 [00:37<00:00,  2.70it/s]


Avg Train Loss: 0.4444878526031971, Avg Train Acc: 0.7650000005960464
=== Epoch: 26 ===


100%|██████████| 100/100 [00:37<00:00,  2.68it/s]


Avg Train Loss: 0.4518947121500969, Avg Train Acc: 0.7603000003099442
=== Epoch: 27 ===


100%|██████████| 100/100 [00:36<00:00,  2.75it/s]


Avg Train Loss: 0.4136989685893059, Avg Train Acc: 0.7787000018358231
=== Epoch: 28 ===


100%|██████████| 100/100 [00:37<00:00,  2.70it/s]


Avg Train Loss: 0.41361864551901817, Avg Train Acc: 0.7703000009059906
=== Epoch: 29 ===


100%|██████████| 100/100 [00:36<00:00,  2.73it/s]


Avg Train Loss: 0.41889304041862485, Avg Train Acc: 0.7658000016212463
=== Epoch: 30 ===


100%|██████████| 100/100 [00:36<00:00,  2.72it/s]


Avg Train Loss: 0.4107272346317768, Avg Train Acc: 0.7718000018596649
=== Epoch: 31 ===


100%|██████████| 100/100 [00:36<00:00,  2.72it/s]


Avg Train Loss: 0.4025590988993645, Avg Train Acc: 0.7700000029802322
=== Epoch: 32 ===


100%|██████████| 100/100 [00:37<00:00,  2.66it/s]


Avg Train Loss: 0.3872591580450535, Avg Train Acc: 0.7825000041723251
=== Epoch: 33 ===


100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


Avg Train Loss: 0.4174938052892685, Avg Train Acc: 0.7669000011682511
=== Epoch: 34 ===


100%|██████████| 100/100 [00:36<00:00,  2.70it/s]


Avg Train Loss: 0.3715371511876583, Avg Train Acc: 0.7899000012874603
=== Epoch: 35 ===


100%|██████████| 100/100 [00:38<00:00,  2.63it/s]


Avg Train Loss: 0.3908589242398739, Avg Train Acc: 0.7768000024557113
=== Epoch: 36 ===


100%|██████████| 100/100 [00:36<00:00,  2.73it/s]


Avg Train Loss: 0.39036192908883094, Avg Train Acc: 0.7824000012874603
=== Epoch: 37 ===


100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


Avg Train Loss: 0.39052356496453283, Avg Train Acc: 0.7763000005483627
=== Epoch: 38 ===


100%|██████████| 100/100 [00:37<00:00,  2.65it/s]


Avg Train Loss: 0.4012675406038761, Avg Train Acc: 0.7749999988079072
=== Epoch: 39 ===


100%|██████████| 100/100 [00:36<00:00,  2.70it/s]


Avg Train Loss: 0.3687115746736527, Avg Train Acc: 0.7902000015974044
=== Epoch: 40 ===


100%|██████████| 100/100 [00:35<00:00,  2.80it/s]


Avg Train Loss: 0.34435475662350656, Avg Train Acc: 0.8048999989032746
=== Epoch: 41 ===


100%|██████████| 100/100 [00:37<00:00,  2.69it/s]


Avg Train Loss: 0.3692013287544251, Avg Train Acc: 0.7901000010967255
=== Epoch: 42 ===


100%|██████████| 100/100 [00:37<00:00,  2.69it/s]


Avg Train Loss: 0.3696120096743107, Avg Train Acc: 0.791599999666214
=== Epoch: 43 ===


100%|██████████| 100/100 [00:38<00:00,  2.61it/s]


Avg Train Loss: 0.3705327617377043, Avg Train Acc: 0.7886000037193298
=== Epoch: 44 ===


100%|██████████| 100/100 [00:38<00:00,  2.62it/s]


Avg Train Loss: 0.3393215827643871, Avg Train Acc: 0.8112000006437302
=== Epoch: 45 ===


100%|██████████| 100/100 [00:38<00:00,  2.60it/s]


Avg Train Loss: 0.3510974821448326, Avg Train Acc: 0.7952999991178512
=== Epoch: 46 ===


100%|██████████| 100/100 [00:38<00:00,  2.62it/s]


Avg Train Loss: 0.3619718247652054, Avg Train Acc: 0.7953999984264374
=== Epoch: 47 ===


100%|██████████| 100/100 [00:38<00:00,  2.60it/s]


Avg Train Loss: 0.35193532906472685, Avg Train Acc: 0.7949000018835067
=== Epoch: 48 ===


100%|██████████| 100/100 [00:55<00:00,  1.81it/s]


Avg Train Loss: 0.34706170186400415, Avg Train Acc: 0.8003999996185303
=== Epoch: 49 ===


100%|██████████| 100/100 [00:45<00:00,  2.21it/s]


Avg Train Loss: 0.3588626051694155, Avg Train Acc: 0.7949000006914139
Final model saved to /home/onyxia/work/PrototypicalFewShots-4/last_model.pth
