In [1]:
import glob
from generators import *
from helpers import *
from networks import *
import torch
from torch.utils import data
from torch import nn
import torch.optim as optim
import numpy as np
import pandas as pd
from datetime import datetime

In [2]:
#set parameters
BATCH_SIZE = 8
epochs = 10
sampling = [0,0,5,5,10,5] #0 bckg, 1 bladder, 2 kidneys, 3 liver, 4 pancreas, 5 spleen
outpath = '/home/eva/Desktop/research/PROJEKT2-DeepLearning/AnatomyAwareDL/Results/'
outpath_nets = outpath+'Networks/'

#set dataset and loaders
subjekti = glob.glob('/home/eva/Desktop/research/PROJEKT2-DeepLearning/AnatomyAwareDL/Data/TRAINdata/sub*'); subjekti.sort()
labele = glob.glob('/home/eva/Desktop/research/PROJEKT2-DeepLearning/AnatomyAwareDL/Data/TRAINdata/lab*'); labele.sort()

subbatch = sum(sampling)
#dataset = POEMDataset(subjekti[0:10], labele[0:10], 25, 9, sampling) #my_collate
dataset = POEMDatasetMultiInput(subjekti[0:10], labele[0:10], sampling, 25, channels=[0,1], subsample=True,
                 channels_sub=[0,1])
train_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=multi_input_collate) 



In [3]:
# create your optimizer, network and set parameters
network_type = "DeepMedOrig"
#net = OnePathway(in_channels=4, num_classes=6, dropoutrateCon=0.2, dropoutrateFC=0.5)
net = DualPathway(in_channels_orig=2, in_channels_subs=2, num_classes=6, dropoutrateCon=0.2, dropoutrateFC=0.5, nonlin="prelu")
net = net.float()
optimizer = optim.Adam(net.parameters(), lr=0.001)
napaka = nn.CrossEntropyLoss(weight=None, ignore_index=1, reduction='mean') #weight za weightedxentropy, ignore_index ce ces ker klas ignorat.

log_interval = 1 #na kolko batchev reportas.
val_interval = 5 #na kolko epoch delas validation.

# train on cuda if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# training loop:
train_losses = []
training_Dice = []
val_losses = []
val_Dice = []

vsehslik = len(train_loader.dataset)*subbatch

for epoch in range(epochs):
    print('Train Epoch: {}'.format(epoch))
    net.train()
    running_loss = 0.0
    epoch_loss = 0.0
    epoch_dice = 0.0

  
  #  tqdm_iter = tqdm_(enumerate(train_loader), total=len(train_loader), desc=">>Training: ")
    for batch_idx, (notr, target) in enumerate(train_loader): #tqdm_iter:
        notr, target = [torch.as_tensor(notri, device=device).float() for notri in notr], torch.as_tensor(target, device=device)
        # print(len(notr))
        optimizer.zero_grad()   # zero the gradient buffers
        ven = net(*notr)
        
        loss = napaka(ven, target.long())
        loss.backward()
        optimizer.step()

        dice = dice_coeff(nn.Softmax(dim=1)(ven), target, nb_classes=6, weights=np.array([0,1,1,1,1,1]))
        dice_organs = dice_coeff_per_class(nn.Softmax(dim=1)(ven), target, nb_classes=6)
        epoch_dice += dice_organs.sum(0).squeeze()

        running_loss += loss.item()
        epoch_loss += loss.item() * ven.shape[0] #or loss.detach() - detach disables differentiation through here (?)
        if batch_idx%log_interval==0:
            print('[{:.0f}%]\tAccumulated batch loss: {:.6f}\tBatch generalized Dice: {:.6f}'.format(
                                                100.*batch_idx/len(train_loader),
                                                running_loss/(batch_idx+1), dice.sum()/len(notr[0])))
            print('Dice (batch averages) by organ: ', dice_organs.mean(0).data.numpy())

    if (epoch+1)%val_interval==0: #TODO add validation
        #with torch.no_grad():
            #for x_val, y_val in val_loader:
             #   x_val = x_val.to(device)
            #    y_val = y_val.to(device)
           #     net.eval()
          #      yhat = model(x_val)
         #       val_loss = loss_fn(y_val, yhat)
        #        val_losses.append(val_loss.item())
        pass
                #do validation, save results. PRint.

    epoch_loss = epoch_loss/vsehslik
    epoch_dice = epoch_dice.squeeze()/vsehslik
    print('>>> Epoch {} finished. Averaged loss: {:.6f}, average Dices: {} \n'.format(
        epoch, epoch_loss, epoch_dice.data.numpy()))
    train_losses.append(epoch_loss)
    training_Dice.append(epoch_dice.data.numpy())

    
today = datetime.today()
now = datetime.now() 
timestamp = today.strftime("_%m%d%Y")+now.strftime("_%H%M%S")

