# P4_neural-odes-segmentation-master_CRAG

## Google Colab

In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
cd drive/MyDrive/BA_SemanticSegmentation_JonasHeinke/___P4_neural-odes-segmentation-master_CRAG/

In [None]:
ls

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

### Installation des Moduls torchdiffeq

In [None]:
!pip install git+https://github.com/rtqichen/torchdiffeq

## Change these flags to train a specific model

In [None]:
TRAIN_RESNET = False
TRAIN_UNODE = False
TRAIN_UNET = True

def get_title():
    if TRAIN_UNODE: return 'U-NODE'
    elif TRAIN_RESNET: return 'RESNET'
    elif TRAIN_UNET: return 'UNET'

---

In [None]:
import os
import glob
import random

import torch
import torch.utils.data

import PIL
import skimage.measure
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm, tqdm_notebook

%matplotlib inline

from models import ConvODEUNet, ConvResUNet, ODEBlock, Unet
#from dataloader import GLaSDataLoader
from dataloader_crag import Crag_DataLoader
from train_utils import plot_losses

from IPython.display import clear_output



## Datenset laden
- Dieser Teil unterscheidet sich in Bezug zum Original, da eine Anpassung an das zu ladende Datenset CRAG_v2 erfolgt.

#### Download the filnames of dataset

MILD-Net: "Colorectal Adenocarcinoma Gland (CRAG) Dataset"

https://warwick.ac.uk/services/its/intranet/projects/webdev/sandbox/juliemoreton/research-copy/tia/data/mildnet

Datenset zum Herunterladen:

https://drive.google.com/u/0/uc?id=1p3dZXpgeA1IcGO6vXhStbVLMku-fZTmQ&export=download

In [None]:
if not os.path.exists('CRAG_v2'):
    print('Bitte laden sie das Datenset in das Projektverzeichnis!')
    print('Das Verzeichnis lautet "CRAG_v2".')

In [None]:
from jh_path import Path as PATH   # Pfade und Dateinamen
path=PATH() # Instanz der Klasse für Methodenaufruf erforderlich

path_images=PATH.dataset / 'train/Images/'
path_targets=PATH.dataset / 'train/Annotation/'
# input and target files
filenames_inputs  =path.get_filenames(path=path_images , dateifilter= '*.png')
filenames_targets =path.get_filenames(path=path_targets ,dateifilter='*.png')
# sicherstellen, dass beide Listen die gleiche Länge besitzen.
print('Anzahl der Bilder      : ', len(filenames_inputs))
print('Anzahl der Annotationen: ', len(filenames_targets))

## Define datasets

In [None]:

torch.manual_seed(0)

val_set_idx = torch.LongTensor(10).random_(0, len(filenames_inputs)) #+
train_set_idx = torch.arange(0, len(filenames_inputs)) #+

overlapping = (train_set_idx[..., None] == val_set_idx).any(-1)
train_set_idx = torch.masked_select(train_set_idx, ~overlapping)


## Größe der Eingangsbild-Masken-Paare vereinbaren
- Verhältnis zwischen Höhe zu Weite h/w=1 

In [None]:

h=512 # 1024 #752 #758  #752 # orginal 1516
w=512 # 1024 #752 #754  #752  # orginal 1509

trainset = Crag_DataLoader(filenames_inputs,filenames_targets,(h, w), dataset_repeat=1, images=train_set_idx)
valset = Crag_DataLoader(filenames_inputs,filenames_targets,(h, w), dataset_repeat=1, images=val_set_idx, validation=True)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=4)
valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False, num_workers=4)


# Plotting train data

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=6, figsize=(24, 15))

for y in range(5): # y ist idx einer Stichprobe
    for x in range(3):  
        sample = trainset[y]
        ax[y, x * 2].imshow(sample[0].numpy().transpose(1,2,0))
        ax[y, x * 2 + 1].imshow(sample[1][0])
        ax[y, x * 2].axis('off')
        ax[y, x * 2 + 1].axis('off')

plt.show();

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(24, 15))

sample = trainset[0]
ax[1].imshow(sample[1][0].numpy())
ax[2].imshow(sample[1].sum(dim=0))
ax[0].imshow(sample[0].numpy().transpose(1,2,0))

# Plotting validation data

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=6, figsize=(24, 15))

for y in range(5):
    for x in range(3):
        sample = valset[y]
        ax[y, x * 2].imshow(sample[0].numpy().transpose(1,2,0))
        ax[y, x * 2 + 1].imshow(sample[1][1])
        ax[y, x * 2].axis('off')
        ax[y, x * 2 + 1].axis('off')

