In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset, random_split, WeightedRandomSampler
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import numpy as np
import os
from google.colab import drive
from collections import Counter

In [2]:
drive.mount('/content/drive')
import sys
sys.path.append('/content/drive/MyDrive/TR_DIMA/Hybrid-SC')

from model import ResNet50DualBranch, CombinedPSCLoss

Mounted at /content/drive


In [3]:
# Path to the dataset and model
path_train_dataset = "/content/drive/MyDrive/TR_DIMA/training_set_reduit"

base_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),  # Convert images to PyTorch tensors
])

augmented_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.GaussianBlur(kernel_size=9, sigma=(0.01, 5)),
        transforms.RandomResizedCrop(size=224, scale=(0.7, 1.0), ratio = (0.2,5)),
        transforms.ToTensor(),
    ])



dataset_init = datasets.ImageFolder(root=path_train_dataset, transform=augmented_transform)

train_size = int(0.8 * len(dataset_init))
valid_size = len(dataset_init) - train_size

train_dataset, valid_dataset = random_split(dataset_init, [train_size, valid_size])



In [4]:
# Sampler for class balancing

train_indices = train_dataset.indices
train_labels = [dataset_init.samples[i][1] for i in train_indices]

class_counts = Counter(train_labels)
class_weights = {label: 1.0 / count for label, count in class_counts.items()}
sample_weights = [class_weights[label] for label in train_labels]

sampler = WeightedRandomSampler(weights=sample_weights, num_samples= len(sample_weights), replacement=True)

# Create DataLoaders for training and validation datasets

train_loader = DataLoader(train_dataset, batch_size=56, sampler=sampler, num_workers=16, pin_memory=True)
val_loader = DataLoader(valid_dataset, batch_size=56, shuffle=False, num_workers=16, pin_memory=True)

valid_classes = dataset_init.classes

#Initialisation model

model = ResNet50DualBranch(num_classes=len(valid_classes), proj_dim=128, hidden_dim=2048).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

#Loss function

criterion = CombinedPSCLoss(alpha=1.0, temperature=0.1).cuda()

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 188MB/s]


In [7]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100):

    for epoch in range(num_epochs):
        model.train()
        alpha = 1.0 - (epoch / num_epochs)**2  # Linear decay of alpha
        running_loss = 0.0
        correct = 0
        total = 0
        criterion = CombinedPSCLoss(alpha=alpha, temperature=0.1).cuda()

        for images, labels in tqdm(train_loader):
            images, labels = images.cuda(), labels.cuda()

            optimizer.zero_grad()
            out_cls, out_proj, proto = model(x_cls=images, x_proj=images)
            loss = criterion(out_cls, out_proj, labels, proto)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(out_cls.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.cuda(), labels.cuda()
                out_cls, out_proj, proto = model(x_cls=images, x_proj=images)
                loss = criterion(out_cls, out_proj, labels, proto)
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(out_cls.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_epoch_loss = val_loss / len(val_loader.dataset)
        val_epoch_acc = val_correct / val_total
        print(f"Validation Loss: {val_epoch_loss:.4f}, Validation Accuracy: {val_epoch_acc:.4f}")



    return model

In [9]:
model_trained=train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100)

100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [1/100], Loss: -7.8846, Accuracy: 0.9749





Validation Loss: -7.0856, Validation Accuracy: 0.9345


100%|██████████| 692/692 [01:40<00:00,  6.85it/s]

Epoch [2/100], Loss: -7.9009, Accuracy: 0.9753





Validation Loss: -6.8085, Validation Accuracy: 0.9250


100%|██████████| 692/692 [01:40<00:00,  6.85it/s]

Epoch [3/100], Loss: -7.9136, Accuracy: 0.9765





Validation Loss: -6.9272, Validation Accuracy: 0.9268


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [4/100], Loss: -7.9037, Accuracy: 0.9756





Validation Loss: -7.1594, Validation Accuracy: 0.9423


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [5/100], Loss: -7.9174, Accuracy: 0.9772





Validation Loss: -6.9330, Validation Accuracy: 0.9345


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [6/100], Loss: -7.9040, Accuracy: 0.9766





Validation Loss: -6.6398, Validation Accuracy: 0.9274


100%|██████████| 692/692 [01:40<00:00,  6.90it/s]

Epoch [7/100], Loss: -7.9128, Accuracy: 0.9766





Validation Loss: -7.0568, Validation Accuracy: 0.9389


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [8/100], Loss: -7.9282, Accuracy: 0.9786





Validation Loss: -6.7805, Validation Accuracy: 0.9297


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [9/100], Loss: -7.9183, Accuracy: 0.9783





Validation Loss: -6.9551, Validation Accuracy: 0.9372


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [10/100], Loss: -7.9143, Accuracy: 0.9788





Validation Loss: -6.8485, Validation Accuracy: 0.9330


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [11/100], Loss: -7.9295, Accuracy: 0.9806





Validation Loss: -6.9841, Validation Accuracy: 0.9365


100%|██████████| 692/692 [01:40<00:00,  6.90it/s]

Epoch [12/100], Loss: -7.9018, Accuracy: 0.9803





Validation Loss: -6.9155, Validation Accuracy: 0.9335


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [13/100], Loss: -7.9012, Accuracy: 0.9813





Validation Loss: -6.9015, Validation Accuracy: 0.9341


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [14/100], Loss: -7.8544, Accuracy: 0.9794





Validation Loss: -6.7179, Validation Accuracy: 0.9236


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [15/100], Loss: -7.8804, Accuracy: 0.9830





Validation Loss: -6.6860, Validation Accuracy: 0.9220


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [16/100], Loss: -7.8302, Accuracy: 0.9806





Validation Loss: -6.4370, Validation Accuracy: 0.9157


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]


