In [1]:
#First, we choose the pretrained model to use 

#model_id = "google/vit-base-patch16-224"
#model_id = 'microsoft/swin-tiny-patch4-window7-224'
#model_id = 'facebook/deit-base-patch16-224'
model_id = 'microsoft/resnet-50'

Now we load the feature extractor to process the image into a tensor.

In [2]:
from transformers import AutoFeatureExtractor, ViTFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
#feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)

  from .autonotebook import tqdm as notebook_tqdm
Downloading: 100%|██████████| 266/266 [00:00<00:00, 212kB/s]


This feature extractor will resize every image to the resolution that the model expects and normalize channels. 

We define 2 functions, one for training and one for validation, including resizing, center cropping and normalizing.

In [3]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["img"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["img"]]
    return example_batch

Next, we can preprocess our dataset by applying these functions.

In [4]:
#Load data
import datasets
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
ds = load_from_disk('./data_dict_cut')
ds

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 2050
    })
    val: Dataset({
        features: ['img', 'label'],
        num_rows: 228
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 402
    })
})

In [5]:
# split up training into training + validation
train_ds = ds['train']
val_ds = ds['val']
test_ds = ds['test']

In [6]:
#Classes names
labels = train_ds.features["label"].names
print(labels)

['proton', 'gamma']


In [7]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)



In [8]:
train_ds[0]

