In [3]:
from datasets import load_dataset, Image
import os
import numpy as np

root_dir = './NIH-small/sample/'

dataset = load_dataset('imagefolder', split='train', data_dir=os.path.join(root_dir, 'images'))
# Add a filename column
def add_filename(example):
    example['filename'] = os.path.basename(example['image'].filename)
    return example

dataset = dataset.map(add_filename)

dataset = dataset.cast_column("image", Image(mode="RGB"))

# Load the metadata from the CSV file
import pandas as pd
metadata_file = os.path.join(root_dir, 'sample_labels.csv')
# Load the metadata from the CSV file
metadata_df = pd.read_csv(metadata_file)

# Create a dictionary from the metadata for quick lookup
metadata_dict = metadata_df.set_index('Image Index').to_dict(orient='index')

# Add metadata to the dataset
def add_metadata(example):
    filename = example['filename']
    if filename in metadata_dict:
        metadata = metadata_dict[filename]
        example.update(metadata)
    return example

dataset = dataset.map(add_metadata)

from datasets.features import ClassLabel, Sequence

# Split "Finding Labels" into multiple labels
metadata_df['Finding Labels'] = metadata_df['Finding Labels'].str.split('|')

# Get all unique labels
all_labels = set(label for sublist in metadata_df['Finding Labels'] for label in sublist)
# as no finding label affects so many images, most implementations remove "no finding" label.
all_labels.remove('No Finding')

### #TODO: only select some labels
all_labels = set(['Infiltration', 'Effusion', 'Atelectasis', 'Nodule', 'Pneumothorax']) 

# Create a ClassLabel feature for each unique label
class_labels = ClassLabel(names=list(all_labels))

# Define the label feature as a sequence of ClassLabel
labels_type = Sequence(class_labels)
num_labels = len(class_labels.names)


# # Remove unnecessary columns if needed
# dataset = dataset.remove_columns(['Image Index', 'Finding Labels', 'Follow-up #', 'Patient ID', 'Patient Age', 'Patient Gender'])

# Create a dictionary from the metadata for quick lookup
metadata_dict = metadata_df.set_index('Image Index').to_dict(orient='index')

# Add metadata to the dataset, including the sequence of class labels
def add_metadata(example):
    filename = example['filename']
    if filename in metadata_dict:
        metadata = metadata_dict[filename]
        example.update(metadata)
        # example['labels_list'] = [class_labels.str2int(label) if label in class_labels.names else 'No Finding' for label in metadata['Finding Labels']]
        example['labels'] = [float(class_labels.int2str(x) in metadata['Finding Labels']) for x in range(num_labels)]
    return example

# Apply the metadata and features to the dataset
dataset = dataset.map(add_metadata)

Resolving data files:   0%|          | 0/5606 [00:00<?, ?it/s]

Map:   0%|          | 0/5606 [00:00<?, ? examples/s]

In [4]:
# # filter data with no finding label; we can also down-sample it.
dataset_only_finding = dataset.filter(lambda example: sum(example['labels']) >= 1.0)
print(len(dataset), len(dataset_only_finding))
dataset = dataset_only_finding

Filter:   0%|          | 0/5606 [00:00<?, ? examples/s]

5606 2065


### data split
train : valid : test with ratio of 6:2:2.


In [5]:
train_testvalid = dataset.train_test_split(test_size=0.4, seed=42)
train_ds = train_testvalid['train']
test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=42)
val_ds = test_valid['train']
test_ds = test_valid['test']

### Preprocessing the data
We will now preprocess the data. The model requires 2 things: pixel_values and labels.

We will perform data augmentaton on-the-fly using HuggingFace Datasets' set_transform method (docs can be found here). This method is kind of a lazy map: the transform is only applied when examples are accessed. This is convenient for tokenizing or padding text, or augmenting images at training time for example, as we will do here.

In [11]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)
import torch
import torch.nn as nn

size = 224

