# NIH Chest X-ray Classification Demo

## Configure environment

In [None]:
# Run this to change the working directory.+
# TODO: pip install will resolve this issue..+.
import os
os.chdir("..")

## Import

In [None]:
import torch
import imagiq.federated as iqf
from imagiq.models import Model
from imagiq.datasets import NIHDataset

## Load dataset

In [None]:
import numpy as np
from monai.transforms import \
    Compose, LoadPNGd, AddChanneld, ScaleIntensityd, Lambdad, Resized, AsChannelFirstd, ToTensorD, RandFlipd, RandRotated, RandZoomd, CastToTyped

train_transforms = Compose([
    LoadPNGd("image"),
    Lambdad("image", func=lambda x: np.mean(x, axis=2) if len(x.shape) == 3 else x),
    AsChannelFirstd("image"),
    AddChanneld("image"),
    ScaleIntensityd("image"),
    Resized("image", spatial_size=(224,224), mode="nearest"),
    RandRotated("image", range_x=15, prob=0.5, keep_size=True),
    RandFlipd("image", spatial_axis=0, prob=0.5),
    RandZoomd("image", min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True),
    ToTensorD(("image", "label")), 
    CastToTyped( ('label'), torch.float )
])

val_transforms = Compose([
    LoadPNGd("image"),
    Lambdad("image", func=lambda x: np.mean(x, axis=2) if len(x.shape) == 3 else x),
    AsChannelFirstd("image"),
    AddChanneld("image"),
    ScaleIntensityd("image"),
    Resized("image", spatial_size=(224,224), mode="nearest"),
    ToTensorD(("image", "label")), 
    CastToTyped( ('label'), torch.float )
])

# set download = None
train_ds = NIHDataset("training", train_transforms, download=[0])
val_ds = NIHDataset("validation", val_transforms, download=[0])
test_ds = NIHDataset("test", val_transforms, download=[0])

In [None]:
print( train_ds ) 
print( val_ds ) 
print( test_ds ) 

## Display Data

In [None]:
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

plt.subplots(3, 3, figsize=(8, 8))
for i,k in enumerate(np.random.randint( len(train_ds), size=9)):
    im = Image.open(train_ds.data[k]['image'] )
    arr = np.array(im)
    plt.subplot(3, 3, i + 1)
    # plt.xlabel(class_names[image_label_list[k]])
    plt.imshow(arr, cmap='gray', vmin=0, vmax=255)
plt.tight_layout()
plt.show()

## Feed it to a simple model

In [None]:
from monai.data import DataLoader
from monai.networks.nets import densenet121

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = densenet121(
    spatial_dims=2,
    in_channels=1,
    out_channels=15)

# modify densenet121 by including some dropout layers 
net.features[5].add_module( 'drop1', torch.nn.Dropout(0.5) )
net.features[7].add_module( 'drop1', torch.nn.Dropout(0.5) )
net.features[9].add_module( 'drop1', torch.nn.Dropout(0.5) )
net.class_layers = torch.nn.Sequential( 
    torch.nn.ReLU( inplace=True ), 
    torch.nn.AdaptiveAvgPool2d( output_size=1),
    torch.nn.Dropout( 0.5 ), 
    torch.nn.Flatten(start_dim=1, end_dim=-1), 
    torch.nn.Linear(in_features=1024, out_features=15, bias=True)
)
model = Model( net )

In [None]:
pos_weight = torch.tensor( train_ds.getPositiveWeights() ).to(device)
model.train( train_ds,
    torch.nn.BCEWithLogitsLoss( pos_weight=pos_weight), 
    torch.optim.Adam( model.net.parameters(), 1e-3, weight_decay=1e-4), 
    epochs=5, 
    metrics=['AUC'], 
    batch_size=16, 
    device=device,
    validation_dataset=val_ds )