# Malaria
https://www.kaggle.com/sagnikmazumder37/malaria-cell-imagesshuffled-and-split

### Import libraries

In [1]:
import os
import glob
import warnings
import random as rnd
from collections import defaultdict
from copy import deepcopy
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import PIL
from PIL import Image as im

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.transforms.functional import to_tensor
import torchvision.models as models
import torch.nn as nn
import torch.optim
import torch.backends.cudnn as cudnn

import albumentations as A
from albumentations.pytorch import ToTensorV2


cudnn.benchmark = True
warnings.filterwarnings('ignore', category=UserWarning)
plt.ion()
%matplotlib inline

In [2]:
def imread(path: str) -> PIL.PngImagePlugin.PngImageFile:
    """
    Args:
        path (string): Path to the image.
    """
    img = im.open(path)
    img.load()
    return img


def display_image_grid(images_filepaths, predicted_labels=(), cols=5):
    rows = len(images_filepaths) // cols
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
    for i, image_filepath in enumerate(images_filepaths):
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        true_label = os.path.normpath(image_filepath).split(os.sep)[-2]
        predicted_label = predicted_labels[i] if predicted_labels else true_label
        color = 'green' if true_label == predicted_label else 'red'
        ax.ravel()[i].imshow(image)
        ax.ravel()[i].set_title(predicted_label, color=color)
        ax.ravel()[i].set_axis_off()
    plt.tight_layout()
    plt.show()

### Exploring dataset

There are in `train` and `val` two folders with equal size: **parasitized** and **uninfected**. 

**Train:** 10334 * 2 = 20668 = **~66%**

**Val:** 3445 * 2 = 6890 = **~33%**

### Datasets

