Dataset is taken from https://www.kaggle.com/datasets/dansbecker/cityscapes-image-pairs?select=cityscapes_data

In [None]:
import os

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

from func import CustomData, pixel_accuracy, dsc_score, CrossEntropyWithDiceLoss, plot_learning_curve, mask2image, plot_masks
from models import FCN, Unet

In [None]:
dataset_train = CustomData('train')
dataset_valid = CustomData('val')

print(len(dataset_train), len(dataset_valid))

In [None]:
batch_size = 32

data_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
data_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
def train(model, data_train, lr, n_epochs, device, seed=None):

    if seed is not None:
        # Set seed for reproducibility
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    model = model.to(device)

    optimizer = Adam(params=model.parameters(), lr=lr)
    loss_func = CrossEntropyWithDiceLoss()

    train_loss_lst, valid_loss_lst = [], []
    train_accuracy_lst, valid_accuracy_lst = []
    train_dsc_lst, valid_dsc_lst = [], []

    for epoch in range(n_epochs):
        train_loss, train_acc, train_dsc = 0, 0, 0
        model.train()
        for x_train, y_train in data_train:

            x_train = x_train.to(device)
            y_train = y_train.to(device)

            predict = model(x_train)
            loss = loss_func(predict, y_train.long())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += pixel_accuracy(predict, y_train)
            train_dsc += dsc_score(predict, y_train.long())

        # Take mean of metrics by all batches, add to list for history
        train_loss /= len(data_train)
        train_loss_lst.append(train_loss)
        train_acc /= len(data_train)
        train_accuracy_lst.append(train_acc)
        train_dsc /= len(data_train)
        train_dsc_lst.append(train_dsc)

        # Validation
        valid_loss, valid_acc, valid_dsc = 0, 0, 0
        model.eval()
        with torch.no_grad():
            for x_valid, y_valid in data_valid:
                x_valid, y_valid = x_valid.to(device), y_valid.to(device)
                predict = model(x_valid)
                valid_loss += loss_func(predict, y_valid.long()).item()
                valid_acc += pixel_accuracy(predict, y_valid)
                valid_dsc += dsc_score(predict, y_valid.long())

        # Take mean of metrics by all batches, add to list for history
        valid_loss /= len(data_valid)
        valid_loss_lst.append(valid_loss)
        valid_acc /= len(data_valid)
        valid_accuracy_lst.append(valid_acc)
        valid_dsc /= len(data_valid)
        valid_dsc_lst.append(valid_dsc)

        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f'backup_{epoch+1}.tar')

        print(f'Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f}| Valid Loss: {valid_loss:.4f} | Valid Accuracy: {valid_acc:.4f}')

    history_loss = dict(
        train_loss=train_loss_lst,
        valid_loss=valid_loss_lst,
    )

    history_accuracy = dict(
        train_accuracy=train_accuracy_lst,
        valid_accuracy=valid_accuracy_lst,
    )

    history_dsc = dict(
        train_dsc=train_dsc_lst,
        valid_dsc=valid_dsc_lst,
    )

    return model, history_loss, history_accuracy, history_dsc

In [None]:
net, history_loss, history_accuracy, history_dsc = train(FCN(0.1), data_train, lr=0.01, n_epochs=100, device=device, seed=0)

In [None]:
# Save model

model_fold = 'models'
model_name = type(net).__name__
filename = f'{model_name}.tar'
torch.save(net.state_dict(), os.path.join(model_fold, filename))

In [None]:
plot_fold = 'plots'

fig = plot_learning_curve(history_loss)
fig.savefig(os.path.join(plot_fold, f'{model_name}_learning_curve_loss.png'))

fig = plot_learning_curve(history_accuracy)
fig.savefig(os.path.join(plot_fold, f'{model_name}_learning_curve_accuracy.png'))

fig = plot_learning_curve(history_dsc)
fig.savefig(os.path.join(plot_fold, f'{model_name}_learning_curve_dsc.png'))

In [None]:
# Plot several images from valid data

device = 'cpu'

net = net.to(device)
for idx in range(5):
    img, target_mask = dataset_valid[idx] # Take image from dataset
    predict = net(img.unsqueeze(0).to(device)).squeeze() # Make prediction
    predict_mask = predict.argmax(dim=0).to(device) # Extract prediction mask

    # Construct images from input and masks
    img_in = img.permute(1, 2, 0)
    img_target = mask2image(target_mask.long())
    img_predict = mask2image(predict_mask)

    # Plot images
    fig = plot_masks(img.permute(1, 2, 0), img_target, img_predict)

fig.savefig()
fig.savefig(os.path.join(plot_fold, f'{model_name}_result_example.png'))