<a href="https://colab.research.google.com/github/Smart-Lizard/Med_Image_Generation/blob/main/ChestMNIST_Classification_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install medmnist

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.6.0.tar.gz (88 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/88.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.4/88.4 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.6.0-py2.py3-none-any.whl size=117030 sha256=bfaecd4ce8a973b4d2d5b3d984dd34aa733dd48e56f647eda41d143d23853460
  Stored in directory: /root/.cache/pip/wheels/d6/6d/5d/5b73fa0f46d01a793713f8859201361e9e581ced8c75e5c6a3
Successfully built fire
Installing collected packages: fire, medmnist
Successfully installed fire-0.6.0 medmnist-3.0.2


In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator

In [None]:
data_flag = 'chestmnist'
download = True

BATCH_SIZE = 32
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download, size=224, mmap_mode='r')

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

Downloading https://zenodo.org/records/10519652/files/chestmnist_224.npz?download=1 to /root/.medmnist/chestmnist_224.npz


100%|██████████| 3889293042/3889293042 [03:46<00:00, 17175237.34it/s]


In [None]:
# load the validation and test data
val_dataset = DataClass(split='val', transform=data_transform, download=download, size=224, mmap_mode='r')
test_dataset = DataClass(split='test', transform=data_transform, download=download, size=224, mmap_mode='r')

# encapsulate data into dataloader form
val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

label_mapping = ['atelectasis', 'cardiomegaly', 'effusion', 'infiltration', 'mass', 'nodule', 'pneumonia', 'pneumothorax', 'consolidation', 'edema', 'emphysema', 'fibrosis', 'pleural', 'hernia']

Using downloaded and verified file: /root/.medmnist/chestmnist_224.npz
Using downloaded and verified file: /root/.medmnist/chestmnist_224.npz