In [3]:
class MalariaDataset(Dataset):
    """Malaria dataset with parasitized and uninfected cells."""
    
    def __init__(self, root_path: str, transform=None):
        """
            Args:
                root_path (string): Path to 'Parasitized' and 'Uninfected' folders.
                transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.transform = transform
        self.path = root_path
        self.all_paths = glob.glob(self.path + '/Parasitized/*.png') + glob.glob(self.path + '/Uninfected/*.png')
        self.X = []
        self.y = []
        self.sum_w, self.sum_h = 0, 0
        for path in tqdm(self.all_paths):
            img = imread(path)
            w, h = img.size
            self.sum_w += w
            self.sum_h += h
            self.X.append(img)
            label = path.split('/')[2]
            if label == 'Uninfected':
                self.y.append(0.)
            else:
                self.y.append(1.)
        self.mean_w, self.mean_h = self.sum_w // len(self.X), self.sum_h // len(self.X)
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index: int):
        img = self.X[index]
        label = self.y[index]
        if self.transform is not None:
            img = self.transform(image=img)['image']
        return img, label
    
    def standartize(self):
        for i in tqdm(range(len(self.X))):
            self.X[i] = self.X[i].resize((self.mean_w, self.mean_h))
            self.X[i] = np.array(self.X[i]).reshape((3, self.mean_w, self.mean_h))
            self.X[i] = torch.from_numpy(self.X[i])

### Augmentations

In [4]:
# train_transform = transforms.Compose(
#     [
#         A.SmallestMaxSize(max_size=160),
#         A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
#         A.RandomCrop(height=128, width=128),
#         A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
#         A.RandomBrightnessContrast(p=0.5),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#         transforms.ToTensor(),
#     ]
# )

train_ds = MalariaDataset('data/train')
train_ds.standartize()

100%|██████████| 20668/20668 [00:11<00:00, 1731.62it/s]
100%|██████████| 20668/20668 [00:04<00:00, 4564.91it/s]


In [5]:
# val_transform = transforms.Compose(
#     [
#         A.SmallestMaxSize(max_size=160),
#         A.CenterCrop(height=128, width=128),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#         transforms.ToTensor(),
#     ]
# )

val_ds = MalariaDataset('data/val')
val_ds.standartize()

100%|██████████| 3446/3446 [00:01<00:00, 1793.85it/s]
100%|██████████| 3446/3446 [00:00<00:00, 4937.42it/s]


In [6]:
# def visualize_augmentations(dataset, title='', idx=0, samples=10, cols=5):
#     dataset = deepcopy(dataset)
# #     dataset.transform = transforms.Compose([t for t in dataset.transform if not isinstance(t, transforms.ToTensor)])
#     rows = samples // cols
#     _, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
#     for i in range(samples):
#         image, _ = dataset[idx]
#         ax.ravel()[i].imshow(image)
#         ax.ravel()[i].set_axis_off()
#     plt.suptitle(title, fontsize=25, color='g')
#     plt.tight_layout()
#     plt.show()

In [7]:
# rnd.seed(42)
# visualize_augmentations(train_ds, title='Train images')

In [8]:
# rnd.seed(42)
# visualize_augmentations(val_ds, title='Validation images')

### Training helpers

In [9]:
def calculate_accuracy(output, target):
    output = torch.sigmoid(output) >= 0.5
    target = target == 1.0
    return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()

In [10]:
class MetricMonitor:
    def __init__(self, float_precision=3):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = defaultdict(lambda: {'val': 0, 'count': 0, 'avg': 0})

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]
        metric['val'] += val
        metric['count'] += 1
        metric['avg'] = metric['val'] / metric['count']

    def __str__(self):
        return ' | '.join(
            [
                f'{metric_name}: {metric["avg"]}:.{self.float_precision}f'
                for (metric_name, metric) in self.metrics.items()
            ]
        )

### Define training parameters

In [11]:
params = {
    'model': 'resnet50',
    'device': 'cpu',
    'lr': 0.001,
    'batch_size': 64,
    'num_workers': 0,
    'epochs': 10,
}

### Preparing for training and validation

In [12]:
model = getattr(models, params['model'])(pretrained=False, num_classes=1)
model = model.to(params['device'])
criterion = nn.BCEWithLogitsLoss().to(params['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])

In [13]:
train_loader = DataLoader(
    train_ds, batch_size=params['batch_size'], shuffle=True, num_workers=params['num_workers'], pin_memory=True,
)
val_loader = DataLoader(
    val_ds, batch_size=params['batch_size'], shuffle=False, num_workers=params['num_workers'], pin_memory=True,
)

In [14]:
def train(train_loader, model, criterion, optimizer, epoch, params):
    metric_monitor = MetricMonitor()
    model.train()
    stream = tqdm(train_loader)
    for i, (images, target) in enumerate(stream, start=1):
        images = images.to(params['device'], non_blocking=True).float()
        target = target.to(params['device'], non_blocking=True).float().view(-1, 1)
        output = model(images)
        loss = criterion(output, target)
        accuracy = calculate_accuracy(output, target)
        metric_monitor.update('Loss', loss.item())
        metric_monitor.update('Accuracy', accuracy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        stream.set_description(f'Epoch: {epoch}. Train.      {metric_monitor}')

In [15]:
def validate(val_loader, model, criterion, epoch, params):
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(val_loader)
    with torch.no_grad():
        for i, (images, target) in enumerate(stream, start=1):
            images = images.to(params['device'], non_blocking=True).float()
            target = target.to(params['device'], non_blocking=True).float().view(-1, 1)
            output = model(images)
            loss = criterion(output, target)
            accuracy = calculate_accuracy(output, target)

            metric_monitor.update('Loss', loss.item())
            metric_monitor.update('Accuracy', accuracy)
            stream.set_description(f'Epoch: {epoch}. Validation. {metric_monitor}')

### Training and validation

In [None]:
for epoch in range(1, params['epochs'] + 1):
    train(train_loader, model, criterion, optimizer, epoch, params)
    validate(val_loader, model, criterion, epoch, params)

Epoch: 1. Train.      Loss: 0.7091145515441895:.3f | Accuracy: 0.484375:.3f:   0%|          | 1/323 [00:14<1:19:47, 14.87s/it]

### Test

In [None]:
class MalariaInferenceDataset(Dataset):
    def __init__(self, root_path, transform=None):
        self.path = images_filepaths
        self.transform = transform
        self.all_paths = glob.glob(self.path + '/Parasitized/*.png') + glob.glob(self.path + '/Uninfected/*.png')

    def __len__(self):
        return len(self.all_paths)

    def __getitem__(self, idx):
        path = self.all_paths[idx]
        img = imread(path)
        if self.transform is not None:
            img = self.transform(image=img)['image']
        return img

In [None]:
test_transform = A.Compose(
    [
        A.SmallestMaxSize(max_size=160),
        A.CenterCrop(height=128, width=128),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

test_ds = MalariaInferenceDataset('data/test', transform=test_transform)
test_loader = DataLoader(
    test_ds, batch_size=params['batch_size'], shuffle=False, num_workers=params['num_workers'], pin_memory=True,
)

In [None]:
model = model.eval()
predicted_labels = []
with torch.no_grad():
    for images in test_loader:
        images = images.to(params['device'], non_blocking=True)
        output = model(images)
        predictions = (torch.sigmoid(output) >= 0.5)[:, 0].cpu().numpy()
        predicted_labels += ['Cat' if is_cat else 'Dog' for is_cat in predictions]

In [None]:
display_image_grid(test_images_filepaths, predicted_labels)