In [1]:
## classique 

import numpy as np
import torch 
import h5py # pour gérer les formats de données utilisés ici 
import torch
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

## handmade
import train
import test

On commence par charger les données d'entraînement, que l'on va plus tard utiliser pour entraîner le modèle de few-shot learning en question. 

In [None]:
# Chemin vers le fichier de données
x, y, snr = test.load_data(type="enroll")


### Few-Shot Classification

In few-shot classification we are given a small support set of $ N $ labeled examples $ S = \{(x_1, y_1), \dots, (x_N, y_N)\} $, where each $ x_i \in \mathbb{R}^D $ is the $ D $-dimensional feature vector of an example and $ y_i \in \{1, \dots, K\} $ is the corresponding label. $ S_k $ denotes the set of examples labeled with class $ k $.

### Prototypical Networks classification

Prototypical networks compute an $ M $-dimensional representation $ c_k \in \mathbb{R}^M $, or prototype, of each class through an embedding function $ f_\phi : \mathbb{R}^D \to \mathbb{R}^M $ with learnable parameters $ \phi $. Each prototype is the mean vector of the embedded support points belonging to its class:

$$
c_k = \frac{1}{|S_k|} \sum_{(x_i, y_i) \in S_k} f_\phi(x_i) \tag{1}
$$

Given a distance function $ d : \mathbb{R}^M \times \mathbb{R}^M \to [0, +\infty) $, prototypical networks produce a distribution over classes for a query point $ x $ based on a softmax over distances to the prototypes in the embedding space:

$$
p_\phi(y = k \mid x) = \frac{\exp(-d(f_\phi(x), c_k))}{\sum_{k'} \exp(-d(f_\phi(x), c_{k'}))} \tag{2}
$$

Learning proceeds by minimizing the negative log-probability:

$$
J(\phi) = -\log p_\phi(y = k \mid x)
$$

of the true class $ k $ via stochastic gradient descent (SGD). Training episodes are formed by randomly selecting a subset of classes from the training set, then choosing a subset of examples within each class to act as the support set and a subset of the remainder to serve as query points. 


In [2]:
opt = {
    'train_dataset_path': 'train.hdf5',  # Path to the dataset
    'epochs': 50,  # Number of epochs for training
    'learning_rate': 0.001,  # Learning rate for the optimizer
    'lr_scheduler_step': 20,  # Step size for the learning rate scheduler
    'lr_scheduler_gamma': 0.5,  # Multiplicative factor for learning rate decay
    'iterations': 100,  # Number of episodes per epoch
    'classes_per_it_tr': 5,  # Number of classes per training iteration
    'num_support_tr': 30,  # Number of support samples per class for training
    'num_query_tr': 30,  # Number of query samples per class for training
    'classes_per_it_val': 5,  # Number of classes per validation iteration
    'num_support_val': 5,  # Number of support samples per class for validation
    'num_query_val': 15,  # Number of query samples per class for validation
    'manual_seed': 7,  # Manual seed for reproducibility
    'device': 'cuda:0'  # Whether to use CUDA for computation
}

In [3]:
model, best_acc, train_loss, train_acc = train.main(opt)

=== Epoch: 0 ===


100%|██████████| 100/100 [00:43<00:00,  2.32it/s]


Avg Train Loss: 1.0285357028245925, Avg Train Acc: 0.47293333321809766
=== Epoch: 1 ===


100%|██████████| 100/100 [00:42<00:00,  2.38it/s]


Avg Train Loss: 0.6827497184276581, Avg Train Acc: 0.6455333346128463
=== Epoch: 2 ===


100%|██████████| 100/100 [00:41<00:00,  2.41it/s]


Avg Train Loss: 0.564268915951252, Avg Train Acc: 0.6633333340287209
=== Epoch: 3 ===


100%|██████████| 100/100 [00:42<00:00,  2.35it/s]


Avg Train Loss: 0.5468684497475624, Avg Train Acc: 0.6630666670203209
=== Epoch: 4 ===


100%|██████████| 100/100 [00:42<00:00,  2.35it/s]


Avg Train Loss: 0.5513446366786957, Avg Train Acc: 0.6607999986410141
=== Epoch: 5 ===


100%|██████████| 100/100 [00:43<00:00,  2.30it/s]


Avg Train Loss: 0.5701562124490738, Avg Train Acc: 0.652533331811428
=== Epoch: 6 ===


100%|██████████| 100/100 [00:41<00:00,  2.40it/s]


