In [1]:
#Create dataset
import os
import datasets

#model_id = "google/vit-base-patch16-224"
model_id = 'microsoft/swin-tiny-patch4-window7-224'

def create_image_folder_dataset(root_path):
  """creates `Dataset` from image folder structure"""

  # get class names by folders names
  _CLASS_NAMES= os.listdir(root_path)
  # defines `datasets` features`
  features=datasets.Features({
                      "img": datasets.Image(),
                      "label": datasets.features.ClassLabel(names=_CLASS_NAMES),
                  })
  # temp list holding datapoints for creation
  img_data_files=[]
  label_data_files=[]
  # load images into list for creation
  for img_class in os.listdir(root_path):
    for img in os.listdir(os.path.join(root_path,img_class)):
      path_=os.path.join(root_path,img_class,img)
      img_data_files.append(path_)
      label_data_files.append(img_class)
  # create dataset
  ds = datasets.Dataset.from_dict({"img":img_data_files,"label":label_data_files},features=features)
  return ds

In [2]:
ds = create_image_folder_dataset("data")

In [3]:
ds

Dataset({
    features: ['img', 'label'],
    num_rows: 6402
})

In [4]:
#Classes names
labels = ds.features["label"].names
print(labels)

['gamma', 'iron', 'proton']


In [5]:
# test size will be 15% of train dataset
test_size=.15

ds_split = ds.shuffle().train_test_split(test_size=test_size)

In [6]:
ds_split

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 5441
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 961
    })
})

In [7]:
ds_split['train'][0]

{'img': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=288x288>,
 'label': 1}

We take a look at an example. The image field contains a PIL image and each label is an integer that represents a class. We create a dictionary that maps a label name to an integer and vice versa. The mapping will help the model recover the label name from the label number.

In [8]:
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

Now we can covert the label number to a label name.

In [9]:
id2label[str(0)]

'gamma'

In [10]:
id2label[str(1)]

'iron'

In [11]:
id2label[str(2)]

'proton'

Now we load the ViT feature extractor to process the image into a tensor.

In [12]:
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

This feature extractor will resize every image to the resolution that the model expects and normalize channels. 

We define 2 functions, one for training and one for validation, including resizing, center cropping and normalizing.

In [13]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["img"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["img"]]
    return example_batch

Next, we can preprocess our dataset by applying these functions.

In [14]:
# split up training into training + validation
splits = ds_split["train"].train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

In [15]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)



In [16]:
train_ds[0]

