In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
from os.path import join as pjoin
import collections
import json
import torch
import imageio
import numpy as np
import scipy.misc as m
import scipy.io as io
import matplotlib.pyplot as plt
import glob
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torch.utils import data
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from Dataloader import Trainloader, Testloader
from model import R2UNet
from metrics import evaluate

In [None]:

#HYPERPARAMETERS
BATCH_SIZE = 4
STEP_SIZE = 0.007
EPOCHS = 1
wgtFile = "/home/s9dxschm/r2unet.pt"

In [None]:
# Creating an instance of the model defined above. 
# You can modify it incase you need to pass paratemers to the constructor.
model = R2UNet()

In [None]:
trainloader = Trainloader(BATCH_SIZE = BATCH_SIZE)
testloader = Testloader(BATCH_SIZE = BATCH_SIZE)



In [None]:
# loss function
loss_f = nn.CrossEntropyLoss()
# optimizer variable
optimizer = torch.optim.SGD(model.parameters(), lr=STEP_SIZE)

In [None]:
#load model
if os.path.isfile(wgtFile):
    model = R2UNet()
    model.load_state_dict(torch.load(wgtFile))

# train model
for _ in range(EPOCHS):
    for i, d in enumerate(trainloader):
        print (i)
        images_batch, labels_batch = d
        labels_batch = (labels_batch * 255).long() #when converted to tensor labels have to be multiplied by 255 to get back classes
        if labels_batch.shape[0] == BATCH_SIZE:
            labels_batch = torch.reshape(labels_batch, (BATCH_SIZE, 256, 512))
            optimizer.zero_grad()
            outputs = model(images_batch)
            loss = loss_f(outputs, labels_batch)
            loss.backward()
            optimizer.step()
    #save model
torch.save(model.state_dict(), "/home/s9dxschm/r2unet.pt")

In [None]:
#testing model on validation dataset
with torch.no_grad():
    for i, data in enumerate(testloader):
      
        images_batch, labels_batch = data
        labels_batch = (labels_batch * 255)
        outputs = model(images_batch)
        _, predicted = torch.max(outputs, 1)
        

        if i == 0:
            ground_truth = labels_batch
            predictions = predicted
        else:
            ground_truth = torch.cat((ground_truth, labels_batch), 0)
            predictions = torch.cat((predictions, predicted), 0)
        break
Acc, SE, SP, F1, Dice = evaluate(ground_truth, predictions)

np.savetxt("Accuracy", np.array([Acc]))
np.savetxt("Sensitivity", np.array([SE]))
np.savetxt("Specificity", np.array([SP]))
np.savetxt("F1 score", np.array([F1]))
np.savetxt("Dice coef", np.array([Dice]))

