In [1]:
from datasets import load_dataset, Image
import os
import numpy as np

root_dir = './NIH-small/sample/'

dataset = load_dataset('imagefolder', split='train', data_dir=os.path.join(root_dir, 'images'))
# Add a filename column
def add_filename(example):
    example['filename'] = os.path.basename(example['image'].filename)
    return example

dataset = dataset.map(add_filename)

dataset = dataset.cast_column("image", Image(mode="RGB"))

# Load the metadata from the CSV file
import pandas as pd
metadata_file = os.path.join(root_dir, 'sample_labels.csv')
# Load the metadata from the CSV file
metadata_df = pd.read_csv(metadata_file)

# Create a dictionary from the metadata for quick lookup
metadata_dict = metadata_df.set_index('Image Index').to_dict(orient='index')

# Add metadata to the dataset
def add_metadata(example):
    filename = example['filename']
    if filename in metadata_dict:
        metadata = metadata_dict[filename]
        example.update(metadata)
    return example

dataset = dataset.map(add_metadata)

from datasets.features import ClassLabel, Sequence

# Split "Finding Labels" into multiple labels
metadata_df['Finding Labels'] = metadata_df['Finding Labels'].str.split('|')

# Get all unique labels
all_labels = set(label for sublist in metadata_df['Finding Labels'] for label in sublist)
# as no finding label affects so many images, most implementations remove "no finding" label.
all_labels.remove('No Finding')

# Create a ClassLabel feature for each unique label
class_labels = ClassLabel(names=list(all_labels))

# Define the label feature as a sequence of ClassLabel
labels_type = Sequence(class_labels)
num_labels = len(class_labels.names)


# # Remove unnecessary columns if needed
# dataset = dataset.remove_columns(['Image Index', 'Finding Labels', 'Follow-up #', 'Patient ID', 'Patient Age', 'Patient Gender'])

# Create a dictionary from the metadata for quick lookup
metadata_dict = metadata_df.set_index('Image Index').to_dict(orient='index')

# Add metadata to the dataset, including the sequence of class labels
def add_metadata(example):
    filename = example['filename']
    if filename in metadata_dict:
        metadata = metadata_dict[filename]
        example.update(metadata)
        # example['labels_list'] = [class_labels.str2int(label) if label in class_labels.names else 'No Finding' for label in metadata['Finding Labels']]
        example['labels'] = [float(class_labels.int2str(x) in metadata['Finding Labels']) for x in range(num_labels)]
    return example

# Apply the metadata and features to the dataset
dataset = dataset.map(add_metadata)



Resolving data files:   0%|          | 0/5606 [00:00<?, ?it/s]

Map:   0%|          | 0/5606 [00:00<?, ? examples/s]

In [2]:
# # filter data with no finding label; we can also down-sample it.
dataset_only_finding = dataset.filter(lambda example: sum(example['labels']) >= 1.0)
print(len(dataset), len(dataset_only_finding))
dataset = dataset_only_finding

Filter:   0%|          | 0/5606 [00:00<?, ? examples/s]

5606 2562


### data split
train : valid : test with ratio of 6:2:2.


In [4]:
train_testvalid = dataset.train_test_split(test_size=0.4, seed=42)
train_ds = train_testvalid['train']
test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=42)
val_ds = test_valid['train']
test_ds = test_valid['test']

### Preprocessing the data
We will now preprocess the data. The model requires 2 things: pixel_values and labels.

We will perform data augmentaton on-the-fly using HuggingFace Datasets' set_transform method (docs can be found here). This method is kind of a lazy map: the transform is only applied when examples are accessed. This is convenient for tokenizing or padding text, or augmenting images at training time for example, as we will do here.

In [7]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")



In [8]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)
import torch
import torch.nn as nn

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            # Resize(size),
            # CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

# Set the transforms
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)

In [9]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples]).float() # change for one-hot multilabels
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)
    if k == 'labels':
      print(v)

### Define the model
Here we define the model. We define a ViTForImageClassification, which places a linear layer (nn.Linear) on top of a pre-trained ViTModel. The linear layer is placed on top of the last hidden state of the [CLS] token, which serves as a good representation of an entire image.