{'img': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=288x288>,
 'label': 1,
 'pixel_values': tensor([[[2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          ...,
          [2.1633, 2.1462, 2.1290,  ..., 2.2318, 2.2318, 2.2489],
          [2.1462, 2.0605, 1.9407,  ..., 2.2489, 2.2489, 2.2489],
          [1.9920, 1.8379, 1.5810,  ..., 2.2489, 2.2489, 2.2489]],
 
         [[2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          ...,
          [0.8529, 0.8529, 0.8179,  ..., 1.4832, 1.5882, 1.6583],
          [0.8880, 0.7829, 0.6429,  ..., 1.7108, 1.8333, 1.9384],
          [0.7654, 0.5903, 0.3102,  ..., 1.6583, 1.8508, 1.9734]],
 
         [[2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
    

Now that our data is ready, we can download the pretrained model and fine-tune it. We use the modelViTForImageClassification.  

In [18]:
from transformers import SwinForImageClassification, TrainingArguments, Trainer

model = SwinForImageClassification.from_pretrained(model_id,
                                                 label2id=label2id,
                                                 id2label=id2label,
                                                 ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The warning is telling us we are throwing away some weights (the weights and bias of the classifier layer) and randomly initializing some other (the weights and bias of a new classifier layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

To instantiate a Trainer, we will need to define the training configuration and the evaluation metric. The most important is the TrainingArguments, which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model.

Most of the training arguments are pretty self-explanatory, but one that is quite important here is remove_unused_columns=False. This one will drop any features not used by the model's call function. By default it's True because usually it's ideal to drop unused feature columns, making it easier to unpack inputs into the model's call function. But, in our case, we need the unused features ('image' in particular) in order to create 'pixel_values'.

In [19]:
model_name = model_id.split("/")[-1]
batch_size = 32
args = TrainingArguments(
    f"{model_name}-finetuned-ds",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
   # push_to_hub=True,
)

Next, we need to define a function for how to compute the metrics from the predictions, which will just use the metric we loaded earlier. Let us also load the Accuracy metric, which we'll use to evaluate our model both during and after training. The only preprocessing we have to do is to take the argmax of our predicted logits:

In [20]:
import numpy as np

from datasets import load_metric

metric = load_metric("accuracy")

# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

We also define a collate_fn, which will be used to batch examples together. Each batch consists of 2 keys, namely pixel_values and labels.

In [21]:
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

Then we just need to pass all of this along with our datasets to the Trainer:

In [22]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

Now we can finetune our model by calling the train method:

In [23]:
train_results = trainer.train()
# rest is optional but nice to have
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** Running training *****
  Num examples = 4896
  Num Epochs = 3
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 4
  Total optimization steps = 114


Epoch,Training Loss,Validation Loss,Accuracy
0,0.1636,0.037612,0.992661
1,0.065,0.05535,0.979817
2,0.0391,0.024472,0.988991


***** Running Evaluation *****
  Num examples = 545
  Batch size = 32
Saving model checkpoint to swin-tiny-patch4-window7-224-finetuned-ds/checkpoint-38
Configuration saved in swin-tiny-patch4-window7-224-finetuned-ds/checkpoint-38/config.json
Model weights saved in swin-tiny-patch4-window7-224-finetuned-ds/checkpoint-38/pytorch_model.bin
Feature extractor saved in swin-tiny-patch4-window7-224-finetuned-ds/checkpoint-38/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 545
  Batch size = 32
Saving model checkpoint to swin-tiny-patch4-window7-224-finetuned-ds/checkpoint-76
Configuration saved in swin-tiny-patch4-window7-224-finetuned-ds/checkpoint-76/config.json
Model weights saved in swin-tiny-patch4-window7-224-finetuned-ds/checkpoint-76/pytorch_model.bin
Feature extractor saved in swin-tiny-patch4-window7-224-finetuned-ds/checkpoint-76/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 545
  Batch size = 32
Saving model checkpoint to swin-

***** train metrics *****
  epoch                    =        2.99
  total_flos               = 339280398GF
  train_loss               =      0.1747
  train_runtime            =  1:25:59.33
  train_samples_per_second =       2.847
  train_steps_per_second   =       0.022


In [24]:
metrics = trainer.evaluate()

***** Running Evaluation *****
  Num examples = 545
  Batch size = 32


In [25]:
test_ds = ds_split['test']
test_ds.set_transform(preprocess_val)

In [26]:
outputs = trainer.predict(test_ds)
y_pred = outputs.predictions.argmax(1)

***** Running Prediction *****
  Num examples = 961
  Batch size = 32


In [27]:
compute_metrics(outputs)

{'accuracy': 0.9895941727367326}

In [28]:
outputs

PredictionOutput(predictions=array([[ 5.5055723, -5.744572 , -2.035857 ],
       [ 5.0418024, -3.0681725, -0.8998853],
       [ 4.5115333, -4.938079 , -1.4383147],
       ...,
       [-3.7176938,  4.834112 ,  1.6651603],
       [ 5.483938 , -5.6335926, -1.8940908],
       [-3.6593235,  6.2780037,  0.5556666]], dtype=float32), label_ids=array([0, 0, 0, 0, 1, 0, 2, 1, 2, 0, 0, 2, 1, 1, 1, 1, 1, 2, 0, 1, 0, 0,
       0, 1, 2, 0, 1, 2, 1, 2, 1, 1, 1, 1, 0, 1, 0, 0, 1, 2, 0, 1, 0, 0,
       1, 0, 0, 0, 1, 2, 0, 1, 1, 0, 1, 2, 0, 1, 2, 0, 1, 0, 1, 1, 2, 1,
       0, 2, 2, 1, 0, 0, 0, 2, 0, 2, 2, 1, 1, 2, 2, 1, 1, 2, 1, 0, 0, 1,
       0, 1, 1, 0, 1, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 2, 1, 0, 2, 2,
       0, 0, 2, 0, 0, 1, 2, 0, 2, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0,
       2, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 1, 2, 1, 0,
       0, 2, 0, 0, 0, 2, 0, 2, 1, 0, 1, 0, 1, 0, 2, 0, 1, 1, 1, 0, 0, 1,
       2, 0, 1, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 0, 1, 0, 0, 1, 0, 1, 2, 1,
     

In [29]:
y_pred

array([0, 0, 0, 0, 1, 0, 2, 1, 2, 0, 0, 2, 1, 1, 1, 1, 1, 2, 0, 1, 0, 0,
       0, 1, 2, 0, 1, 2, 1, 2, 1, 1, 1, 1, 0, 1, 0, 0, 1, 2, 0, 1, 0, 0,
       1, 0, 0, 0, 1, 2, 0, 1, 1, 0, 1, 2, 0, 1, 2, 0, 1, 0, 1, 1, 2, 1,
       0, 2, 2, 1, 0, 0, 0, 2, 0, 2, 2, 1, 1, 2, 2, 1, 1, 2, 1, 0, 0, 1,
       0, 1, 1, 0, 1, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 2, 1, 0, 2, 2,
       0, 0, 2, 0, 0, 1, 2, 0, 2, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0,
       2, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 1, 2, 1, 0,
       0, 2, 0, 0, 0, 2, 0, 2, 1, 0, 1, 0, 1, 0, 2, 0, 1, 1, 1, 0, 0, 1,
       2, 0, 1, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 0, 1, 0, 0, 1, 0, 1, 2, 1,
       0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 2, 0, 2, 1, 0, 2, 1, 2, 2, 1,
       1, 0, 0, 1, 0, 1, 1, 0, 2, 0, 2, 1, 0, 1, 1, 2, 0, 0, 1, 0, 0, 2,
       1, 1, 2, 1, 0, 2, 1, 0, 1, 1, 2, 1, 0, 0, 1, 1, 1, 2, 0, 0, 0, 1,
       1, 0, 0, 2, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 2, 0, 0,
       0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0,