plt.show(); 

# Define network

In [None]:
device = torch.device('cuda')

if TRAIN_UNODE:
    net = ConvODEUNet(num_filters=16, output_dim=2, time_dependent=True, 
                      non_linearity='lrelu', adjoint=True, tol=1e-3)
    net.to(device)

In [None]:
if TRAIN_RESNET:
    net = ConvResUNet(num_filters=16, output_dim=2, non_linearity='lrelu')
    net.to(device)

In [None]:
if TRAIN_UNET:
    net = Unet(depth=5, num_filters=64, output_dim=2).cuda()
    net.to(device)

---

In [None]:
for m in net.modules():
    if isinstance(m, torch.nn.Conv2d):
        torch.nn.init.kaiming_normal_(m.weight)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
count_parameters(net)

# Train model

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
val_criterion = torch.nn.BCEWithLogitsLoss()

if TRAIN_UNET:
    cross_entropy = torch.nn.BCEWithLogitsLoss()

    def criterion(conf, labels):
        out_shape = conf.shape[2:4]
        label_shape = labels.shape[2:4]

        w = (label_shape[1] - out_shape[1]) // 2
        h = (label_shape[1] - out_shape[1]) // 2
        dh, dw = out_shape[0:2]

        conf_loss_ce = cross_entropy(conf, labels[:, :, h:h+dh, w:w+dw])

        return conf_loss_ce

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
losses = []
val_losses = []
nfe = [[],[],[],[],[],[],[],[],[]] if TRAIN_UNODE else None

In [None]:
accumulate_batch =8  # 8  # mini-batch size by gradient accumulation
accumulated = 0

if TRAIN_RESNET: filename = 'best_border_resnet_model.pt'
elif TRAIN_UNODE: filename = 'best_border_unode_model.pt'
elif TRAIN_UNET: filename = 'best_border_unet_model.pt'

def run(lr=1e-3, epochs=100):
    accumulated = 0
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    for epoch in range(epochs):
        
        # training loop with gradient accumulation
        running_loss = 0.0
        optimizer.zero_grad()
        for data in tqdm(trainloader):
            inputs, labels = data[0].cuda(), data[1].cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels) / accumulate_batch
            loss.backward()
            accumulated += 1
            if accumulated == accumulate_batch:
                optimizer.step()
                optimizer.zero_grad()
                accumulated = 0

            running_loss += loss.item() * accumulate_batch

        losses.append(running_loss / len(trainloader))
        
        # validation loop
        with torch.no_grad():
            running_loss = 0.0
            for data in valloader:
                inputs, labels = data[0].cuda(), data[1].cuda()
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item()

            val_losses.append(running_loss / len(valloader))
            # Bedingung zum Speichern des Modells
            if np.argmin(val_losses) == len(val_losses) - 1 and loss < 0.4:
                torch.save(net, filename)
                #------Protokoll--------------------------
                protokolldatei = open('_protokoll.txt','a') #+
                protokolldatei.write('---------------------------------------------\n')  #+
                protokolldatei.write(f'Speicherung des Modells nach {epoch} Epochen, loss: {loss}\n') #+
                protokolldatei.close() #+
                #-------------------------------------
                
            clear_output(wait=True)
            plot_losses(inputs, outputs, losses, val_losses, get_title(), nfe, net=net)

In [None]:
if TRAIN_UNODE or TRAIN_RESNET: lr = 1e-3 
else: lr = 1e-4

run(lr, 200 - len(losses))

## Calculate results

In [None]:
# load best model
net = torch.load(filename)

In [None]:
with torch.no_grad():
    running_loss = 0.0
    for data in tqdm(valloader):
        inputs, labels = data[0].cuda(), data[1].cuda()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

    print("Check validation loss:", running_loss / len(valloader))

## Visualize results on validation set

In [None]:
#+ from inference_utils import inference_image, postprocess
from inference_utils_crag import inference_image, postprocess #+
import numpy as np              #+
import matplotlib.pyplot as plt #*

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=3, figsize=(4*3,3*5))

ax[0, 0].set_title('Image')
ax[0, 1].set_title('Ground-truth')
ax[0, 2].set_title(get_title())

