In [None]:
%load_ext autoreload
%autoreload 2

import os
import yaml
import time
import torch
import random
import numpy as np
from tqdm import tqdm

from torch.utils import data

from tensorboardX import SummaryWriter

import data_loader
import models
from optimizer import get_optimizer
from loss.cross_entropy_2d import CrossEntropy2d
from utils.metrics import MetricsComp, AverageComp

def train(config, writer=None):
    # setup seeds
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    torch.cuda.empty_cache()
    
    # setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # setup augmentations
    ## ah todo
    
    # setup dataloader
    sim_data = data_loader.get_loader("simulated")()
    train_data = data.DataLoader(sim_data, batch_size=config["batch_sz"],\
                                 num_workers=config["num_workers"], shuffle=True)
    val_data = data.DataLoader(data_loader.get_loader("simulated")(n_imgs=10), batch_size=config["batch_sz"],\
                               num_workers=config["num_workers"])  
    n_classes = sim_data.n_classes
    
    # setup metrics
    metrics_comp = MetricsComp(n_classes)
    val_loss_avg_comp = AverageComp()
    
    # setup model
    model = models.get_model("fcn32s", n_classes).to(device)
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    
    # setup optimizer
    optimizer = get_optimizer("adam")(model.parameters())
    
    print(f"Optimizer: {optimizer}")
    
    # setup learning rate scheduler (optim.lr_scheduler)
    ## ah todo
    weight = torch.tensor([1.0] +  9*[10.0]).to(device) # weight the background 10x less
    loss_fn = CrossEntropy2d(weight=weight)
    
    writer.add_scalar("batch_size", config["batch_sz"])

    i = 0
    # Load a saved checkpoint
    if config.get("resume_ckpoint") is not None:
        if os.path.isfile(config["resume_ckpoint"]):
            print(f"Loading checkpoint: {config['resume_ckpoint']}")
            ckpoint = torch.load(config["resume_ckpoint"])
            model.load_state_dict(ckpoint["model_state"])
            optimizer.load_state_dict(ckpoint["optimizer_state"])
            i = ckpoint["epoch"]
            print(f"Saved epoch: {i}, loss: {ckpoint['epoch_loss']}, time: {ckpoint['epoch_time']}")
        else:
            raise FileNotFoundError("Unable to load saved checkpoint!, Quitting")
            
    while i < config["epochs"]:
        i += 1
        epoch_loss = 0
        epoch_time = 0
        print(f"Epoch {i}")
        # Prepare for training
        model.train()
        optimizer.zero_grad()
        acc_gradients_batch = 0
        for b_i, (images, labels) in tqdm(enumerate(train_data)):
            start_ts = time.time()
            images = images.to(device)
            labels = labels.to(device)
            
            out = model(images) # [batch_sz, n_classes, H=512, W=512]
            acc_gradients_batch += out.shape[0]

            loss = loss_fn(out, labels)    
            loss.backward()
            
            # See if it is time to take a gradient step.
            if acc_gradients_batch >= 100: # gradient step every 100 samples
              optimizer.step()
              optimizer.zero_grad()
              acc_gradients_batch = 0
              
            epoch_loss += float(loss.item())
            epoch_time += time.time() - start_ts
            writer.add_scalar(f"{i}_batch_loss", loss.item(), b_i)
        
        print(f"Epoch Loss: {epoch_loss}")
          
        # Run through validation
        model.eval()
        with torch.no_grad():
            ctr = 0
            for _, (images_val, labels_val) in tqdm(enumerate(val_data)):
                if ctr > 10:
                    break
                ctr += 1
                images_val = images_val.to(device)
                labels_val = labels_val.to(device)               
                out = model(images_val) # [batch_sz, n_classes, H=512, W=512]
                val_loss = loss_fn(out, labels_val)
                val_loss_avg_comp.update(val_loss.item())
                
                _, pred = out.max(1)

                pred_uv, pred_uc = np.unique(pred.cpu().numpy(), return_counts=True)
                ps = [f"{p_uv} ({p_uc})" for p_uv, p_uc in zip(pred_uv, pred_uc)]
                print(f"pred: {ps}")
                
                lbl_uv, lbl_uc = np.unique(labels_val.cpu().numpy(), return_counts=True)
                lbls = [f"{l_uv} ({l_uc})" for l_uv, l_uc in zip(lbl_uv, lbl_uc)]
                print(f"labels_val: {lbls}")

                metrics_comp.update(label_trues=labels_val.cpu().numpy(), label_preds=pred.cpu().numpy())
        # Add validation results to tensorboard writer
        writer.add_scalar("val_loss", val_loss_avg_comp.avg, i)
        overall_scores, class_iou = metrics_comp.get_results()
        print(f"Scores after epoch {i}: {overall_scores}")
        for k, v in overall_scores.items():
            writer.add_scalar(f"val_metrics/{k}", v, i)
        for k, v in class_iou.items():
            writer.add_scalar(f"val_metrics/cls_iou_{k}", v, i)
        
        # Save the model checkpoint
        ckpoint = {"epoch":i,
                   "model_state": model.state_dict(),
                   "optimizer_state": optimizer.state_dict(),
                   "epoch_loss": epoch_loss,
                   "epoch_time": epoch_time,
                  }
        ckpoint_name = f"checkpoints/{config['exp_name']}_{i%3}.pkl"
        torch.save(ckpoint, ckpoint_name)
    
            