Epoch [17/100], Loss: -7.8102, Accuracy: 0.9815
Validation Loss: -6.8880, Validation Accuracy: 0.9362


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [18/100], Loss: -7.8160, Accuracy: 0.9836





Validation Loss: -6.6193, Validation Accuracy: 0.9275


100%|██████████| 692/692 [01:41<00:00,  6.85it/s]

Epoch [19/100], Loss: -7.7582, Accuracy: 0.9822





Validation Loss: -6.8173, Validation Accuracy: 0.9410


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [20/100], Loss: -7.7474, Accuracy: 0.9828





Validation Loss: -6.7192, Validation Accuracy: 0.9309


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [21/100], Loss: -7.7437, Accuracy: 0.9844





Validation Loss: -6.7903, Validation Accuracy: 0.9346


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [22/100], Loss: -7.6834, Accuracy: 0.9831





Validation Loss: -6.6771, Validation Accuracy: 0.9358


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [23/100], Loss: -7.6553, Accuracy: 0.9833





Validation Loss: -6.7293, Validation Accuracy: 0.9414


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [24/100], Loss: -7.6295, Accuracy: 0.9842





Validation Loss: -6.9221, Validation Accuracy: 0.9485


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [25/100], Loss: -7.6034, Accuracy: 0.9848





Validation Loss: -6.7955, Validation Accuracy: 0.9409


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [26/100], Loss: -7.5552, Accuracy: 0.9843





Validation Loss: -6.7448, Validation Accuracy: 0.9409


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [27/100], Loss: -7.5353, Accuracy: 0.9856





Validation Loss: -6.6639, Validation Accuracy: 0.9386


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [28/100], Loss: -7.4725, Accuracy: 0.9846





Validation Loss: -6.3048, Validation Accuracy: 0.9298


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [29/100], Loss: -7.4322, Accuracy: 0.9844





Validation Loss: -6.7053, Validation Accuracy: 0.9452


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [30/100], Loss: -7.4029, Accuracy: 0.9860





Validation Loss: -6.4885, Validation Accuracy: 0.9385


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [31/100], Loss: -7.3621, Accuracy: 0.9862





Validation Loss: -6.3440, Validation Accuracy: 0.9334


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [32/100], Loss: -7.3173, Accuracy: 0.9861





Validation Loss: -6.4749, Validation Accuracy: 0.9420


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [33/100], Loss: -7.2610, Accuracy: 0.9858





Validation Loss: -6.5263, Validation Accuracy: 0.9488


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [34/100], Loss: -7.2172, Accuracy: 0.9867





Validation Loss: -6.3855, Validation Accuracy: 0.9440


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [35/100], Loss: -7.1620, Accuracy: 0.9863





Validation Loss: -6.0980, Validation Accuracy: 0.9293


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [36/100], Loss: -7.0991, Accuracy: 0.9868





Validation Loss: -6.3491, Validation Accuracy: 0.9491


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [37/100], Loss: -7.0470, Accuracy: 0.9872





Validation Loss: -6.2790, Validation Accuracy: 0.9444


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [38/100], Loss: -6.9870, Accuracy: 0.9864





Validation Loss: -6.2977, Validation Accuracy: 0.9515


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [39/100], Loss: -6.9300, Accuracy: 0.9872





Validation Loss: -6.2014, Validation Accuracy: 0.9452


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [40/100], Loss: -6.8638, Accuracy: 0.9865





Validation Loss: -6.0392, Validation Accuracy: 0.9471


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [41/100], Loss: -6.8136, Accuracy: 0.9875





Validation Loss: -6.1359, Validation Accuracy: 0.9517


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [42/100], Loss: -6.7357, Accuracy: 0.9863





