ViT-based model using Transfer-Learning
---
#### Model: google/vit-large-patch16-224, descongelando las últimas 3 capas del encoder
#### Epochs: 30
#### Dataset: images_3categories_balanced
#### Cambios:
- DropOut:
    - Clasificador 0
    - HiddenLayer 0
    - AttentionLayer 0
- Learning Rate 3e-4     

In [1]:
# Parámetros
_model = 'google/vit-base-patch16-224'
# path al checkpoint a cargar, None si no existe
_checkpoint = None
_output = 'SavedModels/Mass_ViT-large-patch16-224_A'  # path para guardar el modelo

_dataset = '/Users/julio/Documentos-Local/data/VinDr-Mammo/subsets/Masses_3categories'

# path para guardar el dataset con split
_dataset_split_path = '/Users/julio/Documentos-Local/data/VinDr-Mammo/subsets/Masses_3categories_split'  

# Si el dataset ya está separado en train, validation y test ->  _dataset_split=_dataset_split_path. 
# Si no está separado -> _dataset_split=None.
_dataset_split = _dataset_split_path  

_batch_size = 16
_learning_rate = 3e-4
_epochs = 30  

# DropOut
_dp_clasificador = 0.0
_dp_hidden_layer = 0.0
_dp_attention_layer= 0.0

num_layers_to_unfreeze = 5  # Definir el número de capas a descongelar, None eoc

In [2]:
import torch
import torch.nn as nn
import pandas as pd
from datasets import load_dataset, DatasetDict, load_from_disk

from transformers import AutoImageProcessor, ViTForImageClassification

from transformers import Trainer, TrainingArguments

import evaluate

from Utils import *

## Carga de datos

In [3]:
if _dataset_split is None:
    dataset = load_dataset(_dataset)
else:
    # Cargar el dataset previamente guardado
    dataset = load_from_disk(_dataset_split_path)

dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 2001
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 250
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 251
    })
})

### Revisión de categorías

In [4]:
labels = dataset['train'].features['label'].names
print(len(labels),labels)

label2id = {c:idx for idx,c in enumerate(labels)}
id2label = {idx:c for idx,c in enumerate(labels)}

3 ['benigno', 'maligno', 'sospechoso']


## Muestra de ejemplos

In [5]:
def show_samples(ds,rows,cols):
    samples = ds.shuffle().select(np.arange(rows*cols)) # selecting random images
    fig = plt.figure(figsize=(cols*4,rows*4))
    # plotting
    for i in range(rows*cols):
        img = samples[i]['image']
        label = samples[i]['label']
        fig.add_subplot(rows,cols,i+1)
        plt.imshow(img, cmap='gray')
        plt.title(label)
        plt.axis('off')
            
# show_samples(dataset['train'],rows=3,cols=5)

## Split Dataset

In [6]:
if _dataset_split is None:
    split_dataset = dataset['train'].train_test_split(test_size=0.2)
    eval_dataset = split_dataset['test'].train_test_split(test_size=0.5)
    
    
    # Recombinar los splits 
    
    final_dataset = DatasetDict({
        'train': split_dataset['train'],
        'validation': eval_dataset['train'],
        'test': eval_dataset['test']
    })
    # Guardar el dataset dividido
    final_dataset.save_to_disk(_dataset_split_path)

else:
    final_dataset = dataset
final_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 2001
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 250
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 251
    })
})

In [7]:
print('Número de imágenes por clases en cada split')
clases_split = pd.DataFrame(columns=['split', 'benigno', 'maligno', 'sospechoso'])
for key in final_dataset:
    split = pd.DataFrame(final_dataset[key])
    num = split['label'].value_counts().sort_index()
    clases_split.loc[len(clases_split)] = [key, *num]
    #print(num.sort_index())
clases_split

Número de imágenes por clases en cada split


Unnamed: 0,split,benigno,maligno,sospechoso
0,train,664,683,654
1,validation,80,81,89
2,test,90,70,91


## Preprocesamiento de las imágenes

In [8]:
processor = AutoImageProcessor.from_pretrained(_model, use_fast=True)
processor

ViTImageProcessorFast {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessorFast",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [9]:
def transforms(batch):
    batch['image'] = [x.convert('RGB') for x in batch['image']]
    inputs = processor(batch['image'],return_tensors='pt')
    inputs['labels'] = batch['label']  # Las clases ya están en formato numérico
    return inputs

In [10]:
processed_dataset = final_dataset.with_transform(transforms)

### Data Collation

In [11]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

## Métricas de evaluación

In [12]:
import numpy as np
import evaluate

accuracy = evaluate.load('accuracy')
precision = evaluate.load('precision')
recall = evaluate.load('recall')
f1 = evaluate.load('f1')

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=1)

    # Accuracy no requiere el parámetro average
    accuracy_score = accuracy.compute(predictions=predictions, references=labels)['accuracy']
    
    # Las demás métricas sí requieren el parámetro average para multiclase
    precision_score = precision.compute(predictions=predictions, references=labels, average='macro')['precision']
    recall_score = recall.compute(predictions=predictions, references=labels, average='macro')['recall']
    f1_score = f1.compute(predictions=predictions, references=labels, average='macro')['f1']
    
    return {
        'accuracy': accuracy_score,
        'precision': precision_score,
        'recall': recall_score,
        'f1': f1_score
    }

## Carga del modelo

