In [None]:
%load_ext autoreload
%autoreload 2

import os
import yaml
import time
import torch
import random
import numpy as np
from tqdm import tqdm

from torch.utils import data

from tensorboardX import SummaryWriter

import data_loader
import models
from optimizer import get_optimizer
from loss.cross_entropy_2d import CrossEntropy2d
from utils.metrics import MetricsComp, AverageComp

SEED = 1337

def train(config, writer=None):
    # setup seeds
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    torch.cuda.empty_cache()
    
    # setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # setup augmentations
    ## ah todo
    
    # setup dataloader
    mit_data = data_loader.get_loader("mit_sceneparsing_benchmark")
    train_data_raw = mit_data("training")
    train_data = data.DataLoader(train_data_raw, batch_size=config["batch_sz"],\
                                 num_workers=config["num_workers"], shuffle=True)
    
    val_data = data.DataLoader(mit_data("validation"), batch_size=config["batch_sz"],\
                                 num_workers=config["num_workers"])

    # setup metrics
    metrics_comp = MetricsComp(train_data_raw.n_classes)
    val_loss_avg_comp = AverageComp()
    
    # setup model
    model = models.get_model("fcn32s", train_data_raw.n_classes).to(device)
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    
    # setup optimizer
    optimizer = get_optimizer("adam")(model.parameters())
    
    print(f"Optimizer: {optimizer}")
    
    # setup learning rate scheduler (optim.lr_scheduler)
    ## ah todo
    loss_fn = CrossEntropy2d()
    
    writer.add_scalar("batch_size", config["batch_sz"])

    i = 0
    # Load a saved checkpoint
    if config.get("resume_ckpoint") is not None:
        if os.path.isfile(config["resume_ckpoint"]):
            print(f"Loading checkpoint: {config['resume_ckpoint']}")
            ckpoint = torch.load(config["resume_ckpoint"])
            model.load_state_dict(ckpoint["model_state"])
            optimizer.load_state_dict(ckpoint["optimizer_state"])
            i = ckpoint["epoch"]
            print(f"Saved epoch loss: {ckpoint['epoch_loss']}, time: {ckpoint['epoch_time']}")
        else:
            raise FileNotFoundError("Unable to load saved checkpoint!, Quitting")
            
    while i < config["epochs"]:
        i += 1
        epoch_loss = 0
        epoch_time = 0
        print(f"Epoch {i}")
        for b_i, (images, labels) in tqdm(enumerate(train_data)):
            # Free memory
            torch.cuda.empty_cache()
            start_ts = time.time()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            model.zero_grad()
            out = model(images) # [batch_sz, n_classes, H=512, W=512]
                       
            loss = loss_fn(out, labels)

            loss.backward()
            optimizer.step()
            epoch_loss += float(loss.item())
            epoch_time += time.time() - start_ts
            #print(f"Batch Loss : {loss} ... time : {time.time() - start_ts}")
            writer.add_scalar(f"{i}_batch_loss", loss.item(), b_i)
            # tmp
            images.detach()
            labels.detach()
        # Run through validation
        model.eval()
        with torch.no_grad():
            ctr = 0
            for _, (images_val, labels_val) in tqdm(enumerate(val_data)):
                if ctr > 10:
                    break
                ctr += 1
                images_val = images_val.to(device)
                labels_val = labels_val.to(device)               
                out = model(images_val) # [batch_sz, n_classes, H=512, W=512]
                val_loss = loss_fn(out, labels_val)
                val_loss_avg_comp.update(val_loss.item())
                
                _, pred = out.max(1)
                metrics_comp.update(label_trues=labels_val.cpu().numpy(), label_preds=pred.cpu().numpy())
        # Add validation results to tensorboard writer
        writer.add_scalar("val_loss", val_loss_avg_comp.avg, i)
        overall_scores, class_iou = metrics_comp.get_results()
        print(f"Scores after epoch {i}: {overall_scores}")
        for k, v in overall_scores.items():
            writer.add_scalar(f"val_metrics/{k}", v, i)
        for k, v in class_iou.items():
            writer.add_scalar(f"val_metrics/cls_iou_{k}", v, i)
        
        # Save the model checkpoint
        ckpoint = {"epoch":i,
                   "model_state": model.state_dict(),
                   "optimizer_state": optimizer.state_dict(),
                   "epoch_loss": epoch_loss,
                   "epoch_time": epoch_time,
                  }
        ckpoint_name = f"checkpoints/{config['exp_name']}_{i}.pkl"
        torch.save(ckpoint, ckpoint_name)
    
            
config = {"exp_name": "fcn32s", "batch_sz": 8, "epochs": 30,\
          "seed": 42, "num_workers": 1, "resume_ckpoint": "checkpoints/fcn32s_15.pkl"}
writer = SummaryWriter(log_dir="checkpoints/logs")

#torch.backends.cudnn.benchmark = True
train(config, writer)

Found 20209 training images
Found 2000 validation images
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)
Loading checkpoint: checkpoints/fcn32s_15.pkl
Saved epoch loss: -3.6297646389862915e+29, time: 635.1497781276703
Epoch 16


285it [08:56,  1.89s/it]

In [None]:
!pwd

In [None]:
!ls -alh checkpoints/logs