# Image Classification with the Medical MNIST Dataset

Date: 16-04-2025

**Sources:**

https://medmnist.com/

https://github.com/MedMNIST/MedMNIST

---

Copyright statement:
This material, no matter whether in printed or electronic form, may be used for personal and non-commercial educational use only. Any reproduction of this manuscript, no matter whether as a whole or in parts, no matter whether in printed or in electronic form, requires explicit prior acceptance of the authors.

## The Medical MNIST Dataset

MedMNIST is a large-scale MIST-like collection of biomedical images, sorted into 18 pre-processed datasets: 12 for 2D (e.g. retina, dermal, tissue etc.) with ~708K images, and 6 for 3D (e.g. organs, fractures etc.) with ~10K images. MedMNIST is designed for performing image classification in various settings (binary/multiclass, ordinal regression, multi-label), supporting numerous research and educational purposes in biomedical image analysis. Similarly to the MNIST dataset of handwritten digits, MedMNIST images are resized to 28x28 resolution, though larger variants up to 224x224 are available. More information about the MedMNIST dataset is given under the "Sources" above.

## The Task

In this notebook, you will work with the 2D DermaMNIST sub-dataset, which consists of ~10K images of skin lesions (area of your skin that is abnormal from the skin around it), catregorized into 7 classes. The DermaMNIST is based on the HAM10000, a large collection of multi-source dermatoscopic images. You will work out a simple multi-class image classification pipeline where your model maps an image of a skin lesion to one of the 7 classes.

## Setup

First, make sure to mount your Google Drive in case you are running this notebook on Google Colab:

In [None]:
from google.colab import drive
drive.mount('/content/drive')

MedMNIST requires some Python packages, so let's make sure they are installed. Any packages that are needed in this notebook (e.g. torchvision) are also installed.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import v2
import numpy as np
from tqdm import tqdm

Next, we install MedMNIST and check the version:

In [None]:
!pip install medmnist
import medmnist

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.0.tar.gz (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->medmnist)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->medmnist)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->medmnist)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->medmnist)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting n

In [None]:
print(medmnist.__version__)

3.0.2


Finally, we import DermaMNIST:

In [None]:
from medmnist import DermaMNIST, Evaluator, INFO

Looking at the `INFO` dictionary gives us useful information about the dataset in general. This is true for all MedMNIST subsets.

In [None]:
INFO['dermamnist']

{'python_class': 'DermaMNIST',
 'description': 'The DermaMNIST is based on the HAM10000, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. The dataset consists of 10,015 dermatoscopic images categorized as 7 different diseases, formulized as a multi-class classification task. We split the images into training, validation and test set with a ratio of 7:1:2. The source images of 3×600×450 are resized into 3×28×28.',
 'url': 'https://zenodo.org/records/10519652/files/dermamnist.npz?download=1',
 'MD5': '0744692d530f8e62ec473284d019b0c7',
 'url_64': 'https://zenodo.org/records/10519652/files/dermamnist_64.npz?download=1',
 'MD5_64': 'b70a2f5635c6199aeaa28c31d7202e1f',
 'url_128': 'https://zenodo.org/records/10519652/files/dermamnist_128.npz?download=1',
 'MD5_128': '2defd784463fa5243564e855ed717de1',
 'url_224': 'https://zenodo.org/records/10519652/files/dermamnist_224.npz?download=1',
 'MD5_224': '8974907d8e169bef5f5b96bc506ae45d',
 'task': 'multi-c

We can also observe the unique skin lesion labels as well as the number of channels (3 for RGB):

In [None]:
print('Labels:')
INFO['dermamnist']['label']

Labels:


{'0': 'actinic keratoses and intraepithelial carcinoma',
 '1': 'basal cell carcinoma',
 '2': 'benign keratosis-like lesions',
 '3': 'dermatofibroma',
 '4': 'melanoma',
 '5': 'melanocytic nevi',
 '6': 'vascular lesions'}

In [None]:
print('Channels:', INFO['dermamnist']['n_channels'])

Channels: 3


## DermaMNIST pre-processing

First, the data is downloaded and split into train-val-test sets in a 70-10-20% fashion. So-called image/data transforms are applied to augment the data. The subsets are packaged into their respective `DataLoader`'s for easier data bacthing during training and evaluation.

In [None]:
transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[.5], std=[.5])
])

In [None]:
# download, split and package into DataLoaders
train_data = DermaMNIST(split='train', transform=transforms, download=True)
val_data = DermaMNIST(split='val', transform=transforms, download=True)
test_data = DermaMNIST(split='test', transform=transforms, download=True)

#define arguments for the dataloaders
BATCH_SIZE=128

train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# TODO: Finish implementation

In [None]:
# inspect unique labels of dataset
INFO['dermamnist']['label']

{'0': 'actinic keratoses and intraepithelial carcinoma',
 '1': 'basal cell carcinoma',
 '2': 'benign keratosis-like lesions',
 '3': 'dermatofibroma',
 '4': 'melanoma',
 '5': 'melanocytic nevi',
 '6': 'vascular lesions'}

## Image Classification