In [9]:
# ===========================
# Import Libraries
# ===========================

# Numerical & Array Operations
import numpy as np

# Deep Learning & Torch Utilities
import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
import torchvision
from torchvision import datasets, transforms, models
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor

# Image Processing & Computer Vision
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Machine Learning & Evaluation Metrics
from sklearn.preprocessing import normalize
from sklearn.metrics import accuracy_score
from sklearn.svm import LinearSVC

# Utilities
from pathlib import Path
from tqdm import tqdm

In [2]:
# Charger le fichier .npz
data = np.load("dataset_fine_tuning.npz", allow_pickle=True)

# Récupérer les ensembles de données
X_train = data['X_train']
y_train = data['y_train']
X_test = data['X_test']
y_test = data['y_test']

# Vérifier les dimensions (facultatif)
print(f"X_train: {len(X_train)}, y_train: {len(y_train)}")
print(f"X_test: {len(X_test)}, y_test: {len(y_test)}")


X_train: 80000, y_train: 80000
X_test: 20000, y_test: 20000


In [3]:
# Définir une classe Dataset personnalisée
class CustomDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        image = self.X[idx]
        image = np.array(image)
        image = np.transpose(image, (1, 2, 0))
        y = self.y[idx]
        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations["image"]

        return image, y


In [31]:
# Charger ResNet-34 pré-entraîné
model = models.resnet34(pretrained=False)

# Chemin vers votre fichier de poids
weights_path = "resnet34_cerberus_torchvision .pth"

# Charger les poids sauvegardés
state_dict = torch.load(weights_path)
model.load_state_dict(state_dict)

  state_dict = torch.load(weights_path)


<All keys matched successfully>

In [32]:
# Freeze all layers: disable gradient tracking
for p in model.parameters():
    p.requires_grad = False
# "Thaw" last layer (or whatever is relevant for you)
for p in model.layer1.parameters():
    p.requires_grad = True
for p in model.layer2.parameters():
    p.requires_grad = True
for p in model.layer3.parameters():
    p.requires_grad = True
for p in model.layer4.parameters():
    p.requires_grad = True
# Replace fully-connected part by some other classifier, e.g.

cnn_features = 512
num_classes = 6

model.fc =  nn.Sequential(
    nn.Linear(cnn_features, 100, bias=True),
    nn.ReLU(),
    nn.Linear(100, num_classes, bias=True),
)

In [33]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model=model.to(device)

Using device: cuda:3


In [5]:
def loss_accuracy(loss, Yhat, Y):


    L = loss(Yhat, Y)
    acc = (torch.sum(torch.argmax(Yhat, dim=1) == Y) / Y.size(0)) * 100


    return L, acc

In [14]:
# Évaluation en mini-batches, la RAM ne support pas le forward sur tout le test_loader
def evaluate_model(model, test_loader, loss_fn, device):
    model.eval()  # Mode évaluation
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():  # Pas de calcul des gradients
        for X_batch, Y_batch in test_loader:
            X_batch, Y_batch = X_batch.float().to(device), Y_batch.to(device)  # Déplacer sur GPU
            Yhat = model(X_batch)  # Prédictions
            loss = loss_fn(Yhat, Y_batch)  # Calcul de la perte
            total_loss += loss.item() * X_batch.size(0)  # Ajouter perte pondérée par la taille du batch
            total_correct += (torch.argmax(Yhat, dim=1) == Y_batch).sum().item()  # Compter les bons résultats
            total_samples += X_batch.size(0)  # Total des échantillons
    model.train()
    avg_loss = total_loss / total_samples  # Moyenne des pertes
    avg_accuracy = (total_correct / total_samples) * 100  # Précision en pourcentage
    return avg_loss, avg_accuracy

# Dataset of 100K (Cerberus was trained on a subset of this dataset)

In [10]:
# Transformer X_train, X_test en tensors pour l'évaluation complète
X_train = torch.stack([x for x in X_train])
X_test = torch.stack([x for x in X_test])
y_train = torch.tensor(y_train)
y_test = torch.tensor(y_test)

In [26]:
nx=1000
nh_1=100
nh_2=100
ny=10
eta=0.1
n_iter=10
loss = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
model.train()

# Planificateur StepLR : Réduction du LR tous les 10 epochs par un facteur 0.1
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
batch_size=32

# epoch
for iteration in tqdm(range(n_iter), desc="Epochs", unit="epoch"):
    
        train_transform = A.Compose([
            A.ElasticTransform(alpha=1.0, sigma=50, p=0.25),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.25),
            A.Rotate(limit=35, p=0.7),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.GaussianBlur(blur_limit=(3, 7), p=0.2),
            ToTensorV2(),
        ])
        
        val_transforms = A.Compose(
        [ 
            ToTensorV2(),
        ],
        )
        
        # Recréer les datasets PyTorch
        train_dataset = CustomDataset(X_train, y_train, train_transform)
        test_dataset = CustomDataset(X_test, y_test, val_transforms)
        
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
        
        
    
        for X_batch, Y_batch in tqdm(train_loader, desc="Batches", unit="batch", leave=False):
                X_batch, Y_batch = X_batch.to(device), Y_batch.to(device) 
                optimizer.zero_grad()
                Yhat = model(X_batch)  # Prédictions pour le batch actuel
                L, acc = loss_accuracy(loss, Yhat, Y_batch)  # Calcul de la perte et de la précision
                L.backward()
                optimizer.step()
        scheduler.step()
    
        
        Ltest, acctest = evaluate_model(model, test_loader, loss, device)
    
        title = "Iter {}:, acc test {:.1f}% ({:.2f})".format(iteration, acctest, Ltest)

        
    
        print(title)