for col in range(3):
    for row in range(5):
        index = val_set_idx[row]
        print(f'idx: {index}, ', end='')
        image = PIL.Image.open(path_images /  f'train_{index}.png')
        gt = PIL.Image.open(path_targets / f'train_{index}.png')
        
        with torch.no_grad():
            #? result, input_image = inference_image(net, image, shouldpad=TRAIN_UNET)
            result, input_image  = inference_image(net, image, shouldpad=TRAIN_UNET)
            result = postprocess(result, gt)
        if col == 0:
            ax[row, col].imshow(image)
        elif col == 1:
            ax[row, col].imshow(np.array(gt) > 0)
        else:
            ax[row, col].imshow(image)
            ax[row, col].imshow(result, alpha=0.5)
                
        ax[row, col].set_axis_off()
        


plt.show(); 

# Calculate metrics on test set

In [None]:
from metrics import ObjectDice, ObjectHausdorff, F1score
import torch
import numpy as np
import PIL
import skimage.measure
from tqdm import tqdm, tqdm_notebook
from inference_utils_crag import inference_image, postprocess
import matplotlib.pyplot as plt
########################################################
from img_array_transform_jh import ArrayTransform as TRANSFORM
from _path import Path as PATH   # Pfade und Dateinamen
path=PATH() # Instanz der Klasse für Methodenaufruf erforderlich

In [None]:
TEST_RESNET = False
TEST_UNODE = False
TEST_UNET = True

In [None]:
if TEST_UNODE: net = torch.load('best_border_unode_model.pt')
elif TEST_RESNET: net = torch.load('best_border_resnet_model.pt')
elif TEST_UNET: net = torch.load('best_border_unet_model.pt')

In [None]:
path_testimages=PATH.dataset / 'valid/Images/'
path_testtargets=PATH.dataset / 'valid/Annotation/'

filenames_testinputs  =path.get_filenames(path=path_testimages , dateifilter= '*.png')
filenames_testtargets =path.get_filenames(path=path_testtargets ,dateifilter='*.png')

print('Anzahl der Bilder      : ', len(filenames_testinputs))
print('Anzahl der Annotationen: ', len(filenames_testtargets))

anzahl_testimages=len(filenames_testinputs)

### Identnummernbezogene Bewertung
[ Id0 - Hintergrund, Durchnummerierung der Drüsen je maske von Id=1, 2 , 3 , ..., 

- Anpassungen, da nur ein Testset
- Protokoll zum Abspeichern der Ergebnisse

In [None]:

dice, hausdorff, f1, dice_full = 0, 0, 0, 0

if TEST_UNODE: folder = 'results_unode'
elif TEST_UNET: folder = 'results_unet'
elif TEST_RESNET: folder = 'results_resnet'
    
names = []
i_error=0
anzahl=0

for index in np.arange(1, anzahl_testimages+1):
    names.append(f'test_{index}.png')
  
 
for i, fname in tqdm_notebook(enumerate(names), total=anzahl_testimages):
    # tqdm.notebook.tqdm 
    # print(f'idx: {i}, ', end='') #
    image = PIL.Image.open(path_testimages /  fname)
    gt = PIL.Image.open(path_testtargets / fname)
    
    result, resized = inference_image(net, image, shouldpad=TEST_UNET)
    result = postprocess(result, gt)
    
    
    gt = skimage.measure.label(np.array(gt))
    

    #-# f1_img, hausdorff_img, dice_img =0, 0, 0
    try:

        f1_img = F1score(result, gt)
        hausdorff_img = ObjectHausdorff(result, gt)
        dice_img = ObjectDice(result, gt)
        
        f1 += f1_img
        hausdorff += hausdorff_img
        dice += dice_img
        print(i,', ', fname,' : ', f1_img, hausdorff_img, dice_img)
        anzahl +=1
        
        #------Protokoll--------------------------
        protokolldatei = open('_protokoll.txt','a') #+
        protokolldatei.write('---------------------------------------------\n')  #+
        protokolldatei.write(f'i: {i}: filename: {fname}, f1_img: {f1_img}, hausdorff: {hausdorff_img}, dice_img: {dice_img} \n') #+
        #protokolldatei.write(f'image.shape: {image.shape}, gt.shape: {gt.shape}, result.shape: {result.shape} \n') #+
        protokolldatei.close() #+
        #-------------------------------------

    except:
        i_error +=1
        print('Error: ',i_error, 'Zyklus: ', i, 'Dateiname: ', fname)
    

print('--Mittelwerte, Drüsen- bzw. Identnummernbezogen------------------------------------')
print('ObjectDice:', dice / anzahl )
print('Hausdorff:', hausdorff / anzahl)
print('F1:', f1 / anzahl )
print('Anzahl io.: ', anzahl)
print('Errors: ', i_error)