Avg Train Loss: 0.6177090102434158, Avg Train Acc: 0.6555999994277955
=== Epoch: 7 ===


100%|██████████| 100/100 [00:40<00:00,  2.45it/s]


Avg Train Loss: 0.5215285468101502, Avg Train Acc: 0.6791999995708465
=== Epoch: 8 ===


100%|██████████| 100/100 [00:41<00:00,  2.41it/s]


Avg Train Loss: 0.5257063579559326, Avg Train Acc: 0.6781333315372468
=== Epoch: 9 ===


100%|██████████| 100/100 [00:42<00:00,  2.34it/s]


Avg Train Loss: 0.5300084775686265, Avg Train Acc: 0.6778000018000603
=== Epoch: 10 ===


100%|██████████| 100/100 [00:42<00:00,  2.36it/s]


Avg Train Loss: 0.5103520885109901, Avg Train Acc: 0.6833333358168602
=== Epoch: 11 ===


100%|██████████| 100/100 [00:41<00:00,  2.39it/s]


Avg Train Loss: 0.5277219668030739, Avg Train Acc: 0.6749333328008652
=== Epoch: 12 ===


100%|██████████| 100/100 [00:42<00:00,  2.34it/s]


Avg Train Loss: 0.5217875558137893, Avg Train Acc: 0.6771999973058701
=== Epoch: 13 ===


100%|██████████| 100/100 [00:59<00:00,  1.68it/s]


Avg Train Loss: 0.5006496405601502, Avg Train Acc: 0.6960666662454605
=== Epoch: 14 ===


100%|██████████| 100/100 [00:49<00:00,  2.02it/s]


Avg Train Loss: 0.5187704727053642, Avg Train Acc: 0.6905333322286605
=== Epoch: 15 ===


100%|██████████| 100/100 [00:54<00:00,  1.82it/s]


Avg Train Loss: 0.4838590520620346, Avg Train Acc: 0.7211999994516373
=== Epoch: 16 ===


100%|██████████| 100/100 [00:48<00:00,  2.05it/s]


Avg Train Loss: 0.49815275490283967, Avg Train Acc: 0.7244000029563904
=== Epoch: 17 ===


100%|██████████| 100/100 [00:59<00:00,  1.67it/s]


Avg Train Loss: 0.46820365995168683, Avg Train Acc: 0.7502666717767715
=== Epoch: 18 ===


100%|██████████| 100/100 [00:49<00:00,  2.00it/s]


Avg Train Loss: 0.4422785377502441, Avg Train Acc: 0.7558000022172928
=== Epoch: 19 ===


100%|██████████| 100/100 [00:48<00:00,  2.07it/s]


Avg Train Loss: 0.41187009543180464, Avg Train Acc: 0.7676000010967254
=== Epoch: 20 ===


100%|██████████| 100/100 [00:53<00:00,  1.86it/s]


Avg Train Loss: 0.39184481278061867, Avg Train Acc: 0.7789333319664001
=== Epoch: 21 ===


100%|██████████| 100/100 [00:47<00:00,  2.11it/s]


Avg Train Loss: 0.4012125761806965, Avg Train Acc: 0.7633333319425583
=== Epoch: 22 ===


100%|██████████| 100/100 [00:41<00:00,  2.44it/s]


Avg Train Loss: 0.400803996771574, Avg Train Acc: 0.7730000001192093
=== Epoch: 23 ===


100%|██████████| 100/100 [00:42<00:00,  2.38it/s]


Avg Train Loss: 0.3734773500263691, Avg Train Acc: 0.7900666677951813
=== Epoch: 24 ===


100%|██████████| 100/100 [00:42<00:00,  2.37it/s]


Avg Train Loss: 0.3886082346737385, Avg Train Acc: 0.7762000006437302
=== Epoch: 25 ===


100%|██████████| 100/100 [00:41<00:00,  2.40it/s]


Avg Train Loss: 0.39055453687906266, Avg Train Acc: 0.7756666702032089
=== Epoch: 26 ===


100%|██████████| 100/100 [00:42<00:00,  2.36it/s]


Avg Train Loss: 0.3810998845100403, Avg Train Acc: 0.7840000003576278
=== Epoch: 27 ===


100%|██████████| 100/100 [00:43<00:00,  2.32it/s]


Avg Train Loss: 0.40343222975730897, Avg Train Acc: 0.7758666694164276
=== Epoch: 28 ===