Epochs:   0%|                                                                                 | 0/10 [00:00<?, ?epoch/s]
Batches:   0%|                                                                              | 0/2500 [00:00<?, ?batch/s][A
Batches:   0%|                                                                      | 1/2500 [00:00<06:18,  6.60batch/s][A
Batches:   0%|                                                                      | 2/2500 [00:00<05:09,  8.07batch/s][A
Batches:   0%|                                                                      | 3/2500 [00:00<05:07,  8.11batch/s][A
Batches:   0%|                                                                      | 4/2500 [00:00<04:44,  8.76batch/s][A
Batches:   0%|▏                                                                     | 6/2500 [00:00<04:30,  9.22batch/s][A
Batches:   0%|▏                                                                     | 7/2500 [00:00<04:33,  9.11batch/s][A
Batches:   

Batches:   8%|█████▌                                                              | 205/2500 [00:20<03:56,  9.71batch/s][A
Batches:   8%|█████▋                                                              | 207/2500 [00:20<03:52,  9.85batch/s][A
Batches:   8%|█████▋                                                              | 209/2500 [00:21<03:49,  9.98batch/s][A
Batches:   8%|█████▋                                                              | 210/2500 [00:21<03:54,  9.76batch/s][A
Batches:   8%|█████▋                                                              | 211/2500 [00:21<04:09,  9.18batch/s][A
Batches:   9%|█████▊                                                              | 213/2500 [00:21<04:00,  9.49batch/s][A
Batches:   9%|█████▊                                                              | 214/2500 [00:21<04:05,  9.33batch/s][A
Batches:   9%|█████▉                                                              | 216/2500 [00:21<03:57,  9.60batch/s][A
Batches:

Batches:  16%|██████████▊                                                         | 398/2500 [00:40<03:34,  9.78batch/s][A
Batches:  16%|██████████▊                                                         | 399/2500 [00:40<03:40,  9.54batch/s][A
Batches:  16%|██████████▉                                                         | 400/2500 [00:40<03:44,  9.36batch/s][A
Batches:  16%|██████████▉                                                         | 402/2500 [00:40<03:41,  9.46batch/s][A
Batches:  16%|██████████▉                                                         | 403/2500 [00:41<03:39,  9.54batch/s][A
Batches:  16%|██████████▉                                                         | 404/2500 [00:41<03:37,  9.64batch/s][A
Batches:  16%|███████████                                                         | 405/2500 [00:41<03:40,  9.50batch/s][A
Batches:  16%|███████████                                                         | 406/2500 [00:41<03:43,  9.35batch/s][A
Batches:

Batches:  24%|████████████████                                                    | 589/2500 [01:00<03:22,  9.42batch/s][A
Batches:  24%|████████████████                                                    | 590/2500 [01:00<03:24,  9.32batch/s][A
Batches:  24%|████████████████                                                    | 591/2500 [01:00<03:30,  9.08batch/s][A
Batches:  24%|████████████████                                                    | 592/2500 [01:00<03:29,  9.11batch/s][A
Batches:  24%|████████████████▏                                                   | 593/2500 [01:00<03:28,  9.13batch/s][A
Batches:  24%|████████████████▏                                                   | 594/2500 [01:00<03:26,  9.24batch/s][A
Batches:  24%|████████████████▏                                                   | 595/2500 [01:00<03:30,  9.06batch/s][A
Batches:  24%|████████████████▏                                                   | 597/2500 [01:00<03:18,  9.59batch/s][A
Batches:

Batches:  29%|███████████████████▉                                                | 733/2500 [01:15<03:18,  8.90batch/s][A
Batches:  29%|███████████████████▉                                                | 734/2500 [01:16<03:24,  8.64batch/s][A
Batches:  29%|███████████████████▉                                                | 735/2500 [01:16<03:20,  8.81batch/s][A
Batches:  29%|████████████████████                                                | 736/2500 [01:16<03:23,  8.68batch/s][A
Batches:  29%|████████████████████                                                | 737/2500 [01:16<03:19,  8.84batch/s][A
Batches:  30%|████████████████████                                                | 738/2500 [01:16<03:18,  8.87batch/s][A
Batches:  30%|████████████████████▏                                               | 740/2500 [01:16<03:10,  9.26batch/s][A
Batches:  30%|████████████████████▏                                               | 741/2500 [01:16<03:09,  9.27batch/s][A
Batches:

Batches:  35%|███████████████████████▊                                            | 875/2500 [01:31<02:56,  9.20batch/s][A
Batches:  35%|███████████████████████▊                                            | 877/2500 [01:31<02:56,  9.18batch/s][A
Batches:  35%|███████████████████████▉                                            | 878/2500 [01:32<02:56,  9.17batch/s][A
Batches:  35%|███████████████████████▉                                            | 879/2500 [01:32<02:59,  9.05batch/s][A
Batches:  35%|███████████████████████▉                                            | 880/2500 [01:32<03:03,  8.85batch/s][A
Batches:  35%|███████████████████████▉                                            | 881/2500 [01:32<03:03,  8.81batch/s][A
Batches:  35%|███████████████████████▉                                            | 882/2500 [01:32<03:03,  8.81batch/s][A
Batches:  35%|████████████████████████                                            | 883/2500 [01:32<02:58,  9.04batch/s][A
Batches:

Batches:  41%|███████████████████████████▍                                       | 1023/2500 [01:47<02:37,  9.39batch/s][A
Batches:  41%|███████████████████████████▍                                       | 1024/2500 [01:48<02:40,  9.22batch/s][A
Batches:  41%|███████████████████████████▍                                       | 1025/2500 [01:48<02:36,  9.40batch/s][A
Batches:  41%|███████████████████████████▍                                       | 1026/2500 [01:48<02:38,  9.32batch/s][A
Batches:  41%|███████████████████████████▌                                       | 1027/2500 [01:48<02:37,  9.35batch/s][A
Batches:  41%|███████████████████████████▌                                       | 1028/2500 [01:48<02:39,  9.25batch/s][A
Batches:  41%|███████████████████████████▌                                       | 1029/2500 [01:48<02:36,  9.40batch/s][A
Batches:  41%|███████████████████████████▌                                       | 1030/2500 [01:48<02:39,  9.20batch/s][A
Batches:

Batches:  46%|███████████████████████████████▏                                   | 1162/2500 [02:03<02:26,  9.13batch/s][A
Batches:  47%|███████████████████████████████▏                                   | 1163/2500 [02:03<02:22,  9.37batch/s][A
Batches:  47%|███████████████████████████████▏                                   | 1164/2500 [02:03<02:31,  8.83batch/s][A
Batches:  47%|███████████████████████████████▏                                   | 1165/2500 [02:03<02:30,  8.89batch/s][A
Batches:  47%|███████████████████████████████▏                                   | 1166/2500 [02:03<02:26,  9.11batch/s][A
Batches:  47%|███████████████████████████████▎                                   | 1167/2500 [02:03<02:29,  8.89batch/s][A
Batches:  47%|███████████████████████████████▎                                   | 1168/2500 [02:04<02:32,  8.71batch/s][A
Batches:  47%|███████████████████████████████▎                                   | 1169/2500 [02:04<02:31,  8.78batch/s][A
Batches:

Batches:  52%|███████████████████████████████████                                | 1307/2500 [02:19<02:11,  9.09batch/s][A
Batches:  52%|███████████████████████████████████                                | 1308/2500 [02:19<02:09,  9.22batch/s][A
Batches:  52%|███████████████████████████████████                                | 1309/2500 [02:19<02:07,  9.37batch/s][A
Batches:  52%|███████████████████████████████████                                | 1310/2500 [02:19<02:07,  9.37batch/s][A
Batches:  52%|███████████████████████████████████▏                               | 1311/2500 [02:19<02:13,  8.93batch/s][A
Batches:  52%|███████████████████████████████████▏                               | 1312/2500 [02:20<02:12,  8.97batch/s][A
Batches:  53%|███████████████████████████████████▏                               | 1313/2500 [02:20<02:09,  9.14batch/s][A
Batches:  53%|███████████████████████████████████▏                               | 1314/2500 [02:20<02:09,  9.17batch/s][A
Batches:

Batches:  58%|██████████████████████████████████████▉                            | 1455/2500 [02:35<01:54,  9.10batch/s][A
Batches:  58%|███████████████████████████████████████                            | 1456/2500 [02:35<01:54,  9.12batch/s][A
Batches:  58%|███████████████████████████████████████                            | 1457/2500 [02:36<01:59,  8.74batch/s][A
Batches:  58%|███████████████████████████████████████                            | 1458/2500 [02:36<01:57,  8.88batch/s][A
Batches:  58%|███████████████████████████████████████                            | 1459/2500 [02:36<01:56,  8.96batch/s][A
Batches:  58%|███████████████████████████████████████▏                           | 1460/2500 [02:36<01:58,  8.80batch/s][A
Batches:  58%|███████████████████████████████████████▏                           | 1461/2500 [02:36<02:02,  8.50batch/s][A
Batches:  58%|███████████████████████████████████████▏                           | 1462/2500 [02:36<01:59,  8.69batch/s][A
Batches:

Batches:  64%|██████████████████████████████████████████▉                        | 1600/2500 [02:51<01:40,  8.98batch/s][A
Batches:  64%|██████████████████████████████████████████▉                        | 1601/2500 [02:51<01:39,  9.06batch/s][A
Batches:  64%|██████████████████████████████████████████▉                        | 1602/2500 [02:51<01:41,  8.87batch/s][A
Batches:  64%|██████████████████████████████████████████▉                        | 1603/2500 [02:52<01:38,  9.09batch/s][A
Batches:  64%|██████████████████████████████████████████▉                        | 1604/2500 [02:52<01:40,  8.92batch/s][A
Batches:  64%|███████████████████████████████████████████                        | 1605/2500 [02:52<01:37,  9.14batch/s][A
Batches:  64%|███████████████████████████████████████████                        | 1606/2500 [02:52<01:37,  9.21batch/s][A
Batches:  64%|███████████████████████████████████████████                        | 1607/2500 [02:52<01:36,  9.24batch/s][A
Batches:

Batches:  70%|███████████████████████████████████████████████                    | 1755/2500 [03:08<01:19,  9.40batch/s][A
Batches:  70%|███████████████████████████████████████████████                    | 1756/2500 [03:08<01:19,  9.39batch/s][A
Batches:  70%|███████████████████████████████████████████████                    | 1757/2500 [03:08<01:20,  9.22batch/s][A
Batches:  70%|███████████████████████████████████████████████                    | 1758/2500 [03:08<01:20,  9.23batch/s][A
Batches:  70%|███████████████████████████████████████████████▏                   | 1759/2500 [03:08<01:18,  9.41batch/s][A
Batches:  70%|███████████████████████████████████████████████▏                   | 1760/2500 [03:09<01:20,  9.17batch/s][A
Batches:  70%|███████████████████████████████████████████████▏                   | 1761/2500 [03:09<01:19,  9.28batch/s][A
Batches:  70%|███████████████████████████████████████████████▏                   | 1762/2500 [03:09<01:21,  9.10batch/s][A
Batches:

Batches:  76%|██████████████████████████████████████████████████▉                | 1900/2500 [03:24<01:06,  9.00batch/s][A
Batches:  76%|██████████████████████████████████████████████████▉                | 1902/2500 [03:24<01:04,  9.25batch/s][A
Batches:  76%|███████████████████████████████████████████████████                | 1903/2500 [03:24<01:06,  8.92batch/s][A
Batches:  76%|███████████████████████████████████████████████████                | 1904/2500 [03:25<01:09,  8.56batch/s][A
Batches:  76%|███████████████████████████████████████████████████                | 1905/2500 [03:25<01:08,  8.74batch/s][A
Batches:  76%|███████████████████████████████████████████████████                | 1906/2500 [03:25<01:08,  8.65batch/s][A
Batches:  76%|███████████████████████████████████████████████████▏               | 1908/2500 [03:25<01:05,  9.04batch/s][A
Batches:  76%|███████████████████████████████████████████████████▏               | 1909/2500 [03:25<01:05,  9.04batch/s][A
Batches:

Batches:  82%|██████████████████████████████████████████████████████▋            | 2039/2500 [03:40<00:49,  9.25batch/s][A
Batches:  82%|██████████████████████████████████████████████████████▋            | 2040/2500 [03:40<00:49,  9.34batch/s][A
Batches:  82%|██████████████████████████████████████████████████████▋            | 2041/2500 [03:40<00:50,  9.06batch/s][A
Batches:  82%|██████████████████████████████████████████████████████▋            | 2042/2500 [03:40<00:50,  9.08batch/s][A
Batches:  82%|██████████████████████████████████████████████████████▊            | 2043/2500 [03:40<00:50,  9.13batch/s][A
Batches:  82%|██████████████████████████████████████████████████████▊            | 2044/2500 [03:40<00:51,  8.81batch/s][A
Batches:  82%|██████████████████████████████████████████████████████▊            | 2045/2500 [03:40<00:50,  8.98batch/s][A
Batches:  82%|██████████████████████████████████████████████████████▊            | 2046/2500 [03:40<00:49,  9.16batch/s][A
Batches:

Batches:  87%|██████████████████████████████████████████████████████████▍        | 2181/2500 [03:55<00:36,  8.63batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▍        | 2182/2500 [03:55<00:36,  8.76batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▌        | 2183/2500 [03:56<00:35,  8.87batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▌        | 2184/2500 [03:56<00:35,  9.01batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▌        | 2185/2500 [03:56<00:35,  8.83batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▌        | 2186/2500 [03:56<00:35,  8.79batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▌        | 2187/2500 [03:56<00:36,  8.63batch/s][A
Batches:  88%|██████████████████████████████████████████████████████████▋        | 2188/2500 [03:56<00:35,  8.78batch/s][A
Batches:

Batches:  93%|██████████████████████████████████████████████████████████████     | 2314/2500 [04:12<00:22,  8.10batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████     | 2315/2500 [04:12<00:22,  8.05batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████     | 2317/2500 [04:12<00:21,  8.67batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████     | 2318/2500 [04:12<00:21,  8.63batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▏    | 2319/2500 [04:12<00:20,  8.85batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▏    | 2320/2500 [04:13<00:20,  8.58batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▏    | 2321/2500 [04:13<00:22,  8.06batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▏    | 2322/2500 [04:13<00:22,  7.87batch/s][A
Batches:

Batches:  98%|█████████████████████████████████████████████████████████████████▌ | 2448/2500 [04:29<00:06,  8.03batch/s][A
Batches:  98%|█████████████████████████████████████████████████████████████████▋ | 2449/2500 [04:29<00:06,  8.00batch/s][A
Batches:  98%|█████████████████████████████████████████████████████████████████▋ | 2450/2500 [04:29<00:06,  7.68batch/s][A
Batches:  98%|█████████████████████████████████████████████████████████████████▋ | 2451/2500 [04:29<00:06,  7.73batch/s][A
Batches:  98%|█████████████████████████████████████████████████████████████████▋ | 2452/2500 [04:29<00:06,  7.60batch/s][A
Batches:  98%|█████████████████████████████████████████████████████████████████▋ | 2453/2500 [04:29<00:06,  7.80batch/s][A
Batches:  98%|█████████████████████████████████████████████████████████████████▊ | 2454/2500 [04:30<00:05,  7.85batch/s][A
Batches:  98%|█████████████████████████████████████████████████████████████████▊ | 2455/2500 [04:30<00:05,  7.54batch/s][A
Batches:

Iter 0:, acc test 96.1% (0.21)



Batches:   0%|                                                                              | 0/2500 [00:00<?, ?batch/s][A
Batches:   0%|                                                                      | 1/2500 [00:00<04:20,  9.60batch/s][A
Batches:   0%|                                                                      | 3/2500 [00:00<04:24,  9.43batch/s][A
Batches:   0%|▏                                                                     | 5/2500 [00:00<04:18,  9.64batch/s][A
Batches:   0%|▏                                                                     | 7/2500 [00:00<04:08, 10.03batch/s][A
Batches:   0%|▎                                                                     | 9/2500 [00:00<04:18,  9.62batch/s][A
Batches:   0%|▎                                                                    | 10/2500 [00:01<04:22,  9.49batch/s][A
Batches:   0%|▎                                                                    | 11/2500 [00:01<04:29,  9.23batch/s][A
Batches

Batches:   6%|████                                                                | 148/2500 [00:16<04:10,  9.39batch/s][A
Batches:   6%|████                                                                | 149/2500 [00:16<04:12,  9.33batch/s][A
Batches:   6%|████                                                                | 150/2500 [00:16<04:20,  9.02batch/s][A
Batches:   6%|████                                                                | 151/2500 [00:16<04:16,  9.17batch/s][A
Batches:   6%|████▏                                                               | 152/2500 [00:16<04:18,  9.10batch/s][A
Batches:   6%|████▏                                                               | 153/2500 [00:16<04:22,  8.95batch/s][A
Batches:   6%|████▏                                                               | 154/2500 [00:16<04:25,  8.82batch/s][A
Batches:   6%|████▏                                                               | 155/2500 [00:16<04:23,  8.91batch/s][A
Batches:

Batches:  12%|████████▏                                                           | 300/2500 [00:32<03:35, 10.20batch/s][A
Batches:  12%|████████▏                                                           | 302/2500 [00:32<03:33, 10.29batch/s][A
Batches:  12%|████████▎                                                           | 304/2500 [00:32<03:29, 10.49batch/s][A
Batches:  12%|████████▎                                                           | 306/2500 [00:33<03:35, 10.16batch/s][A
Batches:  12%|████████▍                                                           | 308/2500 [00:33<03:34, 10.20batch/s][A
Batches:  12%|████████▍                                                           | 310/2500 [00:33<03:28, 10.50batch/s][A
Batches:  12%|████████▍                                                           | 312/2500 [00:33<03:33, 10.26batch/s][A
Batches:  13%|████████▌                                                           | 314/2500 [00:33<03:31, 10.35batch/s][A
Batches:

Batches:  21%|██████████████                                                      | 515/2500 [00:54<03:20,  9.88batch/s][A
Batches:  21%|██████████████                                                      | 516/2500 [00:54<03:23,  9.76batch/s][A
Batches:  21%|██████████████                                                      | 518/2500 [00:54<03:18, 10.00batch/s][A
Batches:  21%|██████████████▏                                                     | 520/2500 [00:54<03:15, 10.14batch/s][A
Batches:  21%|██████████████▏                                                     | 522/2500 [00:54<03:16, 10.05batch/s][A
Batches:  21%|██████████████▏                                                     | 523/2500 [00:55<03:20,  9.86batch/s][A
Batches:  21%|██████████████▎                                                     | 524/2500 [00:55<03:26,  9.55batch/s][A
Batches:  21%|██████████████▎                                                     | 525/2500 [00:55<03:25,  9.62batch/s][A
Batches:

Batches:  28%|███████████████████                                                 | 702/2500 [01:13<03:08,  9.54batch/s][A
Batches:  28%|███████████████████                                                 | 703/2500 [01:13<03:13,  9.30batch/s][A
Batches:  28%|███████████████████▏                                                | 705/2500 [01:13<03:04,  9.72batch/s][A
Batches:  28%|███████████████████▏                                                | 706/2500 [01:13<03:07,  9.58batch/s][A
Batches:  28%|███████████████████▏                                                | 707/2500 [01:13<03:09,  9.45batch/s][A
Batches:  28%|███████████████████▎                                                | 708/2500 [01:14<03:09,  9.45batch/s][A
Batches:  28%|███████████████████▎                                                | 710/2500 [01:14<02:58, 10.02batch/s][A
Batches:  28%|███████████████████▎                                                | 712/2500 [01:14<03:02,  9.79batch/s][A
Batches:

Batches:  35%|███████████████████████▉                                            | 878/2500 [01:31<02:55,  9.22batch/s][A
Batches:  35%|███████████████████████▉                                            | 879/2500 [01:32<02:56,  9.20batch/s][A
Batches:  35%|███████████████████████▉                                            | 880/2500 [01:32<02:57,  9.14batch/s][A
Batches:  35%|███████████████████████▉                                            | 881/2500 [01:32<02:54,  9.25batch/s][A
Batches:  35%|███████████████████████▉                                            | 882/2500 [01:32<03:03,  8.82batch/s][A
Batches:  35%|████████████████████████                                            | 883/2500 [01:32<03:05,  8.74batch/s][A
Batches:  35%|████████████████████████                                            | 884/2500 [01:32<03:01,  8.90batch/s][A
Batches:  35%|████████████████████████                                            | 885/2500 [01:32<02:56,  9.15batch/s][A
Batches:

Batches:  41%|███████████████████████████▌                                       | 1028/2500 [01:48<02:47,  8.79batch/s][A
Batches:  41%|███████████████████████████▌                                       | 1029/2500 [01:48<02:47,  8.80batch/s][A
Batches:  41%|███████████████████████████▌                                       | 1030/2500 [01:48<02:42,  9.06batch/s][A
Batches:  41%|███████████████████████████▋                                       | 1031/2500 [01:48<02:39,  9.23batch/s][A
Batches:  41%|███████████████████████████▋                                       | 1032/2500 [01:48<02:39,  9.23batch/s][A
Batches:  41%|███████████████████████████▋                                       | 1033/2500 [01:49<02:48,  8.70batch/s][A
Batches:  41%|███████████████████████████▋                                       | 1034/2500 [01:49<02:46,  8.83batch/s][A
Batches:  41%|███████████████████████████▋                                       | 1035/2500 [01:49<02:47,  8.75batch/s][A
Batches:

Batches:  47%|███████████████████████████████▎                                   | 1169/2500 [02:04<02:42,  8.17batch/s][A
Batches:  47%|███████████████████████████████▎                                   | 1170/2500 [02:04<02:41,  8.24batch/s][A
Batches:  47%|███████████████████████████████▍                                   | 1171/2500 [02:04<02:37,  8.44batch/s][A
Batches:  47%|███████████████████████████████▍                                   | 1172/2500 [02:04<02:37,  8.45batch/s][A
Batches:  47%|███████████████████████████████▍                                   | 1173/2500 [02:04<02:32,  8.68batch/s][A
Batches:  47%|███████████████████████████████▍                                   | 1174/2500 [02:05<02:31,  8.73batch/s][A
Batches:  47%|███████████████████████████████▍                                   | 1175/2500 [02:05<02:26,  9.06batch/s][A
Batches:  47%|███████████████████████████████▌                                   | 1176/2500 [02:05<02:25,  9.12batch/s][A
Batches:

Batches:  52%|███████████████████████████████████                                | 1306/2500 [02:19<02:15,  8.84batch/s][A
Batches:  52%|███████████████████████████████████                                | 1307/2500 [02:19<02:10,  9.16batch/s][A
Batches:  52%|███████████████████████████████████                                | 1308/2500 [02:19<02:10,  9.15batch/s][A
Batches:  52%|███████████████████████████████████                                | 1309/2500 [02:20<02:15,  8.79batch/s][A
Batches:  52%|███████████████████████████████████                                | 1310/2500 [02:20<02:15,  8.76batch/s][A
Batches:  52%|███████████████████████████████████▏                               | 1311/2500 [02:20<02:10,  9.08batch/s][A
Batches:  52%|███████████████████████████████████▏                               | 1312/2500 [02:20<02:11,  9.03batch/s][A
Batches:  53%|███████████████████████████████████▏                               | 1313/2500 [02:20<02:10,  9.07batch/s][A
Batches:

Batches:  58%|██████████████████████████████████████▉                            | 1453/2500 [02:35<02:02,  8.52batch/s][A
Batches:  58%|██████████████████████████████████████▉                            | 1454/2500 [02:35<02:03,  8.46batch/s][A
Batches:  58%|██████████████████████████████████████▉                            | 1455/2500 [02:36<01:58,  8.82batch/s][A
Batches:  58%|███████████████████████████████████████                            | 1456/2500 [02:36<01:59,  8.71batch/s][A
Batches:  58%|███████████████████████████████████████                            | 1457/2500 [02:36<01:57,  8.85batch/s][A
Batches:  58%|███████████████████████████████████████                            | 1458/2500 [02:36<02:00,  8.64batch/s][A
Batches:  58%|███████████████████████████████████████                            | 1459/2500 [02:36<01:59,  8.68batch/s][A
Batches:  58%|███████████████████████████████████████▏                           | 1460/2500 [02:36<01:59,  8.72batch/s][A
Batches:

Batches:  64%|██████████████████████████████████████████▊                        | 1598/2500 [02:51<01:38,  9.18batch/s][A
Batches:  64%|██████████████████████████████████████████▊                        | 1599/2500 [02:51<01:37,  9.20batch/s][A
Batches:  64%|██████████████████████████████████████████▉                        | 1601/2500 [02:51<01:34,  9.47batch/s][A
Batches:  64%|██████████████████████████████████████████▉                        | 1602/2500 [02:52<01:34,  9.54batch/s][A
Batches:  64%|██████████████████████████████████████████▉                        | 1603/2500 [02:52<01:33,  9.58batch/s][A
Batches:  64%|██████████████████████████████████████████▉                        | 1604/2500 [02:52<01:37,  9.21batch/s][A
Batches:  64%|███████████████████████████████████████████                        | 1605/2500 [02:52<01:38,  9.11batch/s][A
Batches:  64%|███████████████████████████████████████████                        | 1606/2500 [02:52<01:40,  8.93batch/s][A
Batches:

Batches:  70%|██████████████████████████████████████████████▌                    | 1739/2500 [03:07<01:22,  9.19batch/s][A
Batches:  70%|██████████████████████████████████████████████▋                    | 1740/2500 [03:07<01:24,  9.01batch/s][A
Batches:  70%|██████████████████████████████████████████████▋                    | 1741/2500 [03:07<01:23,  9.08batch/s][A
Batches:  70%|██████████████████████████████████████████████▋                    | 1742/2500 [03:07<01:27,  8.66batch/s][A
Batches:  70%|██████████████████████████████████████████████▋                    | 1743/2500 [03:07<01:27,  8.66batch/s][A
Batches:  70%|██████████████████████████████████████████████▋                    | 1744/2500 [03:07<01:26,  8.79batch/s][A
Batches:  70%|██████████████████████████████████████████████▊                    | 1745/2500 [03:07<01:25,  8.83batch/s][A
Batches:  70%|██████████████████████████████████████████████▊                    | 1747/2500 [03:08<01:22,  9.11batch/s][A
Batches:

Batches:  75%|██████████████████████████████████████████████████▍                | 1881/2500 [03:23<01:10,  8.72batch/s][A
Batches:  75%|██████████████████████████████████████████████████▍                | 1882/2500 [03:23<01:10,  8.83batch/s][A
Batches:  75%|██████████████████████████████████████████████████▍                | 1883/2500 [03:23<01:11,  8.61batch/s][A
Batches:  75%|██████████████████████████████████████████████████▍                | 1884/2500 [03:23<01:13,  8.37batch/s][A
Batches:  75%|██████████████████████████████████████████████████▌                | 1886/2500 [03:23<01:09,  8.87batch/s][A
Batches:  75%|██████████████████████████████████████████████████▌                | 1887/2500 [03:23<01:10,  8.66batch/s][A
Batches:  76%|██████████████████████████████████████████████████▌                | 1888/2500 [03:23<01:11,  8.62batch/s][A
Batches:  76%|██████████████████████████████████████████████████▋                | 1889/2500 [03:24<01:09,  8.76batch/s][A
Batches:

Batches:  81%|██████████████████████████████████████████████████████▎            | 2025/2500 [03:38<00:52,  9.12batch/s][A
Batches:  81%|██████████████████████████████████████████████████████▎            | 2026/2500 [03:39<00:51,  9.19batch/s][A
Batches:  81%|██████████████████████████████████████████████████████▎            | 2027/2500 [03:39<00:51,  9.15batch/s][A
Batches:  81%|██████████████████████████████████████████████████████▎            | 2028/2500 [03:39<00:52,  9.07batch/s][A
Batches:  81%|██████████████████████████████████████████████████████▍            | 2029/2500 [03:39<00:52,  8.98batch/s][A
Batches:  81%|██████████████████████████████████████████████████████▍            | 2030/2500 [03:39<00:53,  8.84batch/s][A
Batches:  81%|██████████████████████████████████████████████████████▍            | 2031/2500 [03:39<00:53,  8.82batch/s][A
Batches:  81%|██████████████████████████████████████████████████████▍            | 2032/2500 [03:39<00:53,  8.75batch/s][A
Batches:

Batches:  87%|██████████████████████████████████████████████████████████▏        | 2170/2500 [03:54<00:37,  8.76batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▏        | 2171/2500 [03:55<00:37,  8.71batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▏        | 2172/2500 [03:55<00:37,  8.73batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▏        | 2173/2500 [03:55<00:37,  8.73batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▎        | 2174/2500 [03:55<00:36,  9.05batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▎        | 2176/2500 [03:55<00:34,  9.50batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▎        | 2177/2500 [03:55<00:34,  9.44batch/s][A
Batches:  87%|██████████████████████████████████████████████████████████▎        | 2178/2500 [03:55<00:34,  9.37batch/s][A
Batches:

Batches:  93%|██████████████████████████████████████████████████████████████▎    | 2324/2500 [04:11<00:18,  9.69batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▎    | 2325/2500 [04:11<00:18,  9.63batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▎    | 2326/2500 [04:11<00:18,  9.44batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▎    | 2327/2500 [04:11<00:18,  9.33batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▍    | 2328/2500 [04:11<00:18,  9.14batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▍    | 2329/2500 [04:12<00:19,  9.00batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▍    | 2330/2500 [04:12<00:18,  9.01batch/s][A
Batches:  93%|██████████████████████████████████████████████████████████████▍    | 2332/2500 [04:12<00:17,  9.47batch/s][A
Batches:

Batches:  99%|██████████████████████████████████████████████████████████████████▏| 2469/2500 [04:27<00:03,  9.01batch/s][A
Batches:  99%|██████████████████████████████████████████████████████████████████▏| 2470/2500 [04:27<00:03,  9.15batch/s][A
Batches:  99%|██████████████████████████████████████████████████████████████████▏| 2471/2500 [04:27<00:03,  9.09batch/s][A
Batches:  99%|██████████████████████████████████████████████████████████████████▎| 2473/2500 [04:27<00:02,  9.34batch/s][A
Batches:  99%|██████████████████████████████████████████████████████████████████▎| 2474/2500 [04:27<00:02,  9.42batch/s][A
Batches:  99%|██████████████████████████████████████████████████████████████████▎| 2476/2500 [04:28<00:02,  9.72batch/s][A
Batches:  99%|██████████████████████████████████████████████████████████████████▍| 2477/2500 [04:28<00:02,  9.47batch/s][A
Batches:  99%|██████████████████████████████████████████████████████████████████▍| 2478/2500 [04:28<00:02,  9.18batch/s][A
Batches:

Iter 1:, acc test 10.3% (nan)



Batches:   0%|                                                                              | 0/2500 [00:00<?, ?batch/s][A
Batches:   0%|                                                                      | 1/2500 [00:00<05:09,  8.07batch/s][A
Batches:   0%|                                                                      | 2/2500 [00:00<04:39,  8.95batch/s][A
Batches:   0%|                                                                      | 3/2500 [00:00<04:30,  9.22batch/s][A
Batches:   0%|                                                                      | 4/2500 [00:00<04:45,  8.75batch/s][A
Batches:   0%|▏                                                                     | 5/2500 [00:00<04:33,  9.12batch/s][A
Batches:   0%|▏                                                                     | 7/2500 [00:00<04:26,  9.34batch/s][A
Batches:   0%|▏                                                                     | 8/2500 [00:00<04:27,  9.32batch/s][A
Batches

KeyboardInterrupt: 

# EBHI-SEG Dataset

In [37]:
# Custom dataset to apply Albumentations transformations
class CustomDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]  # Image is a PIL Image
        image = np.array(image)  # Convert to NumPy array
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label
# Load the dataset
train_transform = A.Compose([
            A.ElasticTransform(alpha=1.0, sigma=50, p=0.25),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.25),
            A.Rotate(limit=35, p=0.7),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.GaussianBlur(blur_limit=(3, 7), p=0.2),
            ToTensorV2(),
        ])
dataset_path = "./EBHI-SEG-Class/"
base_dataset = ImageFolder(root=dataset_path)
custom_dataset = CustomDataset(base_dataset, transform=train_transform)

# Générer des indices aléatoires pour mélanger les données
indices = np.random.permutation(len(custom_dataset))

# Diviser en 80% entraînement et 20% test
train_indices = indices[:int(0.8 * len(indices))]
test_indices = indices[int(0.8 * len(indices)):]

# Créer des sous-ensembles
train_dataset = Subset(custom_dataset, train_indices)
test_dataset = Subset(custom_dataset, test_indices)

# Créer des DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [34]:
nx=1000
nh_1=100
nh_2=100
ny=10
eta=0.1
n_iter=30
loss = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
model.train()

# Planificateur StepLR : Réduction du LR tous les 10 epochs par un facteur 0.1
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
batch_size=32

# epoch
for iteration in tqdm(range(n_iter), desc="Epochs", unit="epoch"):
    
        train_transform = A.Compose([
            A.ElasticTransform(alpha=1.0, sigma=50, p=0.25),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.25),
            A.Rotate(limit=35, p=0.7),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.GaussianBlur(blur_limit=(3, 7), p=0.2),
            ToTensorV2(),
        ])
        
        val_transforms = A.Compose(
        [ 
            ToTensorV2(),
        ],
        )
        
        # Recréer les datasets PyTorch
        custom_dataset = CustomDataset(base_dataset, transform=train_transform)
    

        # Créer des sous-ensembles
        train_dataset = Subset(custom_dataset, train_indices)
        test_dataset = Subset(custom_dataset, test_indices)

        # Créer des DataLoaders
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
        
        
    
        for X_batch, Y_batch in tqdm(train_loader, desc="Batches", unit="batch", leave=False):
                X_batch, Y_batch = X_batch.float().to(device), Y_batch.to(device) 
                optimizer.zero_grad()
                Yhat = model(X_batch)  # Prédictions pour le batch actuel
                L, acc = loss_accuracy(loss, Yhat, Y_batch)  # Calcul de la perte et de la précision
                L.backward()
                optimizer.step()
        scheduler.step()
    
        
        Ltest, acctest = evaluate_model(model, test_loader, loss, device)
    
        title = "Iter {}:, acc test {:.1f}% ({:.2f})".format(iteration, acctest, Ltest)

        
    
        print(title)

Epochs:   0%|                                                                                 | 0/30 [00:00<?, ?epoch/s]
Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:14,  3.71batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:11,  4.66batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:10,  5.28batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  5.81batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  5.93batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:01<00:08,  6.10batch/s][A
Batches:  1

Iter 0:, acc test 63.7% (1.09)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.47batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:09,  5.97batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.22batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.49batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.42batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.37batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.36batch/s][A
Batches

Iter 1:, acc test 77.6% (0.69)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:07,  7.40batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.69batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.78batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.55batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.39batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.30batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.40batch/s][A
Batches

Iter 2:, acc test 74.7% (0.74)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:07,  7.21batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:07,  6.94batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.89batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.78batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.88batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.70batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.78batch/s][A
Batches

Iter 3:, acc test 78.9% (0.61)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.66batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:07,  7.02batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.72batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.54batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.63batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.68batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.57batch/s][A
Batches

Iter 4:, acc test 79.1% (0.57)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.13batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.13batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.45batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.33batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.38batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.42batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.55batch/s][A
Batches

Iter 5:, acc test 77.1% (0.60)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.35batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.48batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.35batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.47batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.56batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.40batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.30batch/s][A
Batches

Iter 6:, acc test 82.1% (0.51)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.51batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.49batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.56batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.56batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.33batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.30batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.29batch/s][A
Batches

Iter 7:, acc test 80.9% (0.53)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.86batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.43batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.42batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.33batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.22batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.41batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.60batch/s][A
Batches

Iter 8:, acc test 80.0% (0.52)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.44batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.63batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.56batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.45batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.71batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.55batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.46batch/s][A
Batches

Iter 9:, acc test 81.4% (0.48)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.44batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:09,  5.87batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  5.95batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  5.93batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.28batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.31batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.20batch/s][A
Batches

Iter 10:, acc test 80.9% (0.47)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.26batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.31batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.37batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.21batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.22batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:08,  6.21batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.15batch/s][A
Batches

Iter 11:, acc test 80.3% (0.52)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.68batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:07,  6.85batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.47batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.43batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.36batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.39batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.25batch/s][A
Batches

Iter 12:, acc test 80.5% (0.50)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:09,  5.73batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:09,  5.88batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:09,  5.85batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.01batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  5.84batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:01<00:08,  5.95batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:08,  5.80batch/s][A
Batches

Iter 13:, acc test 81.8% (0.47)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.52batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.50batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.68batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.77batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.60batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.64batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.63batch/s][A
Batches

Iter 14:, acc test 83.4% (0.47)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.19batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.59batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.26batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.44batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.33batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.29batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.34batch/s][A
Batches

Iter 15:, acc test 84.8% (0.41)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:07,  7.03batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:07,  6.89batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.69batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.55batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.24batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.42batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.46batch/s][A
Batches

Iter 16:, acc test 83.4% (0.45)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:07,  7.09batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.56batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.47batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.55batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.32batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.28batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.23batch/s][A
Batches

Iter 17:, acc test 82.1% (0.45)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.19batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.46batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.44batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.40batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.46batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.44batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.38batch/s][A
Batches

Iter 18:, acc test 83.4% (0.44)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:09,  6.03batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.45batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.70batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.51batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.40batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.33batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.35batch/s][A
Batches

Iter 19:, acc test 84.1% (0.44)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:09,  5.65batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:10,  5.31batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:09,  5.39batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:09,  5.31batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:09,  5.44batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:01<00:09,  5.50batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:08,  5.54batch/s][A
Batches

Iter 20:, acc test 82.5% (0.44)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.42batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.66batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.44batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.57batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.60batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.69batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.72batch/s][A
Batches

Iter 21:, acc test 81.6% (0.45)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:09,  6.05batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.35batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.77batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.50batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.45batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.39batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.47batch/s][A
Batches

Iter 22:, acc test 84.3% (0.41)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.76batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.69batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.69batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.52batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.62batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.49batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.43batch/s][A
Batches

Iter 23:, acc test 84.1% (0.40)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.75batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:07,  6.87batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:07,  6.64batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.60batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.52batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.49batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.36batch/s][A
Batches

Iter 24:, acc test 82.5% (0.45)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.46batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.30batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.18batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  5.99batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  6.22batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.44batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.27batch/s][A
Batches

Iter 25:, acc test 84.8% (0.41)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:08,  6.67batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.31batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.18batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.41batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.46batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.48batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.35batch/s][A
Batches

Iter 26:, acc test 83.9% (0.42)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:09,  5.76batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:09,  5.85batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:09,  5.72batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:09,  5.72batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:08,  5.72batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:01<00:08,  5.57batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:08,  5.76batch/s][A
Batches

Iter 27:, acc test 81.2% (0.45)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:09,  5.92batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.20batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.27batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:08,  6.29batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.42batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:07,  6.45batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.62batch/s][A
Batches

Iter 28:, acc test 84.8% (0.39)



Batches:   0%|                                                                                | 0/56 [00:00<?, ?batch/s][A
Batches:   2%|█▎                                                                      | 1/56 [00:00<00:07,  6.97batch/s][A
Batches:   4%|██▌                                                                     | 2/56 [00:00<00:08,  6.61batch/s][A
Batches:   5%|███▊                                                                    | 3/56 [00:00<00:08,  6.40batch/s][A
Batches:   7%|█████▏                                                                  | 4/56 [00:00<00:07,  6.61batch/s][A
Batches:   9%|██████▍                                                                 | 5/56 [00:00<00:07,  6.44batch/s][A
Batches:  11%|███████▋                                                                | 6/56 [00:00<00:08,  6.23batch/s][A
Batches:  12%|█████████                                                               | 7/56 [00:01<00:07,  6.21batch/s][A
Batches

Iter 29:, acc test 85.2% (0.40)





# SVM

In [50]:
# Custom dataset to apply Albumentations transformations
class CustomDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]  # Image is a PIL Image
        image = np.array(image)  # Convert to NumPy array
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label
# Load the dataset
train_transform = A.Compose([
            ToTensorV2(),
        ])
dataset_path = "./EBHI-SEG-Class/"
base_dataset = ImageFolder(root=dataset_path)
custom_dataset = CustomDataset(base_dataset, transform=train_transform)

# Générer des indices aléatoires pour mélanger les données
indices = np.random.permutation(len(custom_dataset))

# Diviser en 80% entraînement et 20% test
train_indices = indices[:int(0.8 * len(indices))]
test_indices = indices[int(0.8 * len(indices)):]

# Créer des sous-ensembles
train_dataset = Subset(custom_dataset, train_indices)
test_dataset = Subset(custom_dataset, test_indices)

# Créer des DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [51]:
# Charger ResNet-34 pré-entraîné
model = models.resnet34(pretrained=False)

# Chemin vers votre fichier de poids
weights_path = "resnet34_cerberus_torchvision .pth"

# Charger les poids sauvegardés
state_dict = torch.load(weights_path)
model.load_state_dict(state_dict)

  state_dict = torch.load(weights_path)


<All keys matched successfully>

In [52]:
PRINT_INTERVAL = 50
def extract_features_Res(data, model):
    #####################
    ## Initialization   ##
    #####################
    # Initialize feature and label storage
    X = []
    y = []

    #####################
    ## Feature Extraction ##
    #####################
    for i, (input, target) in enumerate(data):
        # Print batch progress
        if i % PRINT_INTERVAL == 0:
            print('Batch {0:03d}/{1:03d}'.format(i, len(data)))



        # Extract features
        with torch.no_grad():  # Disable gradients for efficiency
            x = model(input.float())

        # Detach from computation graph and move to CPU if needed
        x = x.detach().cpu().numpy()
        target = target.cpu().numpy()

        # Append the features and labels
        X.append(x)
        y.append(target)

    # Concatenate all batches into single arrays and convert to tensor
    X = torch.tensor(np.concatenate(X, axis=0))  # X as a PyTorch tensor
    y = torch.tensor(np.concatenate(y, axis=0))  # y as a PyTorch tensor

    return X, y

In [53]:
model.eval()
    # Extraction des features
print('Feature extraction')
X_train, y_train = extract_features_Res(train_loader, model)
X_test, y_test = extract_features_Res(test_loader, model)

Feature extraction
Batch 000/056
Batch 050/056
Batch 000/014


In [54]:
print('Apprentissage des SVM')
accuracy = 0
svm = LinearSVC(C=1.0).fit(X_train, y_train)
y_hat = svm.predict(X_test)
y_test = y_test.cpu().numpy() if isinstance(y_test, torch.Tensor) else y_test
accuracy = accuracy_score(y_test, y_hat)
print(f"Accuracy = {round(accuracy*100,2)} %")

Apprentissage des SVM
Accuracy = 59.19 %
