In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd "/content/drive/MyDrive/experiments/MedMNIST2D"


/content/drive/MyDrive/experiments/MedMNIST2D


In [3]:
pip install -r requirements.txt

Collecting ACSConv==0.1.1 (from -r requirements.txt (line 1))
  Downloading ACSConv-0.1.1.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting matplotlib-inline==0.1.6 (from -r requirements.txt (line 3))
  Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)
Collecting medmnist==3.0.1 (from -r requirements.txt (line 4))
  Downloading medmnist-3.0.1-py3-none-any.whl (25 kB)
Collecting numpy==1.24.4 (from -r requirements.txt (line 5))
  Downloading numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m74.4 MB/s[0m eta [36m0:00:00[0m
Collecting pillow==10.2.0 (from -r requirements.txt (line 8))
  Downloading pillow-10.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (4.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m69.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scikit-image==0.22.0 (from -r requirem

In [4]:
import sys
sys.path.append('/content/drive/MyDrive/experiments/MedMNIST2D/models')


In [5]:
import argparse
import os
import time
from collections import OrderedDict
from copy import deepcopy

import medmnist
import numpy as np
import PIL
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from medmnist import INFO, Evaluator
from models import ResNet18, ResNet50
from torchvision.models import resnet18, resnet50
from tqdm import tqdm

ImportError: cannot import name 'ResNet18' from 'models' (/content/drive/MyDrive/experiments/MedMNIST2D/models/__init__.py)

See the number of channels in each dataset

In [None]:
data_list=['pathmnist','chestmnist','dermamnist','octmnist','pneumoniamnist','retinamnist','breastmnist','bloodmnist','tissuemnist','organamnist','organcmnist','organsmnist']


for data_flag in data_list:
    info = INFO[data_flag]
    n_channels = info['n_channels']
    print(f'{data_flag}: {n_channels} channels')
task = info['task']

n_classes = len(info['label'])

pathmnist: 3 channels
chestmnist: 1 channels
dermamnist: 3 channels
octmnist: 1 channels
pneumoniamnist: 1 channels
retinamnist: 3 channels
breastmnist: 1 channels
bloodmnist: 3 channels
tissuemnist: 1 channels
organamnist: 1 channels
organcmnist: 1 channels
organsmnist: 1 channels


see the labels from one particular dataset

In [None]:
info=INFO["pathmnist"]
DataClass = getattr(medmnist, info['python_class'])
d=DataClass(split='train',download=True)

Downloading https://zenodo.org/records/10519652/files/pathmnist.npz?download=1 to /root/.medmnist/pathmnist.npz


100%|██████████| 205615438/205615438 [00:07<00:00, 27877060.17it/s]


In [None]:
type(d)

In [None]:
[d[i][1] for i in range(100)]

[array([0]),
 array([4]),
 array([7]),
 array([5]),
 array([5]),
 array([8]),
 array([3]),
 array([3]),
 array([5]),
 array([2]),
 array([8]),
 array([5]),
 array([8]),
 array([2]),
 array([1]),
 array([1]),
 array([5]),
 array([4]),
 array([3]),
 array([2]),
 array([8]),
 array([2]),
 array([5]),
 array([8]),
 array([3]),
 array([8]),
 array([5]),
 array([2]),
 array([3]),
 array([3]),
 array([1]),
 array([3]),
 array([7]),
 array([0]),
 array([1]),
 array([8]),
 array([3]),
 array([0]),
 array([8]),
 array([5]),
 array([4]),
 array([3]),
 array([7]),
 array([1]),
 array([6]),
 array([5]),
 array([6]),
 array([0]),
 array([1]),
 array([2]),
 array([0]),
 array([1]),
 array([0]),
 array([1]),
 array([8]),
 array([4]),
 array([6]),
 array([5]),
 array([1]),
 array([5]),
 array([3]),
 array([7]),
 array([4]),
 array([0]),
 array([0]),
 array([6]),
 array([2]),
 array([7]),
 array([1]),
 array([5]),
 array([6]),
 array([6]),
 array([5]),
 array([8]),
 array([0]),
 array([7]),
 array([2]),

## data processing

In [None]:
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

In [None]:
def convert_to_rgb(pil_img):

    return pil_img.convert('RGB')


In [None]:
def process_train_datasets(data_list, split='train',download=True,transform=None):
    all_images = []
    all_labels = []
    current_max_label = 0
    label_mapping = {}

    for data_flag in data_list:
        dataset_info = INFO[data_flag]
        num_labels = len(dataset_info["label"])
        DataClass = getattr(medmnist, dataset_info['python_class'])
        dataset_instance = DataClass(split=split,transform=transform,download=download)

        for i in range(len(dataset_instance)):
            img, label = dataset_instance[i]

            # if grayscale, convert to RGB
            if dataset_info['n_channels'] == 1:
                img = convert_to_rgb(img)

            all_images.append(img)

            # update label id
            if dataset_info["task"]=="multi-label, binary-class":
                one_hot_label = torch.tensor(label)
                label = torch.nonzero(one_hot_label, as_tuple=False).squeeze().tolist()
                for l in label:
                    new_label=current_max_label+l
                    if new_label not in label_mapping:
                        str_label=str(l)
                        label_mapping[new_label] = dataset_info["label"][str_label]
            else:
                new_label=current_max_label+label[0]
                if new_label not in label_mapping:
                    str_label=str(label[0])
                    label_mapping[new_label] = dataset_info["label"][str_label]

            all_labels.append(new_label)
        current_max_label = num_labels + current_max_label
    return all_images, all_labels, label_mapping

In [None]:
from torch.utils.data import Dataset,DataLoader

class Dataset2D(Dataset):
    def __init__(self, images, labels, transform=data_transform):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
all_train_images, all_train_labels, label_mapping = process_train_datasets(data_list, split='train',download=True)
all_test_images, all_test_labels, label_mapping_test = process_train_datasets(data_list, split='test',download=True)

Using downloaded and verified file: /home/hanwzhan/.medmnist/pathmnist.npz
Downloading https://zenodo.org/records/10519652/files/chestmnist.npz?download=1 to /home/hanwzhan/.medmnist/chestmnist.npz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 82802576/82802576 [00:02<00:00, 38572245.52it/s]


Downloading https://zenodo.org/records/10519652/files/dermamnist.npz?download=1 to /home/hanwzhan/.medmnist/dermamnist.npz


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19725078/19725078 [00:00<00:00, 137552660.77it/s]


Downloading https://zenodo.org/records/10519652/files/octmnist.npz?download=1 to /home/hanwzhan/.medmnist/octmnist.npz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 54938180/54938180 [00:01<00:00, 33055820.12it/s]


Downloading https://zenodo.org/records/10519652/files/pneumoniamnist.npz?download=1 to /home/hanwzhan/.medmnist/pneumoniamnist.npz


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4170669/4170669 [00:00<00:00, 58719453.49it/s]


Downloading https://zenodo.org/records/10519652/files/retinamnist.npz?download=1 to /home/hanwzhan/.medmnist/retinamnist.npz


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3291041/3291041 [00:00<00:00, 55482151.62it/s]


Downloading https://zenodo.org/records/10519652/files/breastmnist.npz?download=1 to /home/hanwzhan/.medmnist/breastmnist.npz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 559580/559580 [00:00<00:00, 32094635.95it/s]


Downloading https://zenodo.org/records/10519652/files/bloodmnist.npz?download=1 to /home/hanwzhan/.medmnist/bloodmnist.npz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35461855/35461855 [00:00<00:00, 50277945.85it/s]


Downloading https://zenodo.org/records/10519652/files/tissuemnist.npz?download=1 to /home/hanwzhan/.medmnist/tissuemnist.npz


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 124962739/124962739 [00:02<00:00, 54070380.09it/s]


Downloading https://zenodo.org/records/10519652/files/organamnist.npz?download=1 to /home/hanwzhan/.medmnist/organamnist.npz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 38247708/38247708 [00:00<00:00, 50538826.00it/s]


Downloading https://zenodo.org/records/10519652/files/organcmnist.npz?download=1 to /home/hanwzhan/.medmnist/organcmnist.npz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15526411/15526411 [00:00<00:00, 80441134.29it/s]


Downloading https://zenodo.org/records/10519652/files/organsmnist.npz?download=1 to /home/hanwzhan/.medmnist/organsmnist.npz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16528359/16528359 [00:00<00:00, 55809920.34it/s]


Using downloaded and verified file: /home/hanwzhan/.medmnist/pathmnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/chestmnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/dermamnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/octmnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/pneumoniamnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/retinamnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/breastmnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/bloodmnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/tissuemnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/organamnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/organcmnist.npz
Using downloaded and verified file: /home/hanwzhan/.medmnist/organsmnist.npz


In [None]:
label_mapping=dict(sorted(label_mapping.items()))
label_mapping_test=dict(sorted(label_mapping_test.items()))
label_mapping==label_mapping_test

True

In [None]:
new_old_label_mapping = {original_label: new_label for new_label, original_label in enumerate(sorted(label_mapping.keys()))}

In [None]:
new_label_mapping = new_labels_list = {new_old_label_mapping[key]: value for key, value in label_mapping.items()}
all_new_train_labels=[new_old_label_mapping[label] for label in all_train_labels]
all_new_test_labels=[new_old_label_mapping[label] for label in all_test_labels]

In [None]:
BATCH_SIZE = 128

train_dataset=Dataset2D(all_train_images, all_new_train_labels)
test_dataset=Dataset2D(all_test_images, all_new_test_labels)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

In [None]:
n_channels = 3
n_classes = len(label_mapping)
NUM_EPOCHS = 3
lr = 0.001

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


model = ResNet18(in_channels=n_channels, num_classes=n_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
print(max(new_label_mapping.keys()))

79


In [None]:
for epoch in range(NUM_EPOCHS):
    train_correct = 0
    train_total = 0
    test_correct = 0
    test_total = 0

    model.train()
    for inputs, targets in tqdm(train_loader):
        # forward + backward + optimize
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)

        targets = targets.to(torch.long)
        loss = criterion(outputs, targets)


        loss.backward()
        optimizer.step()

  return F.conv2d(input, weight, bias, self.stride,
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4049/4049 [02:00<00:00, 33.58it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4049/4049 [01:53<00:00, 35.54it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4049/4049 [01:53<00:00, 35.55it/s]


In [None]:
# evaluation

def test(split):
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])

    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.numpy()
        y_score = y_score.detach().numpy()

        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)

        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))


print('==> Evaluating ...')
test('train')
test('test')