In [1]:
import argparse
import logging
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
from utils.dice_score import dice_loss
from evaluate import evaluate
from unet_model import UNet
from wholeslidedata.iterators import create_batch_iterator
import numpy as np
from matplotlib import pyplot as plt
from plot_utils import init_plot, plot_batch, show_plot
from shapely.prepared import prep
#import albumentations as A

In [2]:
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
dir_checkpoint = Path('./checkpoints/')
from utils.dice_score import dice_loss
from evaluate import evaluate
from unet_model import UNet

cuda


In [4]:
def dice_loss(output, target):

    eps = 0.0001

    intersection = output * target
    numerator = 2 * intersection.sum(0).sum(1).sum(1)
    denominator = output + target
    denominator = denominator.sum(0).sum(1).sum(1) + eps
    loss_per_channel = (1 - (numerator / denominator))

    return loss_per_channel.sum() / output.size(1)

class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.reshape(-1)
        targets = targets.reshape(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

In [5]:
# number of batches
batches = 64
# batch iterator config
user_config = 'user_config.yml'
# number of cpus used to extract patches on multiple cores
cpus = 1

# create iterators for the dataset
if torch.cuda.is_available():
    try:
        training_iterator = create_batch_iterator(user_config=user_config, mode='training', cpus=1, number_of_batches=batches)
        test_iterator = create_batch_iterator(user_config=user_config, mode='validation', cpus=1, number_of_batches=batches)    
    except:
        print("Exception!!!")
        sys.exit()

In [6]:
net = UNet(n_channels=3, n_classes=2, bilinear=True)
net.to(device=device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr= 1e-4, weight_decay=0.0001)

In [7]:
train_losses = []
test_losses = []

textfile = open("train_dice_loss.txt", "w")
textfile.close()
textfile = open("test_dice_loss.txt", "w")
textfile.close()

In [8]:
for epoch in range(30):
    running_train_loss = 0.0
    running_test_loss = 0.0
    with tqdm(training_iterator, unit="batch") as tepoch:
        for i, data in enumerate(tepoch, 0):
            tepoch.set_description(f"Epoch {epoch}")
            
            # normalize the data
            images, true_masks = data[0], data[1]
            info = data[2]
            
            # convert images to tensors
            images = images.astype(np.float64)
            true_masks = true_masks.astype(np.float64)
            images = torch.Tensor(images)
            true_masks = torch.Tensor(true_masks)
            images = torch.permute(images,(0,3,1,2))
            assert images.shape[1] == net.n_channels, \
                f'Network has been defined with {net.n_channels} input channels, ' \
                f'but loaded images have {images.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'
            images = images.to(device=device, dtype=torch.float32)/255
            true_masks = true_masks.to(device=device, dtype=torch.long)
            
            pred_masks = net(images)
            loss = criterion(pred_masks, true_masks) \
                               + dice_loss(F.softmax(pred_masks, dim=1).float(),
                                           F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float())
            
            # zero the parameter gradients
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            #detaching
            loss_train = loss.detach().cpu().item()
            
            #tracking_loss
            running_train_loss += loss_train
            tepoch.set_postfix(loss=loss_train)
            
            #cleanup
            del images
            del true_masks
            del loss
            del pred_masks
            
            torch.cuda.empty_cache()
    print("Training Done")
    with torch.no_grad():
        for i, data in enumerate(test_iterator, 0):
            images, true_masks = data[0], data[1]
            info = data[2] 
            
            # convert images to tensors
            images = images.astype(np.float64)
            true_masks = true_masks.astype(np.float64)
            images = torch.Tensor(images)
            true_masks = torch.Tensor(true_masks)
            images = torch.permute(images,(0,3,1,2))
            assert images.shape[1] == net.n_channels, \
                f'Network has been defined with {net.n_channels} input channels, ' \
                f'but loaded images have {images.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'
            images = images.to(device=device, dtype=torch.float32)/255
            true_masks = true_masks.to(device=device, dtype=torch.long)
            
            pred_masks = net(images)
            loss = criterion(pred_masks, true_masks) \
                               + dice_loss(F.softmax(pred_masks, dim=1).float(),
                                           F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float())
            #detaching
            loss_test = loss.detach().cpu().item()
            
            #running_loss
            running_test_loss += loss_test
            
            #cleanup
            del images
            del true_masks
            del loss
            del pred_masks
            
            torch.cuda.empty_cache()
    print("Validation Done")
    textfile = open("train_dice_loss.txt", "a")
    textfile.write(str(running_train_loss/len(training_iterator)) + "\n")
    textfile.close()

    train_losses.append(running_train_loss/len(training_iterator))

    textfile = open("test_dice_loss.txt", "a")
    textfile.write(str(running_test_loss/len(test_iterator)) + "\n")
    textfile.close()

    test_losses.append(running_test_loss/len(test_iterator))
    print(train_losses[-1])
    print(test_losses[-1])
    # LR scheduler step
    #scheduler.step(loss_test)
    # save checkpoint weights every 20 epochs 
    if ((epoch + 1) % 2 == 0):
        torch.save(net.state_dict(), './checkpoints/weights_' + str(epoch+1) + '.pth')

Epoch 0: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.958]


