In [2]:
import sys
sys.path.append('..')

import torch
from torch import nn
from torch.utils.data import DataLoader
import os
from tqdm.auto import tqdm, trange
from collections import OrderedDict
from PIL import Image
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt

from utils.dataset import SegmentationDataset, transforms as T
from utils.vis import show_torch_batch
from models import deeplabv3

In [3]:
disable_cuda = False
num_classes = 7
num_epochs = 3
batch_size = 4

if not disable_cuda and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [4]:
ds_train = SegmentationDataset("../data/datasets/Calvert_2012", 
                         transform=T.train_transforms, 
                         target_transform=T.train_target_transforms)

ds_val = SegmentationDataset("../data/datasets/Calvert_2015", 
                         transform=T.test_transforms, 
                         target_transform=T.test_target_transforms)

dataloader_opts = {
    "batch_size": batch_size, 
    "pin_memory": True, 
    "drop_last": True,
    "num_workers": 8
}
dataloaders = {
    'train': DataLoader(ds_train, shuffle=True, **dataloader_opts),
    'eval': DataLoader(ds_val, shuffle=False, **dataloader_opts)
}

In [7]:
model = deeplabv3.create_model(num_classes)
model = model.to(device)

# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

criterion = nn.CrossEntropyLoss()

In [8]:
def overfit_model_to_single_batch(model, x, y, optimizer, criterion, num_epochs):
    info = OrderedDict()
    
    x = x.to(device)
    y = y.to(device)
    model.train()
    
    pbar = trange(1, num_epochs+1)
    for epoch in pbar:
        optimizer.zero_grad()

        out = model(x)
        loss = criterion(out['out'], y)

        loss.backward()
        optimizer.step()

        info['loss'] = loss.detach().cpu().item()
        pbar.set_postfix(info)
        
        if epoch % 10 == 0:
            show_torch_batch(x, y, out['out'])
    
    return model
    
# model = overfit_model_to_single_batch(model, *next(iter(dataloaders['train'])), optimizer, criterion, 100)

In [9]:
def train_model(model, dataloaders, optimizer, criterion, num_epochs, metrics, save_path):
    info = OrderedDict()
    
    for epoch in range(1, num_epochs+1):
        print('Epoch {}/{}'.format(epoch, num_epochs))

        for phase in ['train']:#, 'eval']:
            pbar = tqdm(iter(dataloaders[phase]), desc=phase)
            for i, (x, y) in enumerate(pbar):
                x = x.to(device)
                y = y.to(device)
                
                optimizer.zero_grad()
                
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                out = model(x)
                loss = criterion(out['out'], y)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    
                if (i+1) % 1000 == 0:
                    show_torch_batch(x, y, out['out'])

                info['loss'] = loss.detach().cpu().item()
                pbar.set_postfix(info)
                
    return model

In [None]:
model = train_model(model, dataloaders, optimizer, criterion, num_epochs, lambda a: None, './model.pth')

Epoch 1/3


HBox(children=(FloatProgress(value=0.0, description='train', max=4406.0, style=ProgressStyle(description_width…

In [None]:
# TODO
# Evaluation metrics
# Combined big dataset
# Checkpoint weights