In [None]:
from omegaconf import DictConfig, OmegaConf

import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial

from src import data_loader, train, test, plotting, submissions, unet

cfg = OmegaConf.load("config.yaml")

print(OmegaConf.to_yaml(cfg))

# ===== Torch config =====
device = cfg.device
torch.manual_seed(cfg.seed)
torch.set_default_dtype(getattr(torch, cfg.tensor_dtype))

# ===== Data Loading =====
train_loader, test_loader = data_loader.get_loader(cfg)

# ===== Model, Optimizer and Loss function =====
model = unet.UNet()
model = model.to(device=device)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
loss_fn = torch.nn.functional.binary_cross_entropy

## Model training ##

In [None]:
train_losses, train_accs = train.train_model(
    model,
    optimizer,
    loss_fn,
    train_loader,
    cfg,
)

## Model testing ##

In [None]:
test_losses, test_accuracies = test.test_model(
    model,
    device,
    test_loader,
    loss_fn=partial(loss_fn, reduction="none"),
)

## Preditions ##

In [None]:
predictions = submissions.get_predictions(model, test_loader, cfg)

## Plotting of predicitons ##

In [None]:
plotting.plot_pred_on(test_loader, predictions, 1, cfg)

## Make submissions ##

In [None]:
patched_preds = submissions.make_submission(predictions, cfg)

## Plotting of patched predictions ##

In [None]:
plotting.plot_pred_on(test_loader, patched_preds, 1, cfg)