In [1]:
import sys
sys.path.append('../')

In [2]:
import os
import torch
import segmentation_models_pytorch as smp

from scripts.evaluation import get_test_f1
from scripts.plotting import plot_metric_per_epoch, plot_n_predictions
from scripts.preprocessing import RoadDataset, split_data
from scripts.training import train_model
from torch.utils.data import DataLoader

In [3]:
ROOT_PATH = os.path.normpath(os.getcwd() + os.sep + os.pardir)
train_directory = os.path.join(ROOT_PATH, 'data', 'raw', 'training')

In [4]:
image_path_train, image_path_test, mask_path_train, mask_path_test = split_data(train_directory, 0.1)
# Get train and val dataset instances
train_dataset = RoadDataset(image_path_train, mask_path_train)
test_dataset = RoadDataset(image_path_test, mask_path_test)

In [5]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid'

SEED = 13
BATCH_SIZE = 4
K_FOLD = 2
N_CPU = os.cpu_count()
N_EPOCHS = 8

LOADER_PARAMS = {
    'batch_size': BATCH_SIZE, 
    'num_workers': N_CPU, 
    'persistent_workers': True
}

In [6]:
model_ = smp.create_model("FPN", encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS)
criterion_ = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer_ = torch.optim.Adam(model_.parameters(), lr=0.0005)

train_set, val_set = torch.utils.data.random_split(train_dataset, lengths=[0.9, 0.1])

train_loader = DataLoader(train_dataset, **LOADER_PARAMS)
valid_loader = DataLoader(train_dataset, **LOADER_PARAMS)

In [None]:
train_losses, valid_losses, train_f1s, valid_f1s = train_model(
    model_, (train_loader, valid_loader), criterion_, optimizer_, N_EPOCHS
)

0 Loss: 0.3571 Acc: 0.6252
1 Loss: 0.2364 Acc: 0.7474
2 Loss: 0.1921 Acc: 0.7872
3 Loss: 0.1747 Acc: 0.8093
4 Loss: 0.1571 Acc: 0.8216


In [None]:
plot_metric_per_epoch(train_losses, valid_losses, 'DiceLoss')

In [None]:
plot_metric_per_epoch(train_f1s, valid_f1s, 'f1')

In [None]:
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
plot_n_predictions(model_, test_loader)

In [None]:
'f1 score for the test dataset {:.3f}.'.format(get_test_f1(model_, test_loader))