### Training

Next, we perform training. We use the train, dev, and test folders specified above, and our logging is done in tensorboard. The model is saved every `save_every` epochs; the model with the best IoU is also saved under the name `model_best`.

In [1]:
from addict import Dict
from pathlib import Path

data_dir = Path("glacier_data")
process_dir = data_dir / "splits"

args = Dict({
    "batch_size": 16,
    "run_name": "demo", 
    "epochs": 200,
    "save_every": 50,
    "loss_type": "dice",
    "device": "cuda:0"
})


In [2]:
from glacier_mapping.data.data import fetch_loaders
from glacier_mapping.models.frame import Framework
import glacier_mapping.train as tr
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from glacier_mapping.models.metrics import diceloss
import yaml
import torch
import json

conf = Dict(yaml.safe_load(open("conf/train.yaml", "r")))
loaders = fetch_loaders(process_dir, args.batch_size)
device = torch.device(args.device)

loss_fn = None
outchannels = conf.model_opts.args.outchannels
if args.loss_type == "dice":
    loss_fn = diceloss(
        act=torch.nn.Softmax(dim=1), 
        w=[0.6, 0.9, 0.2], # clean ice, debris, background
        outchannels=outchannels, 
        label_smoothing=0.2
    )
    
frame = Framework(
    model_opts=conf.model_opts,
    optimizer_opts=conf.optim_opts,
    reg_opts=conf.reg_opts,
    device=device,
    loss_fn=loss_fn
)

# Setup logging
writer = SummaryWriter(f"{data_dir}/{args.run_name}/logs/")
writer.add_text("Arguments", json.dumps(vars(args)))
writer.add_text("Configuration Parameters", json.dumps(conf))
out_dir = f"{data_dir}/{args.run_name}/models/"

best_epoch, best_iou = None, 0
for epoch in range(args.epochs):
    loss_d = {}
    for phase in ["train", "val"]:
        loss_d[phase], metrics = tr.train_epoch(loaders[phase], frame, conf.metrics_opts)
        tr.log_metrics(writer, metrics, loss_d[phase], epoch, phase, mask_names=conf.log_opts.mask_names)

    # save model
    writer.add_scalars("Loss", loss_d, epoch)
    if (epoch + 1) % args.save_every == 0:
        frame.save(out_dir, epoch)
        tr.log_images(writer, frame, next(iter(loaders["train"])), epoch)
        tr.log_images(writer, frame, next(iter(loaders["val"])), epoch, "val")

    if best_iou <= metrics['IoU'][0]:
        best_iou  = metrics['IoU'][0]
        best_epoch = epoch
        frame.save(out_dir, "best")

    print(f"{epoch}/{args.epochs} | train: {loss_d['train']} | val: {loss_d['val']}")

frame.save(out_dir, "final")
writer.close()

  y = torch.tensor(y, dtype=torch.long, device=self.device)


0/200 | train: 0.054354348008377434 | val: 0.06277614831924438
1/200 | train: 0.051559342105457116 | val: 0.05159737305207686
2/200 | train: 0.04373037737281142 | val: 0.05053310069170865
3/200 | train: 0.04218079838366795 | val: 0.046574049646204166
4/200 | train: 0.04125892524619638 | val: 0.046028152379122646
5/200 | train: 0.04040143975389844 | val: 0.04516075090928511
6/200 | train: 0.03955614099925865 | val: 0.04464033408598466
7/200 | train: 0.038797521248809974 | val: 0.043514727462421764
8/200 | train: 0.03800808854264939 | val: 0.04281104369597002
9/200 | train: 0.03688909101735207 | val: 0.04028119065544822
10/200 | train: 0.03624522094626962 | val: 0.04007151343605735
11/200 | train: 0.035609041274994535 | val: 0.03901559222828258
12/200 | train: 0.03527434360887612 | val: 0.03795452659780329
13/200 | train: 0.03484513118435135 | val: 0.03853121128949252
14/200 | train: 0.0342893503197802 | val: 0.03747232773087241
15/200 | train: 0.03408183278053921 | val: 0.03803189613602