Validation Loss: -5.8775, Validation Accuracy: 0.9416


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [43/100], Loss: -6.6893, Accuracy: 0.9884





Validation Loss: -5.9583, Validation Accuracy: 0.9480


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [44/100], Loss: -6.6165, Accuracy: 0.9881





Validation Loss: -5.7946, Validation Accuracy: 0.9405


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [45/100], Loss: -6.5363, Accuracy: 0.9876





Validation Loss: -5.9370, Validation Accuracy: 0.9498


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [46/100], Loss: -6.4836, Accuracy: 0.9885





Validation Loss: -5.6347, Validation Accuracy: 0.9371


100%|██████████| 692/692 [01:41<00:00,  6.84it/s]

Epoch [47/100], Loss: -6.3949, Accuracy: 0.9879





Validation Loss: -5.7172, Validation Accuracy: 0.9451


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [48/100], Loss: -6.3333, Accuracy: 0.9883





Validation Loss: -5.7317, Validation Accuracy: 0.9539


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [49/100], Loss: -6.2451, Accuracy: 0.9883





Validation Loss: -5.5004, Validation Accuracy: 0.9417


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [50/100], Loss: -6.1772, Accuracy: 0.9885





Validation Loss: -5.3705, Validation Accuracy: 0.9414


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [51/100], Loss: -6.0974, Accuracy: 0.9887





Validation Loss: -5.4697, Validation Accuracy: 0.9499


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [52/100], Loss: -6.0185, Accuracy: 0.9890





Validation Loss: -5.3134, Validation Accuracy: 0.9470


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [53/100], Loss: -5.9315, Accuracy: 0.9889





Validation Loss: -5.3319, Validation Accuracy: 0.9476


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [54/100], Loss: -5.8316, Accuracy: 0.9878





Validation Loss: -5.2857, Validation Accuracy: 0.9539


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [55/100], Loss: -5.7541, Accuracy: 0.9894





Validation Loss: -5.0893, Validation Accuracy: 0.9453


100%|██████████| 692/692 [01:41<00:00,  6.85it/s]

Epoch [56/100], Loss: -5.6748, Accuracy: 0.9895





Validation Loss: -4.9717, Validation Accuracy: 0.9413


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [57/100], Loss: -5.5789, Accuracy: 0.9892





Validation Loss: -5.0158, Validation Accuracy: 0.9527


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [58/100], Loss: -5.4974, Accuracy: 0.9898





Validation Loss: -4.8569, Validation Accuracy: 0.9494


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [59/100], Loss: -5.4023, Accuracy: 0.9893





Validation Loss: -4.8221, Validation Accuracy: 0.9502


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [60/100], Loss: -5.3170, Accuracy: 0.9902





Validation Loss: -4.5680, Validation Accuracy: 0.9402


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [61/100], Loss: -5.1960, Accuracy: 0.9888





Validation Loss: -4.5778, Validation Accuracy: 0.9455


100%|██████████| 692/692 [01:40<00:00,  6.85it/s]

Epoch [62/100], Loss: -5.1162, Accuracy: 0.9901





Validation Loss: -4.4690, Validation Accuracy: 0.9451


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [63/100], Loss: -5.0059, Accuracy: 0.9901





Validation Loss: -4.4847, Validation Accuracy: 0.9556


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [64/100], Loss: -4.9151, Accuracy: 0.9905





Validation Loss: -4.1485, Validation Accuracy: 0.9372


100%|██████████| 692/692 [01:41<00:00,  6.85it/s]

Epoch [65/100], Loss: -4.8086, Accuracy: 0.9908





Validation Loss: -4.1917, Validation Accuracy: 0.9491


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [66/100], Loss: -4.7063, Accuracy: 0.9906





Validation Loss: -4.1449, Validation Accuracy: 0.9514


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [67/100], Loss: -4.6033, Accuracy: 0.9912





Validation Loss: -3.9983, Validation Accuracy: 0.9474


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [68/100], Loss: -4.4862, Accuracy: 0.9903





Validation Loss: -4.0722, Validation Accuracy: 0.9571


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [69/100], Loss: -4.3710, Accuracy: 0.9900





Validation Loss: -3.8999, Validation Accuracy: 0.9555


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [70/100], Loss: -4.2739, Accuracy: 0.9914





Validation Loss: -3.8592, Validation Accuracy: 0.9588


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [71/100], Loss: -4.1592, Accuracy: 0.9913





Validation Loss: -3.6876, Validation Accuracy: 0.9530


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [72/100], Loss: -4.0359, Accuracy: 0.9906





Validation Loss: -3.4162, Validation Accuracy: 0.9385


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [73/100], Loss: -3.9089, Accuracy: 0.9902





