In [None]:
import hydra
from omegaconf import DictConfig, OmegaConf

import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial

from src import data_loader, train, test, plotting, submissions, unet


@hydra.main(version_base=None, config_path=".", config_name="config")
def run_config(cfg: DictConfig) -> dict:
    print(OmegaConf.to_yaml(cfg))

    # ===== Torch config =====
    device = cfg.device
    torch.manual_seed(cfg.seed)
    torch.set_default_dtype(getattr(torch, cfg.tensor_dtype))

    # ===== Data Loading =====
    train_loader, test_loader = data_loader.get_loader(cfg)

    # ===== Model, Optimizer and Loss function =====
    model = unet.UNet()
    model = model.to(device=device)
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
    )
    loss_fn = torch.nn.functional.binary_cross_entropy

    settings = {
        "model": model,
        "optimizer": optimizer,
        "loss_fn": loss_fn,
        "train_loader": train_loader,
        "test_loader": test_loader,
        "cfg": cfg,
    }

    return settings

if __name__ == "__main__":
    settings = run_config()

In [None]:
def run_training(settings):
    
    return train.train_model(
        settings['model'],
        settings['optimizer'],
        settings['loss_fn'],
        settings['train_loader'],
        settings['cfg'],
    )

if __name__ == "__main__":
    train_losses, train_accs = run_training(settings)

In [None]:
def run_testing(settings):

    return test.test_model(
        settings['model'],
        settings['cfg'].device,
        settings['test_loader'],
        loss_fn=partial(settings['loss_fn'], reduction="none"),
    )

if __name__ == "__main__":
    test_losses, test_accuracies = run_testing(settings)

In [None]:
def run_predictions(settings):

    # ===== Preditions =====
    return submissions.get_predictions(
        settings['model'],
        settings['test_loader'],
        settings['cfg']
    )

if __name__ == "__main__":
    predictions = run_predictions(settings)

In [None]:
def run_submissions(settings, predictions):
    # ===== Plotting of predicitons =====
    plotting.plot_pred_on(
        settings['test_loader'],
        predictions,
        1,
        settings['cfg']
    )

    # ==== Make Submission =====
    patched_preds = submissions.make_submission(
        predictions,
        settings['cfg']
    )

    # ===== Plotting of patched predictions =====
    plotting.plot_pred_on(
        settings['test_loader'],
        patched_preds,
        1,
        settings['cfg']
    )

if __name__ == "__main__":
    run_submissions(settings, predictions)