{'img': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=288x288>,
 'label': 0,
 'pixel_values': tensor([[[ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          ...,
          [ 2.1804,  2.1804,  2.1804,  ...,  1.8037,  1.8550,  1.8893],
          [ 1.8722,  1.8722,  1.8379,  ...,  1.7694,  1.8208,  1.8550],
          [ 1.7009,  1.6838,  1.6495,  ...,  1.7523,  1.8037,  1.8208]],
 
         [[ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          ...,
          [ 2.0784,  2.0784,  2.0784,  ..., -0.4951, -0.4426, -0.4076],
          [ 1.6758,  1.6758,  1.6758,  ..., -0.5301, -0.4776, -0.4426],
          [ 1.4307,  1.4307,  1.4482,  ..., -0.5476, -0.4951, -0.4776]],


Now that our data is ready, we can download the pretrained model and fine-tune it. We use the modelViTForImageClassification.  

In [9]:
#We create a dictionary that maps a label name to an integer and vice versa. 
#The mapping will help the model recover the label name from the label number.

label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [14]:
from transformers import ViTForImageClassification, SwinForImageClassification,ResNetForImageClassification, TrainingArguments, Trainer

#model = ViTForImageClassification.from_pretrained(model_id,
#                                                 label2id=label2id,
#                                                 id2label=id2label,
#                                                 ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
#)

#model = SwinForImageClassification.from_pretrained(model_id,
#                                                 label2id=label2id,
#                                                 id2label=id2label,
#                                                 ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
#) 

model = ResNetForImageClassification.from_pretrained(model_id,
                                                 label2id=label2id,
                                                 id2label=id2label,
                                                 ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

Downloading: 100%|██████████| 67.9k/67.9k [00:00<00:00, 3.17MB/s]
Downloading: 100%|██████████| 97.8M/97.8M [00:01<00:00, 73.5MB/s]
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([2, 2048]) in the model instantiated
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The warning is telling us we are throwing away some weights (the weights and bias of the classifier layer) and randomly initializing some other (the weights and bias of a new classifier layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

To instantiate a Trainer, we will need to define the training configuration and the evaluation metric. The most important is the TrainingArguments, which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model.

Most of the training arguments are pretty self-explanatory, but one that is quite important here is remove_unused_columns=False. This one will drop any features not used by the model's call function. By default it's True because usually it's ideal to drop unused feature columns, making it easier to unpack inputs into the model's call function. But, in our case, we need the unused features ('image' in particular) in order to create 'pixel_values'.

In [15]:
model_name = model_id.split("/")[-1]
batch_size = 32
learning_rate = 5e-5
gradient_accumulation_steps = 4
epochs = 3
warmup_ratio= 0.1
logging_steps=10

args = TrainingArguments(
    f"{model_name}-finetuned-ds",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    warmup_ratio=warmup_ratio,
    logging_steps=logging_steps,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
   # push_to_hub=True,
)

Next, we need to define a function for how to compute the metrics from the predictions, which will just use the metric we loaded earlier. Let us also load the Accuracy metric, which we'll use to evaluate our model both during and after training. The only preprocessing we have to do is to take the argmax of our predicted logits:

In [16]:
import numpy as np

from datasets import load_metric

metric = load_metric("accuracy")

# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

We also define a collate_fn, which will be used to batch examples together. Each batch consists of 2 keys, namely pixel_values and labels.

In [17]:
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

Then we just need to pass all of this along with our datasets to the Trainer:

In [18]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

Now we can finetune our model by calling the train method:

In [19]:
train_results = trainer.train()


***** Running training *****
  Num examples = 2050
  Num Epochs = 3
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 4
  Total optimization steps = 48


Epoch,Training Loss,Validation Loss,Accuracy
0,0.6796,0.657536,0.72807
1,0.6227,0.622656,0.72807
2,0.6247,0.603353,0.72807


***** Running Evaluation *****
  Num examples = 228
  Batch size = 32
Saving model checkpoint to resnet-50-finetuned-ds/checkpoint-16
Configuration saved in resnet-50-finetuned-ds/checkpoint-16/config.json
Model weights saved in resnet-50-finetuned-ds/checkpoint-16/pytorch_model.bin
Feature extractor saved in resnet-50-finetuned-ds/checkpoint-16/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 228
  Batch size = 32
Saving model checkpoint to resnet-50-finetuned-ds/checkpoint-32
Configuration saved in resnet-50-finetuned-ds/checkpoint-32/config.json
Model weights saved in resnet-50-finetuned-ds/checkpoint-32/pytorch_model.bin
Feature extractor saved in resnet-50-finetuned-ds/checkpoint-32/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 228
  Batch size = 32
Saving model checkpoint to resnet-50-finetuned-ds/checkpoint-48
Configuration saved in resnet-50-finetuned-ds/checkpoint-48/config.json
Model weights saved in resnet-50-finetuned-ds/ch

In [20]:

metrics = trainer.evaluate()
# some nice to haves:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** Running Evaluation *****
  Num examples = 228
  Batch size = 32


***** eval metrics *****
  epoch                   =       2.98
  eval_accuracy           =     0.7281
  eval_loss               =     0.6575
  eval_runtime            = 0:00:18.36
  eval_samples_per_second =     12.412
  eval_steps_per_second   =      0.436


In [21]:

test_ds.set_transform(preprocess_val)

In [22]:
outputs = trainer.predict(test_ds)
y_pred = outputs.predictions.argmax(1)

***** Running Prediction *****
  Num examples = 402
  Batch size = 32


In [23]:
compute_metrics(outputs)

{'accuracy': 0.7288557213930348}

In [24]:
outputs

PredictionOutput(predictions=array([[ 7.73780793e-02, -1.12476796e-01],
       [ 1.09268889e-01, -1.52451143e-01],
       [ 8.39075893e-02, -6.37701899e-02],
       [ 8.58538970e-02, -1.07397363e-01],
       [ 1.21771477e-01, -7.28272647e-02],
       [ 7.17547238e-02, -4.51459102e-02],
       [ 9.37825069e-03, -1.31042555e-01],
       [ 8.27188045e-02, -8.10401663e-02],
       [ 6.50262833e-02, -1.31769314e-01],
       [ 3.62734646e-02, -7.09352717e-02],
       [ 7.29701668e-02, -1.13420218e-01],
       [ 8.73838663e-02, -6.89674467e-02],
       [ 6.41680211e-02, -6.62222877e-02],
       [ 1.35265470e-01, -6.14006855e-02],
       [ 5.60993180e-02, -8.21944475e-02],
       [ 1.02933310e-01, -1.07396565e-01],
       [ 9.58039686e-02, -1.90653384e-01],
       [ 9.72029567e-02, -2.19756886e-01],
       [ 4.21429649e-02, -1.09505095e-01],
       [ 7.34182447e-03, -1.26007527e-01],
       [ 4.98240367e-02, -1.29639894e-01],
       [-1.89225413e-02, -7.11746514e-02],
       [ 8.91918689e-02, 

In [25]:
y_pred

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [26]:
y_true = np.array(test_ds[:]['label'])
y_true

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0,
       1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0,
       1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1,
       0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
       0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
       1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0,
       1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [27]:
from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred)

array([[293,   0],
       [109,   0]])