In [None]:
!apt-get install git
!git clone https://github.com/rsm-13/classifying-chestMNIST.git
%cd classifying-chestMNIST

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git is already the newest version (1:2.34.1-1ubuntu1.11).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.
Cloning into 'classifying-chestMNIST'...
remote: Enumerating objects: 17, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 17 (delta 2), reused 0 (delta 0), pack-reused 11 (from 1)[K
Receiving objects: 100% (17/17), 158.40 MiB | 12.28 MiB/s, done.
Resolving deltas: 100% (2/2), done.
/content/classifying-chestMNIST


In [None]:
import sys
sys.path.insert(0, '/chestMNIST')
from models import ResNet18
net = ResNet18(in_channels=3, num_classes=14)

#### Hyperparameters and Testing Loop

In [None]:
num_epochs = 100
lr = 0.001
gamma=0.1
milestones = [0.5 * num_epochs, 0.75 * num_epochs]

# Optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

# Scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

# Loss function (cross entropy for classification)
loss_func = nn.BCEWithLogitsLoss()

In [None]:
def getAUC(y_true, y_score):
    '''AUC metric.
    :param y_true: the ground truth labels, shape: (n_samples, n_labels) or (n_samples,) if n_labels==1
    :param y_score: the predicted score of each class,
    shape: (n_samples, n_labels) or (n_samples, n_classes) or (n_samples,) if n_labels==1 or n_classes==1
    :param task: the task of current dataset
    '''
    y_true = y_true.squeeze()
    y_score = y_score.squeeze()

    auc = 0
    for i in range(y_score.shape[1]):
        label_auc = roc_auc_score(y_true[:, i], y_score[:, i])
        auc += label_auc
    ret = auc / y_score.shape[1]

    return ret

In [None]:
def getACC(y_true, y_score, threshold=0.5):
    '''Accuracy metric.
    :param y_true: the ground truth labels, shape: (n_samples, n_labels) or (n_samples,) if n_labels==1
    :param y_score: the predicted score of each class,
    shape: (n_samples, n_labels) or (n_samples, n_classes) or (n_samples,) if n_labels==1 or n_classes==1
    :param task: the task of current dataset
    :param threshold: the threshold for multilabel and binary-class tasks
    '''
    y_true = y_true.squeeze()
    y_score = y_score.squeeze()

    y_pre = y_score > threshold
    acc = 0
    for label in range(y_true.shape[1]):
        label_acc = accuracy_score(y_true[:, label], y_pre[:, label])
        acc += label_acc
    ret = acc / y_true.shape[1]

    return ret

In [None]:
from sklearn.metrics import roc_auc_score, accuracy_score
def test(model, split_labels, data_loader, criterion, device='cuda', raw=False):
    model.cuda()
    model.eval()

    total_loss = []
    y_score = torch.tensor([]).to('cpu')

    with torch.no_grad():  # No gradient computation in evaluation mode
        for batch in data_loader:
            inp, labels = batch

            # Expand from 1 channel to 3 channels in validation/test as well
            if inp.shape[1] == 1:  # If grayscale, repeat the channel to make it RGB
                inp = inp.repeat(1, 3, 1, 1)

            inp = inp.cuda(non_blocking=True).float()
            out = model(inp)
            labels = labels.to(torch.float32).cuda(non_blocking=True)
            loss = criterion(out, labels)

            # Get predictions from scores
            sigmoid = torch.nn.Sigmoid()
            answers = sigmoid(out).data.cpu()

            # Recording values
            y_score = torch.cat((y_score, answers), 0)
            total_loss.append(loss.item())

        y_score = y_score.cpu().data.numpy()
        auc = getAUC(split_labels, y_score)
        acc = getACC(split_labels, y_score)

        testing_loss = np.mean(total_loss)

        if raw:
            return [testing_loss, auc, acc, split_labels, y_score]

        return [testing_loss, auc, acc]

### Training Loop

In [None]:
net.cuda()

best_epoch = 0
best_auc = 0
best_model = net
val_labels = val_dataset.labels

for epoch in range(num_epochs): # We go over the data ten times
    losses = []
    net.train()
    for batch in train_loader:
        optimizer.zero_grad()

        # Forward pass
        inp, labels = batch

        # Expand from 1 channel to 3 channels
        if inp.shape[1] == 1:
            inp = inp.repeat(1, 3, 1, 1)

        inp = inp.cuda().to(torch.float32)
        out = net(inp)
        labels = labels.to(torch.float32).cuda()
        loss = loss_func(out, labels)
        losses.append(loss.item())

        # Backward pass
        loss.backward()
        optimizer.step()

    train_loss = np.mean(losses)
    val_metrics = test(net, val_labels, val_loader, loss_func)

    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = net
        print(f"Epoch {best_epoch} is the best yet with Val AUC = {best_auc}")
        torch.save(best_model.state_dict(), '/content/drive/MyDrive/best_model.pth')

    scheduler.step()

Epoch 0 is the best yet with Val AUC = 0.6778519563574121
Epoch 1 is the best yet with Val AUC = 0.7027811745649607


In [None]:
torch.save(net.state_dict(), '/content/drive/MyDrive/final_model.pth')

### Metrics

In [None]:
from sklearn.metrics import *

In [None]:
test_metrics = test(net, chest['test_labels'], test_dataloader, loss_func, raw=True)

In [None]:
y_true, y_score = test_metrics[-2], test_metrics[-1]
print(f"Test AUC: {test_metrics[1]:5f} \nTest ACC: {test_metrics[2]:5f}")

In [None]:
y_pre = y_score > 0.5
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pre, average='samples', zero_division=1)
print(f"Test Precision: {precision:.5f}\nTest Recall: {recall:.5f}\nTest F1: {f1:.5f}")

This high precision and low recall means that our model predicts many samples as positive – we will have a high number of true positives, but a high number of false positives, too.

Intuitively, you can think of this as the model casting a wide net to catch a lot of the positive samples, which it does, but it also catches other things too.

We're calculating the above metrics sample-wise, i.e., we compute all three metrics for all samples separately, and returning the sample-weighted average.


We're not calculating the metrics above class-wise, so it makes sense that it won't be similar to class-wise precision-recall-f1 scores computed belowh. I would recommend sticking to the overall accuracies listed above, as they are the most representative.

In [None]:
for i in range(y_pre.shape[1]):
    precision, recall, f1, _ = precision_recall_fscore_support(y_true[:, i], y_pre[:, i], average='micro', zero_division=1)
    print(f"Class {i}: Precision: {precision:.5f}\tRecall: {recall:.5f}\tF1: {f1:.5f}")

In [None]:
multilabel_confusion_matrix(y_true, y_pre, samplewise=False)

In [None]:
for i in range(y_score.shape[1]):
    ConfusionMatrixDisplay.from_predictions(y_true[:, i], y_pre[:, i])