Training Done


Epoch 1:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
1.108336791396141
1.036374220624566


Epoch 1: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.846]


Training Done
Validation Done
0.9634148059412837
1.1744936304166913


Epoch 2: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.83] 


Training Done


Epoch 3:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.889397744089365
1.1979550085961819


Epoch 3: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.863]


Training Done
Validation Done
0.854425179772079
1.3544434355571866


Epoch 4: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.793]


Training Done


Epoch 5:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.8233824362978339
1.3481671819463372


Epoch 5: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.664]


Training Done
Validation Done
0.8020014967769384
1.4288484575226903


Epoch 6: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.738]


Training Done


Epoch 7:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.8002567151561379
1.377670873887837


Epoch 7: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.695]


Training Done
Validation Done
0.7768257362768054
1.4718863684684038


Epoch 8: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.751]


Training Done


Epoch 9:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.7499118773266673
1.465862319804728


Epoch 9: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.626]


Training Done
Validation Done
0.7633857745677233
1.485161941498518


Epoch 10: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.689]


Training Done


Epoch 11:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.7376043703407049
1.5139679834246635


Epoch 11: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.707]


Training Done
Validation Done
0.7247852901928127
1.535516919568181


Epoch 12: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.745]


Training Done


Epoch 13:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.7103573963977396
1.5018460499122739


Epoch 13: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.769]


Training Done
Validation Done
0.7050458891317248
1.6488411212339997


Epoch 14: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.712]


Training Done


Epoch 15:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.6913357335142791
1.7039190046489239


Epoch 15: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.599]


Training Done
Validation Done
0.6543138981796801
1.7093359855934978


Epoch 16: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.626]


Training Done


Epoch 17:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.648634220007807
1.7811280079185963


Epoch 17: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.588]


Training Done
Validation Done
0.6494343387894332
1.7404990419745445


Epoch 18: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.67] 


Training Done


Epoch 19:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.643398605287075
1.7643175311386585


Epoch 19: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.562]


Training Done
Validation Done
0.6077534193173051
1.7935045715421438


Epoch 20: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.542]


Training Done


Epoch 21:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.5899172690697014
1.8838801849633455


Epoch 21: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.506]


Training Done
Validation Done
0.5783139234408736
1.8905819449573755


Epoch 22: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.574]


Training Done


Epoch 23:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.5900816265493631
1.9312606863677502


Epoch 23: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.56] 


Training Done
Validation Done
0.557763752527535
1.8800755646079779


Epoch 24: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.533]


Training Done


Epoch 25:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.5196823431178927
1.8952613435685635


Epoch 25: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.629]


Training Done
Validation Done
0.5553075629286468
1.853947477415204


Epoch 26: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.547]


Training Done


Epoch 27:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.5736003369092941
1.845513829961419


Epoch 27: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.602]


Training Done
Validation Done
0.530089573469013
1.8458305764943361


Epoch 28: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.561]


Training Done


Epoch 29:   0%|          | 0/64 [00:00<?, ?batch/s]

Validation Done
0.5129474471323192
1.9471442755311728


Epoch 29: 100%|██████████| 64/64 [00:44<00:00,  1.44batch/s, loss=0.543]


Training Done
Validation Done
0.5280507509596646
1.9370951037853956
