# ----------------------------------------------- #
# DEBUT
# ----------------------------------------------- #

### Diviser les données en ensembles d'entraînement, de validation, et de test

In [1]:
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, AdamW
from torch.utils.tensorboard import SummaryWriter
from torch import nn
import torch
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Détermination du device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

# Chemin et transformations
dataset_path = "/datasets/rakuten-images/images/"
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
full_dataset = datasets.ImageFolder(dataset_path, transform=transform)


# Split
train_size = int(0.7 * len(full_dataset))
test_size = int(0.15 * len(full_dataset))
val_size = len(full_dataset) - train_size - test_size
train_dataset, remaining_dataset = random_split(full_dataset, [train_size, len(full_dataset) - train_size])
test_dataset, val_dataset = random_split(remaining_dataset, [test_size, val_size])

# Loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


Using cuda device


### Chargement du modèle ViT

In [2]:
# Modèle, optimizer, early stopping
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=27)
model.to(device) 

Downloading config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/330M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_

### Configuration de l'entrainement

In [3]:
# Checkpointing
def save_checkpoint(model, optimizer, filename='checkpoint.pth.tar'):
    checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, filename)
    
# Early stopping
class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

# Evaluation
def evaluate(model, dataloader):
    model.eval()
    all_preds, all_labels = [], []
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs.logits, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    return f1_score(all_labels, all_preds, average='weighted')



### Entrainement

In [None]:
# Optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)

#Early Stopping
early_stopping = EarlyStopping(patience=3, verbose=True)

# suivi entrainement
writer = SummaryWriter()

# Training loop
num_epochs = 10
best_f1 = 0
writer = SummaryWriter()

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = nn.CrossEntropyLoss()(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    val_f1 = evaluate(model, val_loader)
    writer.add_scalar('Training Loss', running_loss / len(train_loader), epoch)
    writer.add_scalar('Validation F1', val_f1, epoch)
    print(f'Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}, Validation F1: {val_f1}')

    if val_f1 > best_f1:
        best_f1 = val_f1
        save_checkpoint(model, optimizer, filename=f'checkpoint_epoch_{epoch}.pth.tar')
        print("Checkpoint saved")

    early_stopping(-val_f1, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

writer.close()





Epoch 1, Loss: 2.3137997477380527, Validation F1: 0.4531042035314736
Checkpoint saved
Epoch 2, Loss: 1.6319120150656952, Validation F1: 0.5477437822046588
Checkpoint saved
Epoch 3, Loss: 1.32934395559894, Validation F1: 0.5901230945035656
Checkpoint saved


### Evaluation 

In [4]:
# Load the checkpoint
#checkpoint = torch.load('checkpoint_epoch_2.pth.tar')
#model.load_state_dict(checkpoint['state_dict'])

# Ensure the model is in evaluation mode
#model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_

In [None]:
# Test Evaluation
test_f1 = evaluate(model, test_loader)
print(f'Test Weighted F1 Score: {test_f1}')

# ----------------------------------------------- #
# FIN
# ----------------------------------------------- #