In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import numpy as np

from model_utils import lr_scheduler, restore_weights, save_model
from data_generator import SemanticSegmentationDataset
from train import build_model, train_loop
from evaluate import evaluate, show_sample_segmentation

from pathlib import Path

### Define parameters

In [21]:
batch_size = 1
shuffle = True
epoch_start = 0
epochs = 100
T_save = 1
T_print = 100
num_classes = 4
cuda = True

pretrained_weight_fname = 'resnet_v2-300epoch.pth'
weights_dir = 'weights/'
data_dir = 'drinks/'
train_gt_fname = 'segmentation_train.npy'
test_gt_fname = 'segmentation_test.npy'

# make path if not exist
Path(weights_dir).mkdir(parents=True, exist_ok=True)

### Initialize dataloaders

In [12]:
trainset = SemanticSegmentationDataset(data_dir, train_gt_fname, cuda=cuda)
trainloader = DataLoader(trainset,
                         batch_size=batch_size,
                         shuffle=shuffle)
testset = SemanticSegmentationDataset(data_dir, test_gt_fname, cuda=cuda)
testloader = DataLoader(testset,
                        batch_size=batch_size,
                        shuffle=shuffle)

## Training

### Initialize model

In [19]:
# initialize model
channels, height, width = 3, 480, 640
input_shape = (batch_size, channels, height, width)

model, backbone = build_model(input_shape=input_shape,
                              n_classes=num_classes,
                              weights_dir=weights_dir,
                              pretrained_weight_fname=pretrained_weight_fname,
                              cuda=cuda)

Restoring weights from weights/resnet_v2-100epoch.pth


### Initialize training parameters

In [None]:
# initialize training parameters
# based on https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
criterion = nn.CrossEntropyLoss()   # categorical crossentropy
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_scheduler)

### Train loop

In [None]:
train_loop(model, criterion, optimizer, scheduler, trainloader,
           epochs=epochs, epoch_start=epoch_start,
           T_print=T_print, T_save=T_save)

Learning rate: 0.001




200-th minibatch	loss: 0.024956	68.73466062545776 secs elapsed
400-th minibatch	loss: 0.048901	137.26825499534607 secs elapsed
600-th minibatch	loss: 0.074410	205.35577869415283 secs elapsed
800-th minibatch	loss: 0.099458	273.7226142883301 secs elapsed
1000-th minibatch	loss: 0.123409	342.10432481765747 secs elapsed
Learning rate: 0.001
epoch 201	loss: 24.681771
342.1044669151306 secs elapsed
200-th minibatch	loss: 0.024225	410.50092220306396 secs elapsed
400-th minibatch	loss: 0.048879	479.0910303592682 secs elapsed
600-th minibatch	loss: 0.072854	547.5881795883179 secs elapsed
800-th minibatch	loss: 0.096035	615.9639573097229 secs elapsed
1000-th minibatch	loss: 0.121564	684.2775404453278 secs elapsed
Learning rate: 0.001
epoch 202	loss: 24.312790
684.277651309967 secs elapsed
200-th minibatch	loss: 0.024069	752.5967044830322 secs elapsed
400-th minibatch	loss: 0.047832	820.9738640785217 secs elapsed
600-th minibatch	loss: 0.072155	889.3108744621277 secs elapsed
800-th minibatch	los

800-th minibatch	loss: 0.077374	7097.65584397316 secs elapsed
1000-th minibatch	loss: 0.096172	7166.02033162117 secs elapsed
Learning rate: 0.001
epoch 221	loss: 19.234474
7166.020485877991 secs elapsed
200-th minibatch	loss: 0.019481	7234.3568546772 secs elapsed
400-th minibatch	loss: 0.038412	7302.711222887039 secs elapsed
600-th minibatch	loss: 0.056933	7371.055485010147 secs elapsed
800-th minibatch	loss: 0.076528	7439.425366163254 secs elapsed
1000-th minibatch	loss: 0.095072	7507.713925361633 secs elapsed
Learning rate: 0.001
epoch 222	loss: 19.014339
7507.714034318924 secs elapsed
200-th minibatch	loss: 0.018554	7576.050037622452 secs elapsed
400-th minibatch	loss: 0.037297	7644.345289707184 secs elapsed
600-th minibatch	loss: 0.055677	7712.707197189331 secs elapsed
800-th minibatch	loss: 0.075159	7781.027522087097 secs elapsed
1000-th minibatch	loss: 0.094148	7849.375440597534 secs elapsed
Learning rate: 0.001
epoch 223	loss: 18.829624
7849.375777721405 secs elapsed
200-th mini

200-th minibatch	loss: 0.016572	14058.159142255783 secs elapsed
400-th minibatch	loss: 0.031774	14126.437273025513 secs elapsed
600-th minibatch	loss: 0.046592	14194.74758720398 secs elapsed
800-th minibatch	loss: 0.062632	14263.05306816101 secs elapsed
1000-th minibatch	loss: 0.078743	14331.32479929924 secs elapsed
Learning rate: 0.0005
epoch 242	loss: 15.748660
14331.324951171875 secs elapsed
200-th minibatch	loss: 0.015500	14398.949434757233 secs elapsed
400-th minibatch	loss: 0.031329	14466.461828947067 secs elapsed
600-th minibatch	loss: 0.047267	14533.978209257126 secs elapsed
800-th minibatch	loss: 0.062539	14601.732395887375 secs elapsed
1000-th minibatch	loss: 0.077904	14670.064280748367 secs elapsed
Learning rate: 0.0005
epoch 243	loss: 15.580856
14670.064632415771 secs elapsed
200-th minibatch	loss: 0.014889	14738.386830806732 secs elapsed
400-th minibatch	loss: 0.029711	14806.662728786469 secs elapsed
600-th minibatch	loss: 0.045428	14874.986599206924 secs elapsed
800-th mi

In [0]:
m_iou, m_pla = evaluate(model, testloader)

Saving weights to drive/My Drive/coe197f/weights/resnet_v2-4th-epoch.pth


In [None]:
show_sample_segmentation(model, testloader)