In [None]:
# Run this to change the working directory.
# TODO: pip install will resolve this issue...
import os
os.chdir("..")

In [None]:
import numpy as np
import torch
from imagiq.models import Model
from imagiq.datasets import NIHDataset
from monai.transforms import (
    Compose,
    LoadImaged,
    ScaleIntensityd,
    SqueezeDimd,
    AddChanneld,
    AsChannelFirstd,
    Lambdad,
    ToTensord,
    Resized,
    RandRotated,
    RandFlipd,
    RandHistogramShiftd,
    RandGaussianNoised,
)
from monai.networks.nets import densenet121
from monai.data import CacheDataset
import sys

In [None]:
# TODO: Read all, not just the test section
# TODO: Train test split
master_dataset = NIHDataset(section="test", download=[0])
print(master_dataset)

In [None]:
# Node 1 is biased towards Atelectasis
# Node 2 is biased towards Infiltration
# TODO: bias towards AP/Lateral views
# TODO: Bias towards male/female
N_normal = master_dataset.class_count[0]

train_data = list()
val_data = list()
test_data = list()
for i, data in enumerate(master_dataset):
    data["label"] = 1-data["label"][0]
    r = np.random.rand()
    if r < 0.7:
        train_data.append(data)
    elif r < 0.85:
        val_data.append(data)
    else:
        test_data.append(data)

In [None]:
train_transforms = Compose(
    [
        LoadImaged("image"),
        Lambdad("image", func=lambda x: np.mean(x, axis=2) if len(x.shape) == 3 else x),
        AsChannelFirstd("image"),
        AddChanneld("image"),
        ScaleIntensityd("image"),
        Resized("image", spatial_size=(224,224), mode="nearest"),
        RandHistogramShiftd("image", prob=0.2),
        RandGaussianNoised("image", prob=0.2),
        RandRotated("image", range_x=3.141592/12, prob=0.2),
        RandFlipd("image", prob=0.2),
    ]
)

test_transforms = Compose(
    [
        LoadImaged("image"),
        Lambdad("image", func=lambda x: np.mean(x, axis=2) if len(x.shape) == 3 else x),
        AsChannelFirstd("image"),
        AddChanneld("image"),
        ScaleIntensityd("image"),
        Resized("image", spatial_size=(224,224), mode="nearest"),
    ]
)

train_dataset = CacheDataset(train_data, train_transforms)
val_dataset = CacheDataset(val_data, test_transforms)
test_dataset = CacheDataset(test_data, test_transforms)

In [None]:
model = Model(densenet121(spatial_dims=2, in_channels=1, out_channels=2))

In [None]:
optimizer = torch.optim.Adam( model.net.parameters(), 5e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.1, patience=5)
history = model.train(
    train_dataset,
    torch.nn.CrossEntropyLoss(),
    optimizer,
    epochs=10,
    metrics=["AUC"],
    batch_size=16,
    device="cpu",
    validation_dataset=val_dataset,
    dirpath='path/to/save/model/',
    scheduler=scheduler
)

In [None]:
history.keys()

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.figure( figsize=(15, 5) ) 
plt.subplot( 1, 2, 1 )
plt.plot( history['loss'] )
plt.plot( history['val_loss'] )
plt.title( 'loss vs epoch' )
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend( ['train', 'validation'] )

plt.subplot( 1, 2, 2 )
plt.plot( history['auc'] )
plt.plot( history['val_auc'] )
plt.title( 'AUC vs epoch' )
plt.xlabel('epochs')
plt.ylabel('auc')
plt.legend( ['train', 'validation'] )

In [None]:
# TODO: evaluate() method is incomplete
model.predict(test_dataset, device="cpu")