In [None]:
import os

import monai.transforms as mt
import torch
from monai.apps import MedNISTDataset
from monai.data import DataLoader
from monai.engines import SupervisedTrainer
from monai.inferers import SimpleInferer
from monai.networks import eval_mode
from monai.networks.nets import densenet121

root_dir = os.environ.get("ROOTDIR", ".")

In [None]:
max_epochs = 2
# device = torch.device("cuda:0")
device = torch.device("cpu")
net = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(device)

transform = mt.Compose([
        mt.LoadImaged(keys="image", image_only=True),
        mt.EnsureChannelFirstd(keys="image"),
        mt.ScaleIntensityd(keys="image"),
])

# dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section="training", download=True)


In [None]:

train_dl = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4)

trainer = SupervisedTrainer(
    device=device,
    max_epochs=max_epochs,
    train_data_loader=train_dl,
    network=net,
    optimizer=torch.optim.Adam(net.parameters(), lr=1e-5),
    loss_function=torch.nn.CrossEntropyLoss(),
    inferer=SimpleInferer(),
)

trainer.run()


In [None]:

torch.jit.script(net).save("mednist.ts")

class_names = ("AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT")
testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section="test", runtime_cache=True)

max_items_to_print = 10
eval_dl = DataLoader(testdata[:max_items_to_print], batch_size=1, num_workers=0)
with eval_mode(net):
    for item in eval_dl:
        result = net(item["image"].to(device))
        prob = result.detach().to("cpu")[0]
        pred = class_names[prob.argmax()]
        gt = item["class_name"][0]
        print(f"Prediction: {pred}. Ground-truth: {gt}")

In [None]:
obj = torch.jit.load("./MedNISTClassifier/models/mednist.ts");
torch.save(obj.state_dict(), "./MedNISTClassifier/models/mednist.pt")

In [None]:
os.listdir('.')

%%bash
# Do this outside this notebook with config files later

BUNDLE="./MedNISTClassifier"

# run the bundle with epochs set to 2 for speed during testing, change this to get a better result
python -m monai.bundle run train \
    --meta_file "$BUNDLE/configs/metadata.json" \
    --config_file "['$BUNDLE/configs/common.yaml','$BUNDLE/configs/train.yaml']" \
    --max_epochs 2

# we'll use the trained network as the model object for this bundle
mv model.ts $BUNDLE/models/model.ts

# generate the saved dictionary file as well
cd "$BUNDLE/models"