In [None]:
%load_ext autoreload
%autoreload 2

import os
import yaml
import time
import random
import numpy as np
from tqdm import tqdm

import torch
from torch.utils import data
from tensorboardX import SummaryWriter

# Datasets
from data_loader.mit_scene_parsing import MITSceneParsingLoader
from data_loader.simulated_data import SimulatedDataLoader
# Models
from models.fcn import FCN32s
# Loss
from loss.cross_entropy_2d import CrossEntropy2d
# Metrics
from utils.metrics import MetricsComp, AverageComp

def train(config, data_set):
    # setup seeds
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    torch.cuda.empty_cache()
    
    # setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # setup augmentations
    ## ah todo
    
    # setup dataloader
    train_data_set = data_set(split="training")
    train_data = data.DataLoader(train_data_set, batch_size=config["batch_sz"],\
                                 num_workers=config["num_workers"], shuffle=True)
    val_data = data.DataLoader(data_set(split="validation"), batch_size=config["batch_sz"],\
                               num_workers=config["num_workers"])  
    n_classes = train_data_set.n_classes
    
    # setup metrics
    metrics_comp = MetricsComp(n_classes)
    val_loss_avg_comp = AverageComp()
    
    # setup model
    model = FCN32s(n_classes).to(device)
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    
    # setup optimizer
    optimizer = torch.optim.Adam(model.parameters())
        
    # setup learning rate scheduler (optim.lr_scheduler)
    ## ah todo
    
    # setup loss
    try:
      loss_fn_weight = train_data_set.class_imbalance_weight.to(device)
    except:
      loss_fn_weight = None
  
    #weight = torch.tensor([1.0] +  9*[10.0]).to(device) # weight the background 10x less
    loss_fn = CrossEntropy2d(weight=loss_fn_weight)
    
    # setup tensorboard writer & checkpoint dir 
    exp_name = config.get("exp_name", "")
    uniq_name = f"{train_data_set.name}_{model.name}_{exp_name}_logs"
    writer = SummaryWriter(log_dir="checkpoints/" + uniq_name)

    writer.add_scalar("batch_size", config["batch_sz"])

    i = 0
    # Load a saved checkpoint
    if config.get("resume_ckpoint") is not None:
        if os.path.isfile(config["resume_ckpoint"]):
            print(f"Loading checkpoint: {config['resume_ckpoint']}")
            ckpoint = torch.load(config["resume_ckpoint"])
            model.load_state_dict(ckpoint["model_state"])
            optimizer.load_state_dict(ckpoint["optimizer_state"])
            i = ckpoint["epoch"]
            print(f"Saved epoch: {i}, loss: {ckpoint['epoch_loss']}, time: {ckpoint['epoch_time']}")
        else:
            raise FileNotFoundError("Unable to load saved checkpoint!, Quitting")
            
    while i < config["epochs"]:
        i += 1
        epoch_loss = 0
        epoch_time = 0
        print(f"Epoch {i}")
        # Prepare for training
        model.train()
        optimizer.zero_grad()
        acc_gradients_batch = 0
        
        for b_i, (images, labels) in tqdm(enumerate(train_data)):
            start_ts = time.time()
            images = images.to(device)
            labels = labels.to(device)
            
            out = model(images) # [batch_sz, n_classes, H=512, W=512]
            acc_gradients_batch += out.shape[0]

            loss = loss_fn(out, labels)    
            loss.backward()
            
            # See if it is time to take a gradient step.
            if acc_gradients_batch >= 256: # gradient step every N samples
              optimizer.step()
              optimizer.zero_grad()
              acc_gradients_batch = 0
              
            epoch_loss += float(loss.item())
            epoch_time += time.time() - start_ts
            #writer.add_scalar(f"{i}_batch_loss", loss.item(), b_i)
        
        print(f"Epoch Loss: {epoch_loss}")
        writer.add_scalar(f"epoch_loss", epoch_loss, i)

        # Run through validation
        model.eval()
        with torch.no_grad():
            ctr = 0
            for _, (images_val, labels_val) in tqdm(enumerate(val_data)):
                if ctr > 10:
                    break
                ctr += 1
                images_val = images_val.to(device)
                labels_val = labels_val.to(device)               
                out = model(images_val) # [batch_sz, n_classes, H=512, W=512]
                val_loss = loss_fn(out, labels_val)
                val_loss_avg_comp.update(val_loss.item())
                
                _, pred = out.max(1)

                pred_uv, pred_uc = np.unique(pred.cpu().numpy(), return_counts=True)
                ps = [f"{p_uv} ({p_uc})" for p_uv, p_uc in zip(pred_uv, pred_uc)]
                print(f"pred: {ps}")
                
                lbl_uv, lbl_uc = np.unique(labels_val.cpu().numpy(), return_counts=True)
                lbls = [f"{l_uv} ({l_uc})" for l_uv, l_uc in zip(lbl_uv, lbl_uc)]
                print(f"labels_val: {lbls}")

                metrics_comp.update(label_trues=labels_val.cpu().numpy(), label_preds=pred.cpu().numpy())
        # Add validation results to tensorboard writer
        writer.add_scalar("val_loss", val_loss_avg_comp.avg, i)
        overall_scores, class_iou = metrics_comp.get_results()
        print(f"Scores after epoch {i}: {overall_scores}")
        for k, v in overall_scores.items():
            writer.add_scalar(f"val_metrics/{k}", v, i)
        for k, v in class_iou.items():
            writer.add_scalar(f"val_metrics/cls_iou_{k}", v, i)
        
        if (config["save_ckpoint"]):
          # Save the model checkpoint
          ckpoint = {"epoch":i,
                     "model_state": model.state_dict(),
                     "optimizer_state": optimizer.state_dict(),
                     "epoch_loss": epoch_loss,
                     "epoch_time": epoch_time,
                    }
          ckpoint_name = f"checkpoints/{uniq_name}_{i%3}.pkl"
          torch.save(ckpoint, ckpoint_name)
    
#TMP batch size reduced to 8->4            
config = {"exp_name": "pad_fix", "batch_sz": 8, "epochs": 500,\
          "seed": 3642, "num_workers": 1, "save_ckpoint": True}#, "resume_ckpoint": "checkpoints/mit_sceneparsing_FCN32s_full_logs_0.pkl", }

#train(config, SimulatedDataLoader)
train(config, MITSceneParsingLoader)

Found 20209 training images
Found 2000 validation images
Loss weight: None
Epoch 1


463it [05:41,  1.35it/s]

In [None]:
!pwd

In [None]:
!ls -alh checkpoints/logs