100%|██████████| 100/100 [00:42<00:00,  2.34it/s]


Avg Train Loss: 0.3809508068859577, Avg Train Acc: 0.783333335518837
=== Epoch: 29 ===


100%|██████████| 100/100 [00:42<00:00,  2.34it/s]


Avg Train Loss: 0.3583733157813549, Avg Train Acc: 0.7930666655302048
=== Epoch: 30 ===


100%|██████████| 100/100 [00:41<00:00,  2.40it/s]


Avg Train Loss: 0.3616869989037514, Avg Train Acc: 0.7906666696071625
=== Epoch: 31 ===


100%|██████████| 100/100 [00:40<00:00,  2.46it/s]


Avg Train Loss: 0.37087161481380465, Avg Train Acc: 0.7890000009536743
=== Epoch: 32 ===


100%|██████████| 100/100 [00:41<00:00,  2.42it/s]


Avg Train Loss: 0.3664218567311764, Avg Train Acc: 0.7875333338975906
=== Epoch: 33 ===


100%|██████████| 100/100 [00:41<00:00,  2.42it/s]


Avg Train Loss: 0.37300139099359514, Avg Train Acc: 0.7809999978542328
=== Epoch: 34 ===


100%|██████████| 100/100 [00:40<00:00,  2.46it/s]


Avg Train Loss: 0.3618406347930431, Avg Train Acc: 0.7944666659832
=== Epoch: 35 ===


100%|██████████| 100/100 [00:40<00:00,  2.44it/s]


Avg Train Loss: 0.362326797246933, Avg Train Acc: 0.7920000028610229
=== Epoch: 36 ===


100%|██████████| 100/100 [00:42<00:00,  2.38it/s]


Avg Train Loss: 0.3898339001834393, Avg Train Acc: 0.778466666340828
=== Epoch: 37 ===


100%|██████████| 100/100 [00:42<00:00,  2.34it/s]


Avg Train Loss: 0.36028827920556067, Avg Train Acc: 0.7905333322286606
=== Epoch: 38 ===


100%|██████████| 100/100 [01:00<00:00,  1.65it/s]


Avg Train Loss: 0.3691503643989563, Avg Train Acc: 0.7840000015497207
=== Epoch: 39 ===


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


Avg Train Loss: 0.36287069275975226, Avg Train Acc: 0.7951333355903626
=== Epoch: 40 ===


100%|██████████| 100/100 [00:54<00:00,  1.82it/s]


Avg Train Loss: 0.3424629950523376, Avg Train Acc: 0.8050666683912278
=== Epoch: 41 ===


100%|██████████| 100/100 [00:45<00:00,  2.18it/s]


Avg Train Loss: 0.3474241590499878, Avg Train Acc: 0.8014666682481766
=== Epoch: 42 ===


100%|██████████| 100/100 [00:42<00:00,  2.38it/s]


Avg Train Loss: 0.3311369703710079, Avg Train Acc: 0.813866668343544
=== Epoch: 43 ===


100%|██████████| 100/100 [00:42<00:00,  2.34it/s]


Avg Train Loss: 0.3371964531391859, Avg Train Acc: 0.8058000028133392
=== Epoch: 44 ===


100%|██████████| 100/100 [00:41<00:00,  2.42it/s]


Avg Train Loss: 0.33845368564128875, Avg Train Acc: 0.8045333349704742
=== Epoch: 45 ===


100%|██████████| 100/100 [00:40<00:00,  2.46it/s]


Avg Train Loss: 0.3337714618444443, Avg Train Acc: 0.8052666699886322
=== Epoch: 46 ===


100%|██████████| 100/100 [00:41<00:00,  2.39it/s]


Avg Train Loss: 0.31568843707442285, Avg Train Acc: 0.8206666684150696
=== Epoch: 47 ===


100%|██████████| 100/100 [00:41<00:00,  2.39it/s]


Avg Train Loss: 0.33001531079411506, Avg Train Acc: 0.809600002169609
=== Epoch: 48 ===


100%|██████████| 100/100 [00:42<00:00,  2.33it/s]


Avg Train Loss: 0.3173763997107744, Avg Train Acc: 0.8175333344936371
=== Epoch: 49 ===


100%|██████████| 100/100 [00:41<00:00,  2.41it/s]


Avg Train Loss: 0.3311737933754921, Avg Train Acc: 0.8092666709423065
Final model saved to /home/onyxia/work/PrototypicalFewShots-4/last_model.pth