_train_transforms = Compose(
        [
            # RandomResizedCrop(size),
            # RandomHorizontalFlip(),
            Resize(size),
            ToTensor(),
            # normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            # CenterCrop(size),
            ToTensor(),
            # normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

# Set the transforms
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)

In [None]:
# from monai.transforms import (
#     Activations,
#     EnsureChannelFirst,
#     AsDiscrete,
#     Compose,
#     LoadImage,
#     RandFlip,
#     RandRotate,
#     RandZoom,
#     ScaleIntensity,
# )
# import torch
# import torch.nn as nn

# # # following: https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb

# _train_transforms = Compose(
#     [
#         RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
#         RandFlip(spatial_axis=0, prob=0.5),
#         RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
#     ]
# )

# _val_transforms = Compose([])

# def train_transforms(examples):
#     examples['pixel_values'] = [_train_transforms(image) for image in examples['image']]
#     return examples

# def val_transforms(examples):
#     examples['pixel_values'] = [_val_transforms(image) for image in examples['image']]
#     return examples

# # Set the transforms
# train_ds.set_transform(train_transforms)
# val_ds.set_transform(val_transforms)
# test_ds.set_transform(val_transforms)

In [None]:
# from monai.transforms import LoadImageD, EnsureChannelFirstD, ScaleIntensityD, Compose

# transform = Compose(
#     [
#         LoadImageD(keys="image", image_only=True),
#         EnsureChannelFirstD(keys="image"),
#         ScaleIntensityD(keys="image"),
#     ]
# )
# transform(train_ds[0]['image'])

In [12]:
from torch.utils.data import DataLoader

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples]).to(device)
    labels = torch.tensor([example["labels"] for example in examples]).to(device).float() # change for one-hot multilabels
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=4)

In [13]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)
    if k == 'labels':
      print(v)

pixel_values torch.Size([4, 3, 224, 224])
labels torch.Size([4, 5])
tensor([[0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.]], device='mps:0')


### Define the model

In [14]:
import torchvision.models as models

# Define the model
class ResNetMultiLabel(nn.Module):
    def __init__(self, num_classes):
        super(ResNetMultiLabel, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
        
    def forward(self, x):
        return self.resnet(x)

# Instantiate the model
model = ResNetMultiLabel(num_labels).to(device)



### Visualize the model

In [None]:
# def compute_freq(ground_labels):
#     num_samples = ground_labels.shape[0]
#     pos_samples = np.sum(ground_labels,axis=0)
#     neg_samples = num_samples-pos_samples
#     pos_samples = pos_samples/float(num_samples)
#     neg_samples = neg_samples/float(num_samples)
#     return pos_samples, neg_samples

# ground_labels = []
# for i in train_ds:
#     ground_labels.append(i['labels'])
# ground_labels = np.array(ground_labels)
# print(ground_labels.shape)
# freq_pos, freq_neg = compute_freq(ground_labels)

# freq_pos, freq_neg

In [15]:
from transformers import TrainingArguments, Trainer
from torch import nn
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from tqdm import trange
from tqdm.notebook import tqdm


# # Loss function and optimizer; @TODO: Alternative way is to find the best thresholds for labels on the validation set.
# weights = np.array(freq_neg, dtype=np.float32) / np.array(freq_pos, dtype=np.float32)
# criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(weights, dtype=torch.float).to(device))

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.0
    last_loss = 0.
    pbar = tqdm(enumerate(train_dataloader), unit="batch", total=len(train_dataloader))
    for i, data in pbar:
        inputs, labels = data['pixel_values'], data['labels']
        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        pbar.set_description('  batch {} loss: {}'.format(i + 1, loss.item()))
        if i % 10 == 9:
            last_loss = running_loss / 10 # loss per batch
            # pbar.set_description('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_dataloader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('logs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 10

best_vloss = 1_000_000.

pbar = trange(EPOCHS)
for epoch in pbar:
    pbar.set_description('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    running_vloss = 0.0
    roc_auc = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    sigmoid = torch.nn.Sigmoid()
    with torch.no_grad():
        for i, vdata in enumerate(val_dataloader):
            vinputs, vlabels = vdata['pixel_values'], vdata['labels']
            voutputs = sigmoid(model(vinputs))
            vloss = criterion(voutputs, vlabels)
            roc_auc += roc_auc_score(vlabels.cpu().numpy(), voutputs.cpu().numpy(), average = 'micro')
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {} valid_roc_auc {}'.format(avg_loss, avg_vloss, roc_auc / len(val_dataloader)))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
EPOCH 1::   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 2::  10%|█         | 1/10 [00:49<07:21, 49.05s/it]

LOSS train 0.5073820292949677 valid 0.7439332604408264 valid_roc_auc70.75150814463316


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 3::  20%|██        | 2/10 [01:38<06:32, 49.10s/it]

LOSS train 0.2960722461342812 valid 0.7433245778083801 valid_roc_auc71.67767891830391


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 4::  30%|███       | 3/10 [02:25<05:37, 48.22s/it]

LOSS train 0.16382842659950256 valid 0.7385877966880798 valid_roc_auc71.92209061771561


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 5::  40%|████      | 4/10 [03:10<04:42, 47.10s/it]

LOSS train 0.13172145262360574 valid 0.7303011417388916 valid_roc_auc73.24285457597958


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 6::  50%|█████     | 5/10 [03:56<03:52, 46.51s/it]

LOSS train 0.13462840691208838 valid 0.7256736159324646 valid_roc_auc73.25528443778444


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 7::  60%|██████    | 6/10 [04:41<03:04, 46.14s/it]

LOSS train 0.08356831427663565 valid 0.722332239151001 valid_roc_auc73.9336976911977


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 8::  70%|███████   | 7/10 [05:26<02:17, 45.80s/it]

LOSS train 0.09317317251116038 valid 0.7263154983520508 valid_roc_auc74.54063818126318


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 9::  80%|████████  | 8/10 [06:11<01:31, 45.61s/it]

LOSS train 0.07238347614184022 valid 0.7274338603019714 valid_roc_auc76.15835567210566


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 10::  90%|█████████ | 9/10 [06:57<00:45, 45.61s/it]

LOSS train 0.04813537690788507 valid 0.7241875529289246 valid_roc_auc74.12788003662996


  0%|          | 0/310 [00:00<?, ?batch/s]

EPOCH 10:: 100%|██████████| 10/10 [07:43<00:00, 46.36s/it]

LOSS train 0.049435443803668024 valid 0.7280676960945129 valid_roc_auc75.66183864746365





### Evaluate on Testds 
Consider metrics for multi-class classification


In [16]:
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from transformers import EvalPrediction
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, probs.cpu().numpy(), average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}

    print(classification_report(y_true=y_true.astype(int), y_pred=y_pred, target_names=class_labels.names))
    # labels = train_ds.features['labels_list']
    # cm = confusion_matrix(y_true, y_pred)
    # disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    # disp.plot(xticks_rotation=45)

    return metrics

# y_true = outputs.label_ids
# y_pred = outputs.predictions #.argmax(1)

# multi_label_metrics(y_pred, y_true)

model.eval()
y_true = torch.tensor([], dtype=torch.long)
y_pred = torch.tensor([])
with torch.no_grad():
    test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=32)
    for i, vdata in tqdm(enumerate(test_dataloader), unit="batch", total=len(test_dataloader)):
        vinputs, vlabels = vdata['pixel_values'], vdata['labels'].cpu()
        voutputs = model(vinputs).cpu()
        y_pred = torch.cat((y_pred, voutputs), 0)
        y_true = torch.cat((y_true, vlabels), 0)
