In [None]:
%load_ext autoreload
%autoreload 2

#### Data

MedMNIST2D

In [None]:
# Install MedMNIST data
# !pip install medmnist

Libraries

In [None]:
from tqdm import tqdm
import warnings
warnings.simplefilter("ignore")

import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import seaborn as sns

import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO

Configuration

In [None]:
data_flag = 'chestmnist'
download = True

BATCH_SIZE = 128
NUM_EPOCHS = 50

info = INFO[data_flag]
label_dict = info["label"]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

##### Exploratory Analysis

In [None]:
# preprocessing
data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download, root="./data/.medmnist/", size=224)
test_dataset = DataClass(split='test', transform=data_transform, download=download, root="./data/.medmnist/", size=224)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
print(train_dataset)
print("="*100)
print(test_dataset)

Visualization

In [None]:
# montage
train_dataset.montage(length=4)

In [None]:
fig = plt.figure(figsize=(14, 12))
for i, label in enumerate(label_dict.keys()):
    ind = np.where(train_dataset.labels.sum(axis=-1) == int(label))[0]
    if ind.size != 0:
        label_indices = np.random.choice(ind)
        img_by_label = train_dataset.imgs[label_indices]
        
        plt.subplot(7, 2, i+1)
        plt.imshow(img_by_label, cmap="gray")
        plt.axis("off")
        if task == "multi-label, binary-class":
            plt.title(f"Label: {[label_dict[str(label)] for label in np.where(train_dataset.labels[label_indices] == 1)[0]]}", fontsize=7)
        else:
            plt.title(f"Label: {label_dict[str(label)]}")
        plt.tight_layout()
    else:
        print("No such image")

In [None]:
def histogram_img(img, title=None, ind=i):
    plt.figure(figsize=(10, 4))
    ax1 = plt.subplot(1, 2, 1)
    ax2 = plt.subplot(1, 2, 2)

    histograms = []

    if img.shape[0] == 3:
        colors = ('b','g','r')
        for i in range(3):
            hist = cv.calcHist([img], [i], None, [256], [0,255])
            histograms.append(hist)

            ax2.plot(hist, color=colors[i])
    else:
        hist = cv.calcHist([img], [0], None, [256], [0,255])
        histograms.append(hist)

        ax2.plot(hist)



    # tmp_img = cv.bitwise_and(img, img, mask=mask)
    ax1.imshow(cv.cvtColor(img, cv.COLOR_BGR2RGB))
    ax1.grid(False)
    ax1.axis('off') 
    if title is not None:
        ax1.set_title(title, fontsize=7)

    plt.title("Colour Distribution")
    plt.tight_layout()
    plt.savefig(f"./vlm/visualizations/{ind}_{data_flag}.png")
    plt.show()

plt.figure(figsize=(6, 4))
for i, label in enumerate(label_dict.keys()):
    ind = np.where(train_dataset.labels.sum(axis=-1) == int(label))[0]
    if ind.size != 0:
        label_indices = np.random.choice(ind)
        img_by_label = train_dataset.imgs[label_indices]
        if task == "multi-label, binary-class":
            title= f"Label: {[label_dict[str(label)] for label in np.where(train_dataset.labels[label_indices] == 1)[0]]}"
        else:
            title = f"Label: {label_dict[str(label)]}"
        histogram_img(img_by_label, title=title, ind=i)
    else:
        print("No such image")

In [None]:
for i, label in enumerate(label_dict.keys()):
    label_indices = np.where(train_dataset.labels == int(label))[0]
    img_by_label = train_dataset.imgs[label_indices]
    
    if info["n_channels"] == 3:
        ##R
        red_channel_intensity = np.mean(img_by_label[:, :, :, 0]).round(2)
        red_channel_intensity_sd = np.std(img_by_label[:, :, :, 0]).round(2)
        ##G
        green_channel_intensity = np.mean(img_by_label[:, :, :, 1]).round(2)
        green_channel_intensity_sd = np.std(img_by_label[:, :, :, 1]).round(2)
        ##B
        blue_channel_intensity = np.mean(img_by_label[:, :, :, 2]).round(2)
        blue_channel_intensity_sd = np.std(img_by_label[:, :, :, 2]).round(2)
        
        intensity_dict = {label_dict[label]: ((red_channel_intensity, red_channel_intensity_sd),
                                            (green_channel_intensity, green_channel_intensity_sd),
                                            (blue_channel_intensity, blue_channel_intensity_sd))}
    
    elif info["n_channels"] == 1:
        ##G
        gray_channel_intensity = np.mean(img_by_label[:, :, :]).round(2)
        gray_channel_intensity_sd = np.std(img_by_label[:, :, :]).round(2)
        
        intensity_dict = {label_dict[label]: (gray_channel_intensity, gray_channel_intensity_sd)}

    print(intensity_dict)

Data Distribution

In [None]:
import itertools
from collections import Counter
def data_dist_plot(dataset, title):

    if "multi-label" in task:
        count_dict = {}
        for i, label in enumerate(label_dict.keys()):
            ind = np.where(dataset.labels.sum(axis=-1) == int(label))[0]
            count_dict[int(label)] = len(ind)
        assert sum(count_dict.values()) == len(dataset), "Mismatch in frequency and total count"
        rank = sorted(count_dict, key=count_dict.get)
        
        pal = sns.color_palette("Purples_d", len(label_dict))
        ax = sns.barplot(y=count_dict.keys(), x=count_dict.values(), palette=np.array(pal)[rank], legend=False, orient="h")
        for container in ax.containers:
            ax.bar_label(container)
        ax.set_title(title)
        plt.xlabel("# of labels")
        plt.ylabel("# of images")
        

    else:
        labels = list(itertools.chain.from_iterable(dataset.labels))
        
        freq = Counter(labels)
        assert sum(freq.values()) == len(dataset), "Mismatch in frequency and total count"
        rank = sorted(freq, key=freq.get)

        pal = sns.color_palette("Purples_d", len(labels))
        ax = sns.barplot(y=[label_dict[str(label)] for label in freq.keys()], x=freq.values(), palette=np.array(pal)[rank], legend=False)
        # ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
        ax.set_title(title)
        plt.xlabel("# of images")
        plt.ylabel("Label")

        return ax


plt.figure(figsize=(10, 10))
plt.subplot(211)
data_dist_plot(train_dataset, "train")
plt.subplot(212)
data_dist_plot(test_dataset, "test")
plt.tight_layout()