In [None]:
%load_ext autoreload
%autoreload 2

import os
import yaml
import time
import torch
import random
import numpy as np

from torch.utils import data

from tensorboardX import SummaryWriter

import data_loader
import models
from optimizer import get_optimizer
from loss.cross_entropy_2d import CrossEntropy2d

SEED = 1337

def train(config, writer=None):
    # setup seeds
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed(config["seed"])
    np.random.seed(config["seed"])

    # setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # setup augmentations
    ## ah todo
    
    # setup dataloader
    mit_data = data_loader.get_loader("mit_sceneparsing_benchmark")
    train_data_raw = mit_data("training")
    train_data = data.DataLoader(train_data_raw, batch_size=config["batch_sz"],\
                                 num_workers=config["num_workers"], shuffle=True)
    
    val = data.DataLoader(mit_data("validation"), batch_size=config["batch_sz"],\
                                 num_workers=config["num_workers"])

    # setup metrics
    ## ah todo
    
    # setup model
    model = models.get_model("fcn32s", train_data_raw.n_classes).to(device)
    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    
    # setup optimizer
    optimizer = get_optimizer("adam")(model.parameters())
    
    print(f"Optimizer: {optimizer}")
    
    # setup learning rate scheduler (optim.lr_scheduler)
    ## ah todo
    loss_fn = CrossEntropy2d()
        
    i = 0
    while i < config["epochs"]:
        i += 1
        print(f"Epoch {i}")
        for (images, labels) in train_data:
            # Free memory
            torch.cuda.empty_cache()
            start_ts = time.time()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()

            out = model(images)
                       
            loss = loss_fn(out, labels)

            loss.backward()
            optimizer.step()
            
            print(f"Batch Loss : {loss} ... time : {time.time() - start_ts}")
            
            
config = {"batch_sz": 8, "epochs": 10, "seed": 42, "num_workers": 1}

train(config)

Found 20209 training images
Found 2000 validation images
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)
Epoch 1
Batch Loss : 5.0217132568359375 ... time : 1.6245520114898682
Batch Loss : 4.0412678718566895 ... time : 1.706756830215454
Batch Loss : 4.938187122344971 ... time : 1.703571081161499
Batch Loss : 4.866754055023193 ... time : 1.720219612121582
Batch Loss : 4.109326362609863 ... time : 1.6966876983642578
Batch Loss : 7.257319450378418 ... time : 1.6827590465545654
Batch Loss : 4.650835990905762 ... time : 1.7122094631195068
Batch Loss : 4.490235805511475 ... time : 1.7124762535095215
Batch Loss : 3.7909131050109863 ... time : 1.7008717060089111
Batch Loss : 7.016555309295654 ... time : 1.6941075325012207
Batch Loss : 3.822023391723633 ... time : 1.7022602558135986
Batch Loss : 4.353139877319336 ... time : 1.7137632369995117
Batch Loss : 4.478705406188965 ... time : 1.7235069274902344
Batch Loss :

In [None]:
print("foo")