In [1]:
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import monai
import torch
import torchio as tio
import os
import gc
from glob import glob
from tqdm import tqdm
from collections import defaultdict
import json



In [2]:
city = 'Beijing_Zang'
IMAGE_DIR = f'/home/azureuser/cloudfiles/code/Users/rduan6/dataset/{city}/MRI'
MASK_DIR = f'/home/azureuser/cloudfiles/code/Users/rduan6/dataset/{city}/Ventricles'
TRAIN_SIZE = 15 #158
TEST_SIZE = 39
BATCH_SIZE = 1

# right inf
# left inf
# right
# left
output_dir = '../results/models_train15/'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
def show_cuda_memory():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside reserved
    print('Total:     {:0.2f} GiB'.format(t / 2**30))
    print('Reserved:  {:0.2f} GiB'.format(r / 2**30))
    print('Allocated: {:0.2f} GiB'.format(a / 2**30))
    print('Free:      {:0.2f} GiB'.format(f / 2**30))

show_cuda_memory()

Total:     11.17 GiB
Reserved:  0.00 GiB
Allocated: 0.00 GiB
Free:      0.00 GiB


In [4]:
def get_subjects(image_dir, mask_dir):
    subjects = []
    for image_path in tqdm(sorted(glob(f'{image_dir}/*.nii.gz')), desc='Creating Subjects'):
        filename = image_path.split('/')[-1]
        mask_path = f'{mask_dir}/{filename}'
        subject = tio.Subject(
            t1=tio.ScalarImage(image_path),
            label=tio.LabelMap(mask_path),
        )
        subjects.append(subject)
    return subjects

In [5]:
all_subjects = get_subjects(IMAGE_DIR, MASK_DIR)
subjects = {
    'train': all_subjects[:TRAIN_SIZE],
    'validation': all_subjects[-TEST_SIZE:],
}

Creating Subjects: 100%|██████████| 197/197 [00:02<00:00, 75.73it/s]


In [6]:
spatial = tio.OneOf(
    {tio.RandomAffine(degrees=(-3, 3), translation=(-0.1, 0.1)): 1.0},
    p=0.75,
)

resample = tio.Compose([
    tio.Resample(1),
    tio.CropOrPad(256),
])

signal = tio.Compose([ 
    tio.RescaleIntensity(percentiles=(0.1, 99.9), out_min_max=(0, 1)),
])

In [7]:
def get_transform(std):
    noise = tio.Compose([ 
        tio.RandomNoise(mean=0, std=(std, std)),
    ])
    transform = {
        'train': tio.Compose([
            spatial, 
            resample, 
            noise,
            signal,
        ]),
        'validation': tio.Compose([
            resample, 
            noise,
            signal,
        ]),
    }
    return transform

def get_dataloader(transform):
    dataloader = dict()
    for mode in ['train', 'validation']:
        dataloader[mode] = torch.utils.data.DataLoader(
            tio.SubjectsDataset(
                subjects[mode], 
                transform=transform[mode]
            ),
            batch_size=BATCH_SIZE, 
            num_workers=os.cpu_count()
        )
    return dataloader

In [8]:
def validate(model, loss_fn, metric, losses, dscs, std, dataloader):
    model.eval()
    with torch.no_grad():
        mean_loss = 0
        for subject in dataloader['validation']:#tqdm(dataloader['validation'], desc=f'Validation'):
            image = subject['t1'][tio.DATA].to(DEVICE)  # (B, 1, 256, 256, 256)
            label = subject['label'][tio.DATA].to(DEVICE)  # (B, 1, 256, 256, 256)
                
            pred = model(image)  # (B, 5, 256, 256, 256)
            one_hot_label = monai.networks.utils.one_hot(
                label, num_classes=5, dim=1
            ).to(DEVICE)  # (B, 5, 256, 256, 256)

            loss = loss_fn(pred, one_hot_label)
            mean_loss += loss * image.shape[0]

            one_hot_pred = monai.networks.utils.one_hot(
                torch.argmax(pred, dim=1, keepdim=True), 
                num_classes=5, 
                dim=1
            ).to(DEVICE)
            metric(one_hot_pred, one_hot_label)
                
            del image, label, pred, one_hot_label, loss, one_hot_pred
            gc.collect()
            torch.cuda.empty_cache()

        mean_loss = mean_loss.item() / TEST_SIZE
        losses['validation'].append(mean_loss)
        print(f'Validation Loss: {mean_loss}')

        if mean_loss == min(losses['validation']):
            torch.save(model.state_dict(), f'{output_dir}/UNet_std{std}.pth')

        mean_dsc = metric.aggregate().tolist()
        metric.reset()
        dscs['validation'].append(mean_dsc)
        print(f'Validation DSC: {mean_dsc}\n')

In [9]:
def train(model, n_epochs, dataloader, std):
    losses, dscs = defaultdict(list), defaultdict(list)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.01)
    loss_fn = monai.losses.DiceLoss(softmax=True, squared_pred=True).to(DEVICE)
    metric = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')

    for epoch in range(n_epochs):
        print(f'Epoch {epoch+1}/{n_epochs}')
        model.train()
        mean_loss = 0
        for i, subject in enumerate(dataloader['train']):#enumerate(tqdm(dataloader['train'], desc=f'Epoch {epoch+1}/{n_epochs} Train')):
            image = subject['t1'][tio.DATA].to(DEVICE)  # (B, 1, 256, 256, 256)
            label = subject['label'][tio.DATA].to(DEVICE)  # (B, 1, 256, 256, 256)

            pred = model(image)  # (B, 5, 256, 256, 256)
            one_hot_label = monai.networks.utils.one_hot(
                label, num_classes=5, dim=1
            ).to(DEVICE)  # (B, 5, 256, 256, 256)
                
            loss = loss_fn(pred, one_hot_label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            mean_loss += loss * image.shape[0]

            one_hot_pred = monai.networks.utils.one_hot(
                torch.argmax(pred, dim=1, keepdim=True), 
                num_classes=5, 
                dim=1
            ).to(DEVICE)
            metric(one_hot_pred, one_hot_label)
            
            np.save(
                f'{output_dir}/pred_std{std}_epoch{epoch+1}_image{i+1}', 
                one_hot_pred.detach().cpu().numpy()
            )

            del image, label, pred, one_hot_label, loss, one_hot_pred
            gc.collect()
            torch.cuda.empty_cache()

        # scheduler.step()

        mean_loss = mean_loss.item() / TRAIN_SIZE
        losses['train'].append(mean_loss)
        print(f'Train Loss: {mean_loss}')

        mean_dsc = metric.aggregate().tolist()
        metric.reset()
        dscs['train'].append(mean_dsc)
        print(f'Train DSC: {mean_dsc}')

        validate(model, loss_fn, metric, losses, dscs, std, dataloader)

    return losses, dscs

In [10]:
for std in [1000, 2000]:
    transform = get_transform(std=std)
    dataloader = get_dataloader(transform)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=5,
        channels=(32, 64, 128, 256, 512),
        strides=(2, 2, 2, 2),
        norm=monai.networks.layers.Norm.BATCH,
    ).to(DEVICE)
    losses, dscs = train(model, n_epochs=50, dataloader=dataloader, std=std)

    losses = json.dumps(losses)
    f = open(f'{output_dir}/losses_std{std}.json', 'w')
    f.write(losses)
    f.close()

    dscs = json.dumps(dscs)
    f = open(f'{output_dir}/dscs_std{std}.json', 'w')
    f.write(dscs)
    f.close()

    del model, transform, dataloader, losses, dscs
    gc.collect()
    torch.cuda.empty_cache()

Epoch 1/50
Train Loss: 0.8730594635009765
Train DSC: [0.0009869291679933667, 0.0011891480535268784, 0.0001376793225063011, 0.00028293437208049]
Validation Loss: 0.8782900296724759
Validation DSC: [0.001413295860402286, 0.000813145365100354, 0.00016728368063922971, 0.00023681667516939342]

Epoch 2/50
Train Loss: 0.8134078979492188
Train DSC: [0.005527791567146778, 0.0038023900706321, 0.0009054889669641852, 0.0007054306333884597]
Validation Loss: 0.818069360195062
Validation DSC: [0.009693700820207596, 0.011499617248773575, 0.0010941297514364123, 0.0013355028349906206]

Epoch 3/50
Train Loss: 0.7822691599527994
Train DSC: [0.06192946806550026, 0.06087280064821243, 0.008442134596407413, 0.009225575253367424]
Validation Loss: 0.9285631424341446
Validation DSC: [0.0013179085217416286, 0.001524399733170867, 0.00016528925334569067, 0.00030544211040250957]

Epoch 4/50
Train Loss: 0.7304951985677083
Train DSC: [0.1344667375087738, 0.17764711380004883, 0.01129075512290001, 0.003965189680457115]


In [None]:
!zip -r file.zip ../results/models_train15

  adding: ../results/models_train15/ (stored 0%)
  adding: ../results/models_train15/.amlignore (deflated 32%)
  adding: ../results/models_train15/.amlignore.amltmp (deflated 32%)
  adding: ../results/models_train15/dscs_std1000.json (deflated 54%)
  adding: ../results/models_train15/dscs_std2000.json (deflated 53%)
  adding: ../results/models_train15/losses_std1000.json (deflated 52%)
  adding: ../results/models_train15/losses_std2000.json (deflated 52%)
  adding: ../results/models_train15/pred_std0_epoch10_image1.pth.npy (deflated 100%)
  adding: ../results/models_train15/pred_std0_epoch10_image10.pth.npy (deflated 100%)
  adding: ../results/models_train15/pred_std0_epoch10_image11.pth.npy (deflated 100%)
  adding: ../results/models_train15/pred_std0_epoch10_image12.pth.npy (deflated 100%)
  adding: ../results/models_train15/pred_std0_epoch10_image13.pth.npy (deflated 100%)
  adding: ../results/models_train15/pred_std0_epoch10_image14.pth.npy (deflated 100%)
  adding: ../results/mode