The model itself is pre-trained on ImageNet-21k, a dataset of 14 million labeled images. You can find all info of the model we are going to use here.

We also specify the number of output neurons by setting the id2label and label2id mapping, which we be added as attributes to the configuration of the model (which can be accessed as ```model.config```).

In [10]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')

# Load the model and configure it with the number of labels
labels_list = class_labels.names
num_labels = len(labels_list)
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_labels,
    id2label = dict(zip(list(range(0, num_labels)), labels_list)),
    label2id = dict(zip(labels_list, list(range(0, num_labels))))
)

# # Print model configuration to verify
# print(model.config)

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Visualize the model

In [11]:
def compute_freq(ground_labels):
    num_samples = ground_labels.shape[0]
    pos_samples = np.sum(ground_labels,axis=0)
    neg_samples = num_samples-pos_samples
    pos_samples = pos_samples/float(num_samples)
    neg_samples = neg_samples/float(num_samples)
    return pos_samples, neg_samples

ground_labels = []
for i in train_ds:
    ground_labels.append(i['labels'])
ground_labels = np.array(ground_labels)
print(ground_labels.shape)
freq_pos, freq_neg = compute_freq(ground_labels)

(1537, 14)


In [16]:
len(labels_list), labels_list

(14,
 ['Nodule',
  'Emphysema',
  'Effusion',
  'Consolidation',
  'Atelectasis',
  'Infiltration',
  'Pneumothorax',
  'Edema',
  'Hernia',
  'Fibrosis',
  'Pneumonia',
  'Pleural_Thickening',
  'Mass',
  'Cardiomegaly'])

In [12]:
freq_pos, freq_neg

(array([0.12752115, 0.04814574, 0.24333116, 0.0826285 , 0.19323357,
        0.3812622 , 0.09629148, 0.04424203, 0.00455433, 0.0325309 ,
        0.02667534, 0.06701366, 0.10930384, 0.05465192]),
 array([0.87247885, 0.95185426, 0.75666884, 0.9173715 , 0.80676643,
        0.6187378 , 0.90370852, 0.95575797, 0.99544567, 0.9674691 ,
        0.97332466, 0.93298634, 0.89069616, 0.94534808]))

In [17]:
from transformers import TrainingArguments, Trainer
from torch import nn
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

metric_name = "f1"

args = TrainingArguments(
    f"fine-tune-ViT-on-NIH",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    y_true = labels
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= 0.5)] = 1
    # 
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

# Compute weights for each class
weights = np.array(freq_neg, dtype=np.float32) / np.array(freq_pos, dtype=np.float32)

class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get('logits')
        device = 'cpu'
        if torch.cuda.is_available():
            device = 'cuda'
        elif torch.backends.mps.is_available():
            device = 'mps'
        loss_fct = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(weights, dtype=torch.float).to(device))
        # loss_fct = nn.BCEWithLogitsLoss()
        loss = loss_fct(logits.view(-1, self.model.config.num_labels),
                        labels.float().view(-1, self.model.config.num_labels))
        return (loss, outputs) if return_outputs else loss
    
trainer = MultilabelTrainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [23]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir logs/
# %tensorboard dev upload --logdir 'logs/'

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


ERROR: Failed to launch TensorBoard (exited with 1).
Contents of stderr:
TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Address already in use
Port 6006 is in use by another program. Either identify and stop that program, or start the server with a different port.
Contents of stdout:

In [20]:
trainer.train()

  0%|          | 0/970 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.181014060974121, 'eval_f1': 0.28662976988057093, 'eval_roc_auc': 0.6434523335932403, 'eval_accuracy': 0.0, 'eval_runtime': 17.4238, 'eval_samples_per_second': 29.385, 'eval_steps_per_second': 1.837, 'epoch': 1.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.1598867177963257, 'eval_f1': 0.25787662165740005, 'eval_roc_auc': 0.6129286475820697, 'eval_accuracy': 0.0, 'eval_runtime': 17.8767, 'eval_samples_per_second': 28.641, 'eval_steps_per_second': 1.79, 'epoch': 2.0}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.1510090827941895, 'eval_f1': 0.2325715938675152, 'eval_roc_auc': 0.5774115653294719, 'eval_accuracy': 0.00390625, 'eval_runtime': 17.4591, 'eval_samples_per_second': 29.326, 'eval_steps_per_second': 1.833, 'epoch': 3.0}


