## Fine-tune the Vision Transformer on Brain MRI Images Dataset


In this notebook, we are going to fine-tune a pre-trained [Vision Transformer](https://huggingface.co/docs/transformers/model_doc/vit) (which I added to [Transformers](https://github.com/huggingface/transformers)) on the Fashion Product Images dataset.

We will prepare the data using [datasets](https://github.com/huggingface/datasets), and train the model using the [Trainer](https://huggingface.co/transformers/main_classes/trainer.html). For other notebooks (such as training ViT with PyTorch Lightning), I refer to my repo [Transformers-Tutorials](https://github.com/NielsRogge/Transformers-Tutorials). 



In [1]:
!pip install -q transformers datasets   

In [2]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize,
                                    RandomRotation,
                                    RandomResizedCrop,
                                    RandomHorizontalFlip,
                                    RandomAdjustSharpness,
                                    Resize, 
                                    ToTensor)

image_mean, image_std = processor.image_mean, processor.image_std
height = processor.size["height"]
width = processor.size["width"]
size = (height, width)
print("Size: ", size)
print("Image mean: ", image_mean)
print("Image std: ", image_std)

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
        [
            Resize(size),
            RandomRotation(15),
            RandomAdjustSharpness(2),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
        # return _train_transforms(image)
    for item in examples:
        item['pixel_values'] = _train_transforms(item['image'])
    return examples

def val_transforms(examples):
    for item in examples:
        item['pixel_values'] = _val_transforms(item['image'])
    return examples

Size:  (224, 224)
Image mean:  [0.5, 0.5, 0.5]
Image std:  [0.5, 0.5, 0.5]


In [4]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from pathlib import Path
import torch
import numpy as np
from PIL import Image
class MyDataset(Dataset):
    def __init__(self, data_dir,test_frac=0.15,section="training",data_augmentation=False):
        self.num_class = 0
        self.test_frac = test_frac
        self.data_augmentation=data_augmentation
        self.section=section
        self.transform=train_transforms if self.section=="training" else val_transforms
        self.generate_data_list(data_dir)


    def __len__(self):
        return len(self.data)
    
    def generate_data_list(self,data_dir):
        # 类别名 [yes,no]
        class_names = sorted(f"{x.name}" for x in Path(data_dir).iterdir() if x.is_dir())  # folder name as the class name
        # 2
        self.num_class = len(class_names)
        image_files_list = []
        image_class = []
        # [[class1图片列表][class2图片列表]]
        image_files = [[f"{x}" for x in (Path(data_dir) / class_names[i]).iterdir()] for i in range(self.num_class)]
        num_each = [len(image_files[i]) for i in range(self.num_class)]
        # 
        max_value=max(num_each)
        enlarge_factor=[max_value//num_each[i] for i in range(self.num_class)]
        if not self.data_augmentation:
            enlarge_factor=[1]*self.num_class
        print('this is the enlarge factor',enlarge_factor)

        class_name = []
        # 读取所有图片为一个二维list [[class1图片列表][class2图片列表]]
        # 对于每一类图片
        for i in range(self.num_class):
            # 将图片列表合并 [[class1图片列表][class2图片列表]] -> [class1图片列表+class2图片列表]
            image_files_list.extend(image_files[i]*enlarge_factor[i])
            # 为每个图片标记类别，类别标签从0开始，记录index [0,0,0,1,1,1]
            image_class.extend([i] * num_each[i]*enlarge_factor[i])
            # 为每个图片标记类别名 [yes,yes,yes,no,no,no]
            class_name.extend([class_names[i]] * num_each[i]*enlarge_factor[i])
        length = len(image_files_list)
        # 生成图片索引 [0,1,2,3,4,5]
        indices = np.arange(length)
        # 打乱图片顺序
        np.random.shuffle(indices)
        test_length = int(length * self.test_frac)
        if self.section == "test":
            section_indices = indices[:test_length]
        elif self.section == "training":
            section_indices = indices[test_length:]
        else:
            raise ValueError(
                f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].'
            )
        def convert_image(image_path):
            image = Image.open(image_path)
            if image.mode != 'RGB':
                image=image.convert('RGB')
            return image
        self.data=[{"image":convert_image(image_files_list[i]),"label": image_class[i]}  for i in section_indices ]
        self.data=self.transform(self.data)
    
    def __getitem__(self, index):
        # return self.data[index]
        img=self.data[index]["pixel_values"]
        label=self.data[index]['label']
        return {'pixel_values':img,'label':label}
        # return img,label

In [5]:
# from torch.utils.data import DataLoader
# from torch.utils.data import Dataset
# import os
# from pathlib import Path
# import torch
# import numpy as np
# from PIL import Image

# class MyDataset(Dataset):
#     def __init__(self, data_dir,test_frac=0.15,section="training",balance=False):
#         self.num_class = 0
#         self.test_frac = test_frac
#         self.section=section
#         self.transform=train_transforms if self.section=="training" else val_transforms
#         self.generate_data_list(data_dir)
#         if balance:
#             self.balance_classes()


#     def __len__(self):
#         return len(self.samples)
#     #
#     def balance_classes(self):
#         from collections import Counter
#         # 3 暂时不分类
#         label2_counter=Counter(x[2] for x in self.samples)
#         print("this is label2 counter: ", label2_counter)
#         max_label2_count = max(label2_counter.values())
#         print("max label2 count: ", max_label2_count)
#         # before balance
#         print("total num before first balance: ", len(self.samples))
#         print("Deep: ", label2_counter[0])
#         print("Lobar: ", label2_counter[1])
#         print("Subtentorial: ", label2_counter[2])
#         # # 为了平衡类别，复制少数类的数据，先复制label2
#         balanced_samples = []
#         for key in label2_counter.keys():
#             factor = max_label2_count // label2_counter[key]
#             if key==3:
#                 factor=1
#             for i in range(len(self.samples)):
#                 if self.samples[i][2] == key:
#                     balanced_samples.extend([self.samples[i]]*factor)
        
#         self.samples = balanced_samples
#         print("total num after first balance: ", len(self.samples))
#         label2_counter=Counter(x[2] for x in self.samples)
#         print("counter after first balance: ", label2_counter)
#         balanced_samples=[]
#         label1_counter = Counter(x[1] for x in self.samples)
#         print("total num before second balance: ", len(self.samples))
#         print("no tumor: ", label1_counter[0])
#         print("tumor: ", label1_counter[1])
#         max_label1_count = max(label1_counter.values())
#         print("max label1 count: ", max_label1_count)
#         for label in label1_counter.keys():
#             factor = max_label1_count // label1_counter[label]
#         #     # 复制因子次每个类别的数据
#             for i in range(len(self.samples)):
#                 if self.samples[i][1] == label:
#                     balanced_samples.extend([self.samples[i]]* factor)
#         self.samples = balanced_samples
#         print("total num after balance: ", len(self.samples))
#         label1_counter = Counter(x[1] for x in self.samples)
#         print("no tumor: ", label1_counter[0])
#         print("tumor: ", label1_counter[1])


#     def generate_data_list(self,data_dir):
#         # 类别名 [yes,no]
#         # class_names = sorted(f"{x.name}" for x in Path(data_dir).iterdir() if x.is_dir())  # folder name as the class name
#         no_tumor_dir = os.path.join(data_dir, 'no')
#         no_tumor_images = [(os.path.join(no_tumor_dir, img), 0, 3) for img in os.listdir(no_tumor_dir)]
#         yes_tumor_dir = os.path.join(data_dir, 'yes')
#         tumor_classes = {'Deep': 0, 'Lobar': 1, 'Subtentorial': 2}
#         yes_tumor_images = []
#         for tumor_class, label in tumor_classes.items():
#             class_dir = os.path.join(yes_tumor_dir, tumor_class)
#             yes_tumor_images += [(os.path.join(class_dir, img), 1, label) for img in os.listdir(class_dir)]
#         self.samples = no_tumor_images + yes_tumor_images
#         # self.data=self.transform(self.data)
    
#     def __getitem__(self, index):
#         img_path, has_tumor, tumor_type = self.samples[index]
#         image = Image.open(img_path).convert('RGB')
#         image=self.transform(image)
#         #filename=self.data[index]['file_name']
#         # # return {'pixel_values':img,'label':label}
#         # return img,label,filename
#         return {'pixel_values':image,'label1':has_tumor,'label2':tumor_type}
#         # return image,has_tumor, tumor_type


In [6]:
# from torch.utils.data import random_split
data_dir='/home/jialiangfan/DTViT/dataset1/'
# dataset=MyDataset(data_dir)
# train_size = int(0.8 * len(datase/t))
    
# val_size = int(0.1*len(dataset))
# test_size=len(dataset) - train_size-val_size
# train_dataset, val_dataset,test_dataset = random_split(dataset, [train_size,val_size,test_size])
train_dataset = MyDataset(data_dir,test_frac=0.15,section="training",data_augmentation=True)
test_dataset=MyDataset(data_dir,test_frac=0.15,section="test",data_augmentation=True)

this is the enlarge factor [1, 3, 12, 1]
this is the enlarge factor [1, 3, 12, 1]


In [7]:
len(train_dataset)
# 155+98=253

18197

In [8]:
len((test_dataset))

3211

In [9]:
from torch.utils.data import DataLoader
import torch

# def collate_fn(examples):
#     pixel_values = torch.stack([example["pixel_values"] for example in examples])
#     labels = torch.tensor([example["label"] for example in examples])
#     return {"pixel_values": pixel_values, "labels": labels}


In [10]:

train_dataloader = DataLoader(train_dataset,batch_size=32)
test_dataloader = DataLoader(test_dataset,batch_size=4)

In [11]:
batch = next(iter(train_dataloader))
print(batch.keys())
# batch['pixel_values'].shape,batch['label'].shape

# print(batch['pixel_values'].shape)
# for k,v in batch.items():
  # if isinstance(v, torch.Tensor):
    # print(k, v.shape)

dict_keys(['pixel_values', 'label'])


In [12]:
batch = next(iter(test_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([4, 3, 224, 224])
label torch.Size([4])


Of course, we would like to know the actual class name, rather than the 

---

integer index. We can obtain that by creating a dictionary which maps between integer indices and actual class names (id2label):

## Preprocessing the data

We will now preprocess the data. The model requires 2 things: `pixel_values` and `labels`. 

We will perform data augmentaton **on-the-fly** using HuggingFace Datasets' `set_transform` method (docs can be found [here](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=set_transform#datasets.Dataset.set_transform)). This method is kind of a lazy `map`: the transform is only applied when examples are accessed. This is convenient for tokenizing or padding text, or augmenting images at training time for example, as we will do here. 

It's very easy to create a corresponding PyTorch DataLoader, like so:

## Define the model

Here we define the model. We define a `ViTForImageClassification`, which places a linear layer ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) on top of a pre-trained `ViTModel`. The linear layer is placed on top of the last hidden state of the [CLS] token, which serves as a good representation of an entire image. 

The model itself is pre-trained on ImageNet-21k, a dataset of 14 million labeled images. You can find all info of the model we are going to use [here](https://huggingface.co/google/vit-base-patch16-224-in21k).

We also specify the number of output neurons by setting the id2label and label2id mapping, which we be added as attributes to the configuration of the model (which can be accessed as `model.config`).

In [13]:
from transformers import ViTForImageClassification,ViTConfig
from torch import nn
config=ViTConfig()
config.num_labels=4
config.problem_type="multi_label_classification"
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k",config=config)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model.classifier

In [None]:
# for name, param in model.named_parameters():
#     if name.startswith("classifier"):
#         param.requires_grad = True
#     else:
#         model.requires_grad_=False

In [None]:

# 遍历模型中的所有参数
for name, param in model.named_parameters():
    # 检查参数是否被冻结
    if param.requires_grad:
        print(f"参数 {name} 没有被冻结，将会更新。")
    else:
        print(f"参数 {name} 被冻结，不会更新。")

To instantiate a `Trainer`, we will need to define three more things. The most important is the `TrainingArguments`, which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional.

We also set the argument "remove_unused_columns" to False, because otherwise the "img" column would be removed, which is required for the data transformations.

In [None]:
from transformers import TrainingArguments, Trainer
import os

os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'

metric_name = "accuracy"
args = TrainingArguments(
    "Brain-Tumor-Detection",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5, #0.00002, #0.00002
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
    report_to="tensorboard",
)
# args.set_optimizer(name="sgd")

Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, set the training and evaluation batch_sizes and customize the number of epochs for training, as well as the weight decay.

We also define a `compute_metrics` function that will be used to compute metrics at evaluation. We use "accuracy" here.


In [None]:
from sklearn.metrics import accuracy_score
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))

In [None]:
import torch
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    # data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    
)

## Train the model


In [None]:
import os
trainer.train()

## Evaluation

Finally, let's evaluate the model on the test set:

In [None]:
outputs = trainer.predict(test_dataset)


In [None]:
print(outputs.metrics)