Validation Loss: -3.3789, Validation Accuracy: 0.9443


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [74/100], Loss: -3.8077, Accuracy: 0.9910





Validation Loss: -3.3387, Validation Accuracy: 0.9528


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [75/100], Loss: -3.6757, Accuracy: 0.9906





Validation Loss: -3.1788, Validation Accuracy: 0.9437


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [76/100], Loss: -3.5630, Accuracy: 0.9909





Validation Loss: -3.1236, Validation Accuracy: 0.9519


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [77/100], Loss: -3.4451, Accuracy: 0.9919





Validation Loss: -2.9382, Validation Accuracy: 0.9479


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [78/100], Loss: -3.3156, Accuracy: 0.9914





Validation Loss: -2.8683, Validation Accuracy: 0.9483


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [79/100], Loss: -3.1792, Accuracy: 0.9903





Validation Loss: -2.7697, Validation Accuracy: 0.9474


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [80/100], Loss: -3.0638, Accuracy: 0.9918





Validation Loss: -2.6802, Validation Accuracy: 0.9521


100%|██████████| 692/692 [01:40<00:00,  6.89it/s]

Epoch [81/100], Loss: -2.9323, Accuracy: 0.9919





Validation Loss: -2.5982, Validation Accuracy: 0.9584


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [82/100], Loss: -2.7917, Accuracy: 0.9914





Validation Loss: -2.4097, Validation Accuracy: 0.9514


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [83/100], Loss: -2.6620, Accuracy: 0.9915





Validation Loss: -2.1898, Validation Accuracy: 0.9426


100%|██████████| 692/692 [01:40<00:00,  6.85it/s]

Epoch [84/100], Loss: -2.5311, Accuracy: 0.9922





Validation Loss: -2.1888, Validation Accuracy: 0.9571


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [85/100], Loss: -2.3840, Accuracy: 0.9912





Validation Loss: -2.0809, Validation Accuracy: 0.9572


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [86/100], Loss: -2.2552, Accuracy: 0.9924





Validation Loss: -1.9562, Validation Accuracy: 0.9568


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [87/100], Loss: -2.1154, Accuracy: 0.9919





Validation Loss: -1.7980, Validation Accuracy: 0.9564


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [88/100], Loss: -1.9797, Accuracy: 0.9936





Validation Loss: -1.6908, Validation Accuracy: 0.9567


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [89/100], Loss: -1.8319, Accuracy: 0.9928





Validation Loss: -1.5501, Validation Accuracy: 0.9540


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [90/100], Loss: -1.6849, Accuracy: 0.9927





Validation Loss: -1.3867, Validation Accuracy: 0.9526


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [91/100], Loss: -1.5376, Accuracy: 0.9922





Validation Loss: -1.2710, Validation Accuracy: 0.9525


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [92/100], Loss: -1.3903, Accuracy: 0.9928





Validation Loss: -1.1823, Validation Accuracy: 0.9628


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [93/100], Loss: -1.2411, Accuracy: 0.9930





Validation Loss: -1.0147, Validation Accuracy: 0.9587


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [94/100], Loss: -1.0900, Accuracy: 0.9935





Validation Loss: -0.8405, Validation Accuracy: 0.9524


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [95/100], Loss: -0.9383, Accuracy: 0.9942





Validation Loss: -0.7391, Validation Accuracy: 0.9588


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [96/100], Loss: -0.7810, Accuracy: 0.9937





Validation Loss: -0.5873, Validation Accuracy: 0.9572


100%|██████████| 692/692 [01:40<00:00,  6.86it/s]

Epoch [97/100], Loss: -0.6224, Accuracy: 0.9932





Validation Loss: -0.4157, Validation Accuracy: 0.9554


100%|██████████| 692/692 [01:40<00:00,  6.88it/s]

Epoch [98/100], Loss: -0.4671, Accuracy: 0.9941





Validation Loss: -0.2408, Validation Accuracy: 0.9479


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [99/100], Loss: -0.3056, Accuracy: 0.9943





Validation Loss: -0.1615, Validation Accuracy: 0.9623


100%|██████████| 692/692 [01:40<00:00,  6.87it/s]

Epoch [100/100], Loss: -0.1428, Accuracy: 0.9941





Validation Loss: 0.0023, Validation Accuracy: 0.9614


In [10]:
torch.save(model.state_dict(), "/content/drive/MyDrive/TR_DIMA/Entrainement/hybridsc.pth")


In [12]:
backbone_classifier_state = {
    'backbone': model.backbone.state_dict(),
    'classifier': model.classifier.state_dict()
}

torch.save(backbone_classifier_state, "/content/drive/MyDrive/TR_DIMA/Entrainement/hybridsc_V2.pth")