In [None]:
from datetime import datetime
import json
import os
import copy
import sys
import time
import warnings
import matplotlib.pyplot as plt
from src.

import numpy as np
from tqdm import tqdm

import torch
from torch.utils.tensorboard import SummaryWriter
import torch.optim
from torch.optim.lr_scheduler import OneCycleLR
from src.args import ArgumentParser
from src.build_model import build_model
from src import utils
from src.prepare_data import prepare_data
from src.utils import save_ckpt_every_epoch
from src.utils import load_ckpt
from src.utils import print_log


from torchmetrics import JaccardIndex as IoU

In [None]:
args = ArgumentParser()

In [None]:
# directory for storing weights and other training related files
training_starttime = datetime.now().strftime("%d_%m_%Y-%H_%M_%S-%f")
ckpt_dir = os.path.join(
    args.results_dir, args.dataset, f"{args.id}", f"{training_starttime}"
)
os.makedirs(ckpt_dir, exist_ok=True)
os.makedirs(os.path.join(ckpt_dir, "confusion_matrices"), exist_ok=True)

with open(os.path.join(ckpt_dir, "args.json"), "w") as f:
    json.dump(vars(args), f, sort_keys=True, indent=4)

with open(os.path.join(ckpt_dir, "argsv.txt"), "w") as f:
    f.write(" ".join(sys.argv))
    f.write("\n")

# data preparation ---------------------------------------------------------
data_loaders = prepare_data(args, ckpt_dir)

train_loader, valid_loader, _ = data_loaders

n_classes_without_void = train_loader.dataset.n_classes_without_void
if args.class_weighting != "None":
    class_weighting = train_loader.dataset.compute_class_weights(
        weight_mode=args.class_weighting, c=args.c_for_logarithmic_weighting
    )
else:
    class_weighting = np.ones(n_classes_without_void)
# model building -----------------------------------------------------------
model, device = build_model(args, n_classes=n_classes_without_void)
if args.freeze > 0:
    print("Freeze everything but the output layer(s).")
    for name, param in model.named_parameters():
        if "out" not in name:
            param.requires_grad = False

# loss, optimizer, learning rate scheduler, csvlogger  ----------

# loss functions
loss_function_train = utils.CrossEntropyLoss2d(
    weight=class_weighting, device=device
)
loss_objectosphere = utils.ObjectosphereLoss()
loss_mav = utils.OWLoss(n_classes=n_classes_without_void)
loss_contrastive = utils.ContrastiveLoss(n_classes=n_classes_without_void)

pixel_sum_valid_data = valid_loader.dataset.compute_class_weights(
    weight_mode="linear"
)
pixel_sum_valid_data_weighted = np.sum(pixel_sum_valid_data * class_weighting)
loss_function_valid = utils.CrossEntropyLoss2dForValidData(
    weight=class_weighting,
    weighted_pixel_sum=pixel_sum_valid_data_weighted,
    device=device,
)

train_loss = [loss_function_train, loss_objectosphere, loss_mav, loss_contrastive]
val_loss = [loss_function_valid, loss_objectosphere, loss_mav, loss_contrastive]
if not args.obj:
    train_loss[1] = None
    val_loss[1] = None
if not args.mav:
    train_loss[2] = None
    val_loss[2] = None
if not args.closs:
    train_loss[3] = None
    val_loss[3] = None

optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
            betas=(0.9, 0.999),
        )

# in this script lr_scheduler.step() is only called once per epoch
lr_scheduler = OneCycleLR(
    optimizer,
    max_lr=[i["lr"] for i in optimizer.param_groups],
    total_steps=args.epochs,
    div_factor=25,
    pct_start=0.1,
    anneal_strategy="cos",
    final_div_factor=1e4,
)

# load checkpoint if parameter last_ckpt is provided
if args.last_ckpt:
    ckpt_path = args.last_ckpt
    epoch_last_ckpt, best_miou, best_miou_epoch, mav_dict, std_dict = load_ckpt(
        model, optimizer, ckpt_path, device
    )
    start_epoch = epoch_last_ckpt + 1
else:
    start_epoch = 0
    best_miou = 0
    best_miou_epoch = 0

if args.load_weights:
    model.load_state_dict(torch.load(args.load_weights))

writer = SummaryWriter("runs/" + ckpt_dir.split(args.dataset)[-1])

# start training -----------------------------------------------------------
for epoch in range(int(start_epoch), args.epochs):
    # unfreeze
    if args.freeze == epoch and args.finetune is None:
        for param in model.parameters():
            param.requires_grad = True

    mean, var = train_one_epoch(
        model=model,
        train_loader=train_loader,
        device=device,
        optimizer=optimizer,
        train_loss=train_loss,
        epoch=epoch,
        lr_scheduler=lr_scheduler,
        debug_mode=args.debug,
        writer=writer,
    )

    miou = validate(
        model=model,
        valid_loader=valid_loader,
        device=device,
        val_loss=val_loss,
        epoch=epoch,
        debug_mode=args.debug,
        writer=writer,
        classes=args.num_classes,
    )

    writer.flush()

    # save weights
    if not args.overfit:
        # save / overwrite latest weights (useful for resuming training)
        save_ckpt_every_epoch(
            ckpt_dir, model, optimizer, epoch, best_miou, best_miou_epoch, mean, var
        )
        if (epoch + 1) % 20 == 0:
            torch.save(
                model.state_dict(),
                os.path.join(ckpt_dir, "epoch_" + str(epoch) + ".pth"),
            )
        if miou > best_miou:
            best_miou = miou
            best_miou_epoch = epoch
            torch.save(
                model.state_dict(),
                os.path.join(ckpt_dir, "best_miou.pth"),
            )

# save mavs to a pickle
with open("mavs.pickle", "wb") as h1:
    pickle.dump(mean, h1, protocol=pickle.HIGHEST_PROTOCOL)
with open("vars.pickle", "wb") as h2:
    pickle.dump(var, h2, protocol=pickle.HIGHEST_PROTOCOL)

print("Training completed ")