config = {"exp_name": "fcn32s", "batch_sz": 8, "epochs": 5000,\
          "seed": 3642, "num_workers": 1, "resume_ckpoint": "checkpoints/fcn32s_0.pkl"}
writer = SummaryWriter(log_dir="checkpoints/logs")

train(config, writer)

Generating 500 images
Generating 10 images
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)
********  weight: tensor([ 1., 10., 10., 10., 10., 10., 10., 10., 10., 10.], device='cuda:0')
********  weight: tensor([ 1., 10., 10., 10., 10., 10., 10., 10., 10., 10.], device='cuda:0')
Loading checkpoint: checkpoints/fcn32s_0.pkl
Saved epoch: 39, loss: 15.323740854859352, time: 92.98452639579773
Epoch 40


63it [01:33,  1.49s/it]

Epoch Loss: 14.441147327423096



1it [00:00,  1.19it/s]

pred: ['0 (1760038)', '2 (203314)', '3 (133800)']
labels_val: ['0 (1903905)', '1 (1639)', '2 (132736)', '3 (58872)']


2it [00:01,  1.82it/s]

pred: ['0 (419286)', '2 (62371)', '3 (42631)']
labels_val: ['0 (458026)', '1 (527)', '2 (41859)', '3 (23876)']
Scores after epoch 40: {'Overall Acc: ': 0.9199081420898437, 'FreqW Acc : ': 1970094.100567627, 'Mean IoU : ': 241148.4}





Epoch 41


63it [01:34,  1.49s/it]

Epoch Loss: 13.719032257795334



1it [00:00,  1.18it/s]

pred: ['0 (1784519)', '2 (194042)', '3 (118591)']
labels_val: ['0 (1903282)', '1 (2237)', '2 (137866)', '3 (53767)']


2it [00:01,  1.82it/s]

pred: ['0 (411127)', '2 (65940)', '3 (47221)']
labels_val: ['0 (450540)', '1 (537)', '2 (52695)', '3 (20516)']
Scores after epoch 41: {'Overall Acc: ': 0.9258615493774414, 'FreqW Acc : ': 3949670.778859329, 'Mean IoU : ': 485418.1}





Epoch 42


63it [01:34,  1.50s/it]

Epoch Loss: 13.329099968075752



1it [00:00,  1.19it/s]

pred: ['0 (1756799)', '2 (153091)', '3 (187262)']
labels_val: ['0 (1921142)', '1 (1592)', '2 (104647)', '3 (69771)']


2it [00:01,  1.81it/s]

