In [None]:
import pandas as pd
import os
import albumentations as alb
import tensorboard
#import tensorflow as tf
import datetime
import torch
import numpy as np

from model import UNet
from utils import plot, get_data_loaders, evaluate, get_dice_score
#from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from albumentations.pytorch.transforms import ToTensorV2

In [None]:
# Define constants

# directories
data_dir = 'data' # change to directory containing the data
trained_models = 'trained_models'
train_dir = 'train'
log_dir = 'runs'

In [None]:
# Training configuration (hyperparameters)

## Data
test_split = .2 #20% for test split
#valdation_split = .2 #20% for validation split
random_seed = np.random.seed()
shuffle_dataset = True

transform = alb.Compose([
    alb.RandomCrop(width=256, height=256),
    alb.HorizontalFlip(p=0.5),
    ToTensorV2()
    ],
    # we want the mask and the image to have the same augmentation (especially when we crop)
    # this way we pass the image and the mask simultaneously to the pipeline
    additional_targets={'image': 'image', 'mask': 'mask'}
    )

## model architecture
in_channels = 3
min_channels = 16
max_channels = 256
num_classes = 5

## Training
learning_rate = 0.1
batch_size = 16
epochs = 50

In [None]:
# Training

# setup training enviroment

#Labels


# init data loader/generator
train_dataloader, test_dataloader = get_data_loaders(data_dir, transform, shuffle_dataset, test_split, random_seed, batch_size)

# init model, optimizer and loss function
model = UNet(in_channels, min_channels, max_channels, num_classes)
opt = torch.optim.SGD(model.parameters(), learning_rate)
loss_func = get_dice_score

# Set up summary writer for tensorboard
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(os.path.join(log_dir, current_time))


# start training
for epoch in range(epochs):
    for x, y in train_dataloader:
        pred = model(x)
        # TODO: check the output of the model. is it one hot encoded, rgb or else ?
        pred = torch.argmax(pred.T, 2).detach().numpy()
        loss = loss_func(pred, y)
        writer.add_scalar('Loss', loss, epoch)
        loss.backward()
        opt.step()
        opt.zero_grad() 
    
    # Evaluate model
    model.eval()
    evaluate(model, writer, test_dataloader, epoch)
    model.train()

# Save model after training
torch.save(model.state_dict(), os.path.join(trained_models, current_time))

In [None]:
# Model evaluation

model = UNet(0, 0, 0, 0)
model.load_state_dict(torch.load(PATH))
model.eval()