In [1]:
import numpy as np
import pandas as pd 
import torch
from glob import glob
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as T
import utils

import os
import pandas as pd
cwd = os.getcwd()

In [2]:
def save_results(name, train_loss, val_loss):
    pd.DataFrame(train_loss).to_csv(
        f"{cwd}/results/{name}/{name}_train.txt")
    pd.DataFrame(val_loss).to_csv(
        f"{cwd}/results/{name}/{name}_val.txt")


In [3]:
import matplotlib.pyplot as plt 

def plot_results(name, train_loss, val_loss):
    plt.figure(figsize=(10, 8))
    plt.plot(train_loss,color='b',label='train loss')
    plt.plot(val_loss,color='r',label = 'val_loss')
    plt.legend()
    plt.savefig(f"{cwd}/results/{name}/{name}_graph") 

In [4]:
from model import UNet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def run(train_data, val_data, name, lr=0.01):
    save_path = f"{cwd}/results/{name}/"
    model = UNet(3).float().to(device)
    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_info = utils.train(
        model, train_data, val_data, loss_function, optimizer,
        epochs=30, show_every=5, save_path=f"{save_path}{name}"
    )

    save_results(name, *train_info)
    plot_results(name, *train_info)
    torch.save(model.state_dict(), f"{save_path}model.pt")

In [5]:
IMG_SIZE =256
transformImage = T.Compose([
        T.ToTensor(),
        T.Resize((IMG_SIZE, IMG_SIZE))
    ]
)

ds_kwargs = {
    "transform_orig": transformImage,
    "transform_seg": transformImage,
}

### Cityscapes

In [None]:
DATA_DIR = "/home/will/Documents/datasets/cityscapes-image-pairs/cityscapes_data"

train_path = glob(f'{DATA_DIR}/train/*')
valid_path = glob(f'{DATA_DIR}/val/*')

train_data = utils.AnyDataset(train_path, **ds_kwargs)
val_data = utils.AnyDataset(valid_path, **ds_kwargs)

run(train_data, val_data, "cityscapes")

### Carla

In [None]:
DATA_DIR = "/home/will/Documents/datasets/carla"
images = glob(f'{DATA_DIR}/images/*')
labels = glob(f'{DATA_DIR}/labels/*')

full_dataset = utils.AnyDataset(images, labels, transformImage, transformImage)
split = int(len(full_dataset) * 0.7)
train_data, val_data = torch.utils.data.random_split(
    full_dataset, [split, len(full_dataset) - split])

run(train_data, val_data, "carla")