pred: ['0 (436478)', '2 (43410)', '3 (44400)']
labels_val: ['0 (475721)', '1 (704)', '2 (31095)', '3 (16768)']
Scores after epoch 42: {'Overall Acc: ': 0.9234747568766276, 'FreqW Acc : ': 5960611.617123414, 'Mean IoU : ': 726250.1}





Epoch 43


63it [01:34,  1.49s/it]

Epoch Loss: 13.141244605183601



1it [00:00,  1.16it/s]

pred: ['0 (1808880)', '2 (131790)', '3 (156482)']
labels_val: ['0 (1940995)', '1 (1819)', '2 (89042)', '3 (65296)']


2it [00:01,  1.80it/s]

pred: ['0 (471908)', '2 (27120)', '3 (25260)']
labels_val: ['0 (492948)', '1 (268)', '2 (17748)', '3 (13324)']
Scores after epoch 43: {'Overall Acc: ': 0.9260461807250977, 'FreqW Acc : ': 8074454.550830651, 'Mean IoU : ': 971029.8}





Epoch 44


63it [01:33,  1.49s/it]

Epoch Loss: 13.267536997795105



1it [00:00,  1.17it/s]

pred: ['0 (1825859)', '2 (137281)', '3 (134012)']
labels_val: ['0 (1938180)', '1 (1588)', '2 (102213)', '3 (55171)']


2it [00:01,  1.81it/s]

pred: ['0 (467318)', '2 (17190)', '3 (39780)']
labels_val: ['0 (488829)', '1 (456)', '2 (13934)', '3 (21069)']
Scores after epoch 44: {'Overall Acc: ': 0.9288525390625, 'FreqW Acc : ': 10195371.661184005, 'Mean IoU : ': 1217465.6}





Epoch 45


63it [01:33,  1.49s/it]

Epoch Loss: 13.519232332706451



1it [00:00,  1.17it/s]

pred: ['0 (1792020)', '2 (149461)', '3 (155671)']
labels_val: ['0 (1929182)', '1 (1849)', '2 (105781)', '3 (60340)']


2it [00:01,  1.80it/s]

pred: ['0 (448327)', '2 (45000)', '3 (30961)']
labels_val: ['0 (484694)', '1 (368)', '2 (30054)', '3 (9172)']
Scores after epoch 45: {'Overall Acc: ': 0.9288066864013672, 'FreqW Acc : ': 12260962.51689682, 'Mean IoU : ': 1460886.6}





Epoch 46


63it [01:33,  1.49s/it]

Epoch Loss: 12.814737424254417



1it [00:00,  1.16it/s]

pred: ['0 (1784246)', '2 (173913)', '3 (138993)']
labels_val: ['0 (1888594)', '1 (1271)', '2 (135210)', '3 (72077)']


2it [00:01,  1.79it/s]

pred: ['0 (476378)', '2 (16200)', '3 (31710)']
labels_val: ['0 (497766)', '1 (527)', '2 (10322)', '3 (15673)']
Scores after epoch 46: {'Overall Acc: ': 0.93061888558524, 'FreqW Acc : ': 14320309.586826544, 'Mean IoU : ': 1707693.1}





Epoch 47


63it [01:33,  1.49s/it]

Epoch Loss: 12.492851302027702



1it [00:00,  1.19it/s]

pred: ['0 (1788359)', '2 (143761)', '3 (165032)']
labels_val: ['0 (1931133)', '1 (1894)', '2 (95001)', '3 (69124)']


2it [00:01,  1.82it/s]

pred: ['0 (445387)', '2 (46381)', '3 (32520)']
labels_val: ['0 (482625)', '1 (249)', '2 (31242)', '3 (10172)']
Scores after epoch 47: {'Overall Acc: ': 0.9297134399414062, 'FreqW Acc : ': 16378950.493991993, 'Mean IoU : ': 1949750.4}





Epoch 48


36it [00:54,  1.48s/it]

In [None]:
!pwd

In [None]:
!ls -alh checkpoints/logs