In [13]:
# Clase personalizada que añade dropout antes de la capa final de clasificación
class CustomViTForImageClassification(ViTForImageClassification):
    def __init__(self, config):
        super().__init__(config)
        
        # Dropout adicional antes de la capa final
        self.additional_dropout = nn.Dropout(_dp_clasificador)  # Dropout antes del clasificador
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, pixel_values, labels=None):
        outputs = self.vit(pixel_values)  # Obtenemos la salida del modelo ViT
        
        # Usamos el primer token [CLS] de la salida
        pooled_output = outputs.last_hidden_state[:, 0]  # [CLS] está en la posición 0
        
        # Aplicamos dropout adicional antes de la clasificación
        pooled_output = self.additional_dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            # Asegúrate de que las etiquetas sean tipo long (para clasificación)
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        
        return (loss, logits) if loss is not None else logits

# Configuramos el modelo base con Dropout en las capas internas del ViT
model = CustomViTForImageClassification.from_pretrained(
    _model,
    num_labels = len(labels),
    id2label = id2label,
    label2id = label2id,
    hidden_dropout_prob=_dp_hidden_layer,  # Dropout en las capas internas del modelo
    attention_probs_dropout_prob=_dp_attention_layer,  # Dropout en las capas de atención
    ignore_mismatched_sizes = True
)

Some weights of CustomViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Arquitectura del modelo

In [14]:
model

CustomViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear

### Congelar todas las capas, menos el clasificador

In [15]:
for name,p in model.named_parameters():
    if not name.startswith('classifier'):
        p.requires_grad = False

In [16]:
num_params = sum([p.numel() for p in model.parameters()])
trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])

print(f"{num_params = :,} | {trainable_params = :,}")

num_params = 85,800,963 | trainable_params = 2,307


### Descongelar capas del encoder para fine-tuning

In [17]:
# Obtener el número total de capas en el encoder
num_total_layers = len(list(model.vit.encoder.layer))  # Debería ser 24 para ViT-Large, 12 para ViT-base
print(num_total_layers)

12


In [18]:
# Si se descongelan capas
if num_layers_to_unfreeze is not None:
    # Calcular el índice a partir del cual descongelar
    unfreeze_from = num_total_layers - num_layers_to_unfreeze
    
    # Iterar sobre todas las capas del encoder
    for idx, layer in enumerate(model.vit.encoder.layer):
        if idx >= unfreeze_from:
            # Descongelar esta capa
            for param in layer.parameters():
                param.requires_grad = True
        else:
            # Congelar esta capa
            for param in layer.parameters():
                param.requires_grad = False


In [19]:
# Mostrar el número total de parámetros y los entrenables después de descongelar
num_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Después de descongelar las últimas {num_layers_to_unfreeze} capas:")
print(f"Total de parámetros: {num_params:,}")
print(f"Parámetros entrenables: {trainable_params:,}")

Después de descongelar las últimas 5 capas:
Total de parámetros: 85,800,963
Parámetros entrenables: 35,441,667


In [20]:
# Revisión de trainable por capa
for name, param in model.named_parameters():
    status = "Trainable" if param.requires_grad else "Frozen"
    #print(f"{name}: {status}")

## Training

In [21]:
training_args = TrainingArguments(
    output_dir=_output,
    per_device_train_batch_size=_batch_size,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    num_train_epochs=_epochs,  # Epochs a entrenar -> Revisar
    learning_rate=_learning_rate,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

In [22]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
    tokenizer=processor
)

In [23]:
# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")
    print("MPS enabled")

MPS enabled


In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9595,0.833345,0.584,0.568464,0.595926,0.561729
2,0.7678,0.847668,0.62,0.671727,0.635491,0.572687
3,0.6248,0.969052,0.616,0.627359,0.629372,0.584272
4,0.2889,1.00021,0.632,0.651216,0.634221,0.638523
5,0.1973,1.195677,0.552,0.613769,0.543468,0.516801


In [25]:
trainer.save_model()

### Evaluación del modelo

In [25]:
trainer.evaluate(processed_dataset['test'])

{'eval_loss': 2.0670344829559326,
 'eval_accuracy': 0.649402390438247,
 'eval_precision': 0.6472742563535402,
 'eval_recall': 0.6686202686202686,
 'eval_f1': 0.6534963244640665}

## Inferencia en conjunto de test 

In [26]:
samples = final_dataset['test']
processed_samples = samples.with_transform(transforms)
predictions = trainer.predict(processed_samples).predictions.argmax(axis=1) # labels predichas

In [28]:
show_predictions(rows=5,cols=5, samples_=samples, predictions_=predictions, id2label_=id2label)

TypeError: Wrong key type: '173' of type '<class 'numpy.int64'>'. Expected one of int, slice, range, str or Iterable.

<Figure size 2000x2000 with 0 Axes>

### Matriz de confusión ❌

In [None]:
confusion_matrix(samples_=samples, predictions_=predictions)

## Iterar por más epochs ❌

In [None]:
trainer.args.num_train_epochs = 30  # Para entrenar hasta la época 20
trainer.train(resume_from_checkpoint=_checkpoint)

In [25]:
trainer.save_model()

In [26]:
trainer.evaluate(processed_dataset['test'])

{'eval_loss': 0.8503161072731018,
 'eval_accuracy': 0.611185086551265,
 'eval_precision': 0.6176820549739529,
 'eval_recall': 0.6137534374966717,
 'eval_f1': 0.597174189764872,
 'eval_runtime': 34.361,
 'eval_samples_per_second': 21.856,
 'eval_steps_per_second': 2.736,
 'epoch': 20.0}

In [29]:
confusion_matrix(samples_=samples, predictions_=predictions)