multi_label_metrics(y_pred, y_true.numpy())
        


  0%|          | 0/13 [00:00<?, ?batch/s]

              precision    recall  f1-score   support

Pneumothorax       0.45      0.22      0.29        60
 Atelectasis       0.58      0.33      0.42        98
      Nodule       0.29      0.25      0.26        65
Infiltration       0.55      0.70      0.62       183
    Effusion       0.50      0.53      0.51       114

   micro avg       0.51      0.48      0.49       520
   macro avg       0.47      0.40      0.42       520
weighted avg       0.50      0.48      0.48       520
 samples avg       0.46      0.50      0.46       520



  _warn_prf(average, modifier, msg_start, len(result))


{'f1': 0.4925816023738872,
 'roc_auc': 0.7435499128703013,
 'accuracy': 0.26150121065375304}

In [17]:
def evaluate(test_dataloader):
    y_true = torch.tensor([], dtype=torch.long)
    y_pred = torch.tensor([])
    with torch.no_grad():
        for i, vdata in tqdm(enumerate(test_dataloader), unit="batch", total=len(test_dataloader)):
            vinputs, vlabels = vdata['pixel_values'], vdata['labels'].cpu()
            voutputs = model(vinputs).cpu()
            y_pred = torch.cat((y_pred, voutputs), 0)
            y_true = torch.cat((y_true, vlabels), 0)
    return multi_label_metrics(y_pred, y_true.numpy())
evaluate(train_dataloader)

  0%|          | 0/310 [00:00<?, ?batch/s]

              precision    recall  f1-score   support

Pneumothorax       0.96      0.85      0.90       154
 Atelectasis       0.89      0.88      0.89       305
      Nodule       0.80      0.86      0.83       185
Infiltration       0.85      0.93      0.89       607
    Effusion       0.85      0.92      0.88       387

   micro avg       0.86      0.90      0.88      1638
   macro avg       0.87      0.89      0.88      1638
weighted avg       0.86      0.90      0.88      1638
 samples avg       0.85      0.90      0.86      1638



  _warn_prf(average, modifier, msg_start, len(result))


{'f1': 0.8796185935637664,
 'roc_auc': 0.9759203929710842,
 'accuracy': 0.7280064568200162}