KeyboardInterrupt: 

### Evaluate on Testds 
Consider metrics for multi-class classification


In [26]:
outputs = trainer.predict(test_ds)
print(outputs.metrics)


  0%|          | 0/33 [00:00<?, ?it/s]

{'test_loss': 1.368661880493164, 'test_f1': 0.26852713178294574, 'test_roc_auc': 0.6082420281660861, 'test_accuracy': 0.001949317738791423, 'test_runtime': 18.0475, 'test_samples_per_second': 28.425, 'test_steps_per_second': 1.829}


In [30]:
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from transformers import EvalPrediction
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}

    print(classification_report(y_true=y_true.astype(int), y_pred=y_pred, target_names=class_labels.names))
    # labels = train_ds.features['labels_list']
    # cm = confusion_matrix(y_true, y_pred)
    # disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    # disp.plot(xticks_rotation=45)

    return metrics

y_true = outputs.label_ids
y_pred = outputs.predictions #.argmax(1)

multi_label_metrics(y_pred, y_true)

                    precision    recall  f1-score   support

            Nodule       0.13      0.53      0.20        59
         Emphysema       0.29      0.18      0.22        28
          Effusion       0.38      0.61      0.47       133
     Consolidation       0.15      0.68      0.24        47
       Atelectasis       0.31      0.42      0.36       107
      Infiltration       0.38      0.53      0.44       187
      Pneumothorax       0.31      0.61      0.41        66
             Edema       0.11      0.80      0.19        25
            Hernia       0.00      0.00      0.00         5
          Fibrosis       0.07      0.40      0.12        20
         Pneumonia       0.05      0.73      0.09        11
Pleural_Thickening       0.10      0.79      0.18        39
              Mass       0.31      0.13      0.19        68
      Cardiomegaly       0.06      0.86      0.12        28

         micro avg       0.18      0.53      0.27       823
         macro avg       0.19      0.5

  _warn_prf(average, modifier, msg_start, len(result))


{'f1': 0.26852713178294574,
 'roc_auc': 0.6082420281660861,
 'accuracy': 0.001949317738791423}

In [35]:
outputs = trainer.predict(train_ds)
print(outputs.metrics)
y_true = outputs.label_ids
y_pred = outputs.predictions #.argmax(1)

multi_label_metrics(y_pred, y_true)

  0%|          | 0/97 [00:00<?, ?it/s]

{'test_loss': 1.1852531433105469, 'test_f1': 0.2862757265487641, 'test_roc_auc': 0.6425897990322267, 'test_accuracy': 0.0, 'test_runtime': 51.632, 'test_samples_per_second': 29.768, 'test_steps_per_second': 1.879}
                    precision    recall  f1-score   support

            Nodule       0.18      0.74      0.29       196
         Emphysema       0.19      0.41      0.26        74
          Effusion       0.34      0.71      0.46       374
     Consolidation       0.13      0.61      0.21       127
       Atelectasis       0.28      0.40      0.33       297
      Infiltration       0.43      0.58      0.49       586
      Pneumothorax       0.19      0.59      0.29       148
             Edema       0.11      0.81      0.20        68
            Hernia       0.00      0.00      0.00         7
          Fibrosis       0.10      0.64      0.17        50
         Pneumonia       0.06      0.66      0.11        41
Pleural_Thickening       0.10      0.83      0.17       103
     

  _warn_prf(average, modifier, msg_start, len(result))


{'f1': 0.2862757265487641, 'roc_auc': 0.6425897990322267, 'accuracy': 0.0}

### Save the best fine-tuned model`

In [None]:
trainer.save_model()

In [None]:
my_model = ViTForImageClassification.from_pretrained("./fine-tune-ViT-on-NIH/")

In [None]:
my_trainer = MultilabelTrainer(
    my_model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [None]:
outputs = my_trainer.predict(test_ds)
print(outputs.metrics)

y_true = outputs.label_ids
y_pred = outputs.predictions

multi_label_metrics(y_pred, y_true)