# Example training notebook file

In [2]:
# add work directory
import os
import sys
import torch

# you should add root directory
sys.path.append(os.path.dirname("../"))

## Loading Figaro dataset using get_loader

In [3]:
# importing dataloader

from data import get_loader

# you have to predefine transforms to load dataset
# this transforms images and masks while loading
# example transforms

from utils import joint_transforms as jnt_trnsf
import torchvision.transforms as std_trnsf


# transforms that are applied to both images and masks
# includes geometrical changes like flip
# implemented in ./utils/joint_transforms.py
joint_transforms = jnt_trnsf.Compose([
    jnt_trnsf.Resize(256),
    jnt_trnsf.RandomRotate(5),
    jnt_trnsf.CenterCrop(224),
    jnt_trnsf.RandomHorizontallyFlip()
])


# transforms that are applied to only images
# this includes color jittering, normalizing, blurring, etc
# use torchvision.transforms, or implement additional transforms in 'utils'
train_image_transforms = std_trnsf.Compose([
    std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05),
    std_trnsf.ToTensor(),
    std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


test_image_transforms = std_trnsf.Compose([
    std_trnsf.ToTensor(),
    std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# transforms that are applied to only masks
mask_transforms = std_trnsf.Compose([
    std_trnsf.ToTensor()
    ])

# predifine other needed arguments
batch_size = 4
num_workers = 1
data_dir = '../data/Figaro1k/'

In [4]:
train_loader = get_loader(dataset='figaro',
                          data_dir=data_dir,
                          train=True,
                          joint_transforms=joint_transforms,
                          image_transforms=train_image_transforms,
                          mask_transforms=mask_transforms,
                          batch_size=batch_size,
                          shuffle=False,
                          num_workers=num_workers)

test_loader = get_loader(dataset='figaro',
                         data_dir=data_dir,
                         train=False,
                         joint_transforms=joint_transforms,
                         image_transforms=test_image_transforms,
                         mask_transforms=mask_transforms,
                         batch_size=1,
                         shuffle=False,
                         num_workers=num_workers)

In [5]:
# two ways of iterating dataloader

# 1. using for loop

for step, (data, target) in enumerate(train_loader):
    break
step, data.size(), target.size() 


(0, torch.Size([4, 3, 224, 224]), torch.Size([4, 1, 224, 224]))

In [6]:
# 2. using iterator
batch_iterator = iter(train_loader)

for _ in range(10):
    data, target = batch_iterator.next()
data.size(), target.size()

(torch.Size([4, 3, 224, 224]), torch.Size([4, 1, 224, 224]))

Process Process-2:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/opt/conda/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._p

## Importing model

In [7]:
from networks import get_network

# you can add your own model in get_network fuction in ./networks/__init__.py 
model = get_network(name='SegNet', num_class = 1)

# or just import directly
from networks.segnet import SegNet
model = SegNet(num_class = 1)

## Defining Optimizer & Scheduler & loss & device

In [8]:
# torch.optim
optimizer = torch.optim.Adam(model.parameters(), 
                             lr = 0.001, 
                             betas=(0.5, 0.999), # beta1 acts like 'momentum' in SGD
                            )

# torch.
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# torch.nn
loss = torch.nn.BCEWithLogitsLoss()

# flag to use gpu or not
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Using Pytorch Ignite

In [9]:
# ignite moduels
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Loss

# custom modules
from utils.metrics import Accuracy, MeanIU

In [10]:
# trainer and evaluator
trainer = create_supervised_trainer(model, optimizer, loss, device=device)
evaluator = create_supervised_evaluator(model,
                                        metrics={
                                            'pix-acc': Accuracy(),
                                            'mean-iu': MeanIU(0.5),
                                            'loss': Loss(loss)
                                            },
                                        device=device)

In [11]:
# saving training state if you want
from utils import update_state
state = update_state(model.state_dict(), 0, 0, 0, 0)

In [None]:
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(trainer):
    num_iter = (trainer.state.iteration - 1) % len(train_loader) + 1
    if num_iter % 20 == 0:
        print("Epoch[{}] Iter[{:03d}] Loss: {:.2f}".format(
            trainer.state.epoch, num_iter, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    # evaluate training set
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print("Training Results - Epoch: {}  Pix-acc: {:.3f} MeanIU: {:.3f} Avg-loss: {:.3f}".format(
        trainer.state.epoch, metrics['pix-acc'], metrics['mean-iu'], metrics['loss']))

    # update state
    update_state(model.state_dict(), metrics['loss'], state['val_loss'], state['val_pix_acc'], state['val_miu'])

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    # evaluate test(validation) set
    evaluator.run(test_loader)
    metrics = evaluator.state.metrics
    print("Validation Results - Epoch: {}  Pix-acc: {:.2f} MeanIU: {:.3f} Avg-loss: {:.2f}".format(
        trainer.state.epoch, metrics['pix-acc'], metrics['mean-iu'], metrics['loss']))

    # update scheduler
    scheduler.step(metrics['loss'])

    # update and save state
    update_state(model.state_dict(), state['train_loss'], metrics['loss'], metrics['pix-acc'], metrics['mean-iu'])
    save_ckpt_file(
        ckpt_path.format(network=networks, epoch=trainer.state.epoch),
        state)

trainer.run(train_loader, max_epochs=100)

Process Process-3:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/opt/conda/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._p

## To do this in one-queue

```bash
# run this in root

python3 main.py \
  --networks segnet \
  --scheduler ReduceLROnPlateau \
  --batch_size 4 \
  --epochs 100 \
  --lr 1e-3 \
  --num_workers 4 \
  --optimizer adam \
  --momentum 0.5
```