In [None]:
import torch
import torch.nn as nn
from networks import UNet

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import os
import re
from utils import *

ids = [ re.findall("[0-9]+", name)[0] for name in os.listdir("./segmentation/patches") ]

rest, test = train_test_split(ids, test_size=20, random_state=42)
test_dataset = SegmentationDataset(test)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=12)

In [None]:
learning_rate = 1e-4
batch_size = 10
train_loss_fn = nn.CrossEntropyLoss()
valid_loss_fn = nn.CrossEntropyLoss(reduction="sum")
plateau_window = 5
n_runs = 10


losses = [ [] for _ in range(n_runs) ]
dice_scores = [ [] for _ in range(n_runs) ]
test_scores = []

for n in range(n_runs):
    network = UNet(1, 2, [16, 32, 64, 128, 256, 512]).to(device)
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)

    train, validation = train_test_split(rest, test_size=20, random_state=42)

    train_dataset = SegmentationDataset(train, 1, 1)
    validation_dataset = SegmentationDataset(validation, 1, 1)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, num_workers=12)
    breakthrough = 999
    since = 0
    side = 128

    while since != plateau_window:
        # Training
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            x_hat = network(X)
            loss = train_loss_fn(x_hat, y)
            loss.backward()
            optimizer.step()
        # Validating
        with torch.no_grad():
            loss = 0
            dice_score_ = 0
            for X, y in validation_loader:
                X, y = X.to(device,), y.to(device)
                x_hat = network(X)
                loss += valid_loss_fn(x_hat, y).item()
                dice_score_ += dice_score(x_hat.cpu(), y.cpu(), reduction="sum")
            loss = loss / (len(validation_dataset) * side * side ) 
            dice_score_ = dice_score_ / len(validation_dataset)
            dice_scores[n].append(dice_score_)
            since += 1
            if loss < breakthrough:
                breakthrough = loss
                since=0
            print("Loss:", loss, "Dice score:", dice_score_, "N:", n)
            losses[n].append(loss)
    # Testing
    with torch.no_grad():
        dice_score_ = 0
        for X, y in test_loader:
            X, y = X.to(device,), y.to(device)
            x_hat = network(X)
            dice_score_ += dice_score(x_hat.cpu(), y.cpu(), reduction="sum")
        dice_score_ = dice_score_ / len(test_dataset)
        test_scores.append(dice_score_)
        print(dice_score_)
    print("Run ended")

In [None]:
postfix = "base"

with open("loss_val_"+postfix, "w") as file:
    for i, run in enumerate(losses):
        file.write("model {}\n".format(i))
        file.writelines([str(element)+"\n" for element in run])

with open("dice_val_"+postfix, "w") as file:
    for i, run in enumerate(dice_scores):
        file.write("model {}\n".format(i))
        file.writelines([str(element)+"\n" for element in run])

with open("dice_test_"+postfix, "w") as file:
    file.writelines([str(element)+"\n" for element in test_scores])

In [None]:
import matplotlib.pyplot as plt 


plt.figure()
for line in dice_scores:
    plt.plot(line)

plt.grid()