In [1]:
from torch.utils.data import DataLoader
from create_dataset import data
from training import train_model, val_epoch
from nix import NIX
import torch

In [2]:
# Set device
device = "cpu"
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda_is_available():
    device = "cuda"

device = torch.device(device)

#Set Paths
PATH_TRAIN = "/Users/pauladler/MPDL_Project_2/data/train"
PATH_VAL = "/Users/pauladler/MPDL_Project_2/data/val"
PATH_TEST = "/Users/pauladler/MPDL_Project_2/data/test"
PATH_SAVE = 'nix.pth'

#Set Parameters for creating the Dataset
num_workers = 4
batch_size = 8

# Set Parameters for model
img_width, img_height = 512, 512

# Set hyperparameters for training
learning_rate = 0.0001

In [3]:
train_data = data(PATH_TRAIN)
val_data = data(PATH_VAL)

train_dataloader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)

In [4]:
model = NIX(img_width, img_height)
model = model.to(device)

In [5]:
model = train_model(model, train_dataloader, val_dataloader, learning_rate, device)

[INFO] Training with lr: 0.001
[INFO] Epoch: 1


100%|██████████| 1125/1125 [32:26<00:00,  1.73s/it]


Train loss: 0.255575


100%|██████████| 63/63 [00:47<00:00,  1.34it/s]


Val loss: 0.212649
[INFO] Epoch: 2


  0%|          | 1/1125 [00:05<1:46:32,  5.69s/it]

In [None]:
torch.save(model.state_dict(), PATH_SAVE)

In [None]:
test_data = data(PATH_TEST)

test_data = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)

In [None]:
test_loss = val_epoch(model, test_data, device)
print("[INFO] Val loss: {:.6f}".format(test_loss))