print("Training finished. Saving metrics...")
dejta = np.column_stack([np.array(train_losses), np.array(training_Dice)])
df = pd.DataFrame(data=dejta,    # values
    columns=np.array(['Loss', 'Dice Bckg', 'Dice Bladder', 'Dice Kidneys', 'Dice Liver', 'Dice Pancreas', 'Dice Spleen']))
df.to_csv(outpath_nets+network_type+timestamp+'_DiceAndLoss.csv')

#saving (and reloading) of network
print("Saving trained network...")
torch.save(net, outpath_nets+network_type+timestamp)
#to remember what network was trained here, let's save parameters for a quick lookup:
run_dict = {"networkType": network_type, "channels": dataset.channels, "segment":dataset.segment_size,
            "channelsSubsampled": dataset.channels_sub, "useSubsampled": dataset.subsample, "sampling": sampling,
            "segment2": dataset.segment_size2, "channels2": dataset.channels2, "batchsize": BATCH_SIZE, "lastEpoch": epochs}
with open(outpath_nets+network_type+timestamp+'.txt', 'w') as f:
    print(run_dict, file=f)

Train Epoch: 0


In [None]:
#Load network, if you already have a trained version saved somewhere.
saved_network = outpath_nets + network_type + timestamp

net = torch.load(saved_network)
net.to(torch.device(device))

In [None]:
#EVALUATION ON TEST IMAGES:
#-cut them in patches if needed <- do this in advance!!
#-inference on patches
#-sew patches back together
#So far, Dices were calc. on patches, so they don't say much... Calc again on sewn pics!

##############
test_subjekti = subjekti[0:5]
test_labele = labele[0:5]
patchsize = 52
overlap = 8
##############
need_to_cut = True #change to True if any of test data parameters are changed.
do_inference = True
##############

if need_to_cut:
    test_list = cut_patches(test_subjekti, patchsize, overlap*2, channels=4, 
                            outpath=outpath, subsampledinput=True)
    # best to always run 'cut_patches' with 4channels, since data loader itself takes care of cases 
    # with using different channels.
else:
    test_list = glob.glob(outpath+"subj_*[0-9].npy")

##############
test_dataset = POEMDatasetTEST(test_list, channels=[0,1], subsampled=True, channels_sub=[0,1], input2=None, channels2=None)
test_loader = data.DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=test_collate)
##############

test_losses = []
test_dices = []
net.eval()
i=0
if do_inference:
    for slike, names in test_loader:
        print(f"Doing test inference, batch nr. {i+1}/{len(test_loader)}")
        i+=1
        slike = [torch.as_tensor(slika, device=device) for slika in slike]  #this even needed? only to_device?
  #     print("len slike: ", len(slike))
  #     print("shape slike[0]: ", slike[0].shape)
        segm = net(*slike)
        #save all processed patches temporarily; without doing softmax, since you need one-hot for dice etc later
        for patchnr, name in enumerate(names):
            np.save(name, np.squeeze(segm[patchnr,...].detach().numpy()))
            #ce ne dam detach se kurac prtozi: RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.


#after inference is done on entire dataset, glue temporarily saved ndarrays, evaluate Dice++, save pngs.
vsitestirani = []
for subj in test_labele:
    nr = re.findall(r'.*label([0-9]*)\.pickle', subj)
    print("Saving tests on subject nr ", nr)
    vsitestirani.append(nr[0])
    segmentacija = torch.from_numpy(
        glue_patches(nr[0], outpath, patchsize, overlap, nb_classes=6)
        ) #glues one person, saves result, returns numpy one one-hot.
    
    tarca = pickle.load(open(subj, 'rb'))
    tarca = torch.from_numpy(tarca[np.newaxis, overlap:-overlap, overlap:-overlap, overlap:-overlap])

    test_loss = napaka(segmentacija, tarca.long())  #napaka needs onehot, does softmax inside.
    test_losses.append(test_loss.item())
    dajs = dice_coeff_per_class(nn.Softmax(dim=1)(segmentacija), tarca, nb_classes=6) #Dice expects softmax!
    test_dices.append(dajs.data.numpy().squeeze())

    #reset for counting and new subject:
    print('{}: Loss {:.4f}, \t Dices {}'.format(nr, test_loss.item(), dajs.data.numpy()))

print("Testing finished. Saving metrics...")
dejta = np.column_stack([np.array(test_losses), np.array(test_dices)])
df = pd.DataFrame(data=dejta,    # values
                  index=vsitestirani,
                columns=np.array(['Loss', 'Dice Bckg', 'Dice Bladder', 'Dice Kidneys', 'Dice Liver', 'Dice Pancreas', 'Dice Spleen']))
df.to_csv(outpath_nets+network_type+timestamp_'_DiceAndLoss_Test.csv')