In [1]:
%%capture
!pip install neptune-client psutil
!git clone https://github.com/Cho-D-YoungRae/URP_PD.git
%cd URP_PD
!pwd

In [2]:
import dataset
import object_detection
from utils import *

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as FT

import os
import json
from PIL import Image
import numpy as np
import argparse
from tqdm.auto import tqdm
import time
from datetime import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using <{device}> device")

Using <cuda> device


In [3]:
# ====== constants ======#
label_map = {'background': 0, 'person': 1}
rev_label_map = {v: k for k, v in label_map.items()} 

## setting

In [4]:
from torch.backends import cudnn
cudnn.benchmark = True

# ====== Random Seed Initialization ====== #
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

parser = argparse.ArgumentParser()
args = parser.parse_args("")
args.baselineID = 5

# ====== Dataset ====== #
args.img_type = 'lwir'
args.val_split = 0.1

# ====== Model ====== #
args.base_model = 'VGGBase'
args.n_classes = len(label_map)
args.one_ch_option = "mean"


# ====== Optimizer & Training ====== #
args.optim = 'SGD' 
args.lr = 5e-4
args.twice_b_lr = True
args.momentum = 0.9
args.weight_decay = 5e-4

args.epochs = 150
args.train_batch_size = 32
args.test_batch_size = 64

args.decay_lr_at = [int(args.epochs/6)*4,
                    int(args.epochs/6)*5]
args.decay_lr_to = 0.1

## neptune init

In [5]:
import neptune.new as neptune

api_token = 
run = neptune.init(project='jodyr/urp',
                #    run=''
                   api_token=api_token)

run["parameters"] = vars(args)


https://app.neptune.ai/jodyr/urp/e/PD-21
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


## train

In [6]:
def train(train_loader, model, criterion, optimizer):
    """
    One epoch's training.

    :param train_loader: DataLoader for training data
    :param model: model
    :param criterion: MultiBox loss
    :param optimizer: optimizer
    :param epoch: epoch number
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.train()  # training mode enables dropout

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    print_freq = len(train_loader) // 4

    start = time.time()

    # Batches
    for i, (images, bboxes, category_ids, is_crowds) in enumerate(train_loader):
        data_time.update(time.time() - start)

        images = images.to(device)  # (batch_size (N), 3, 300, 300)
        bboxes = [b.to(device) for b in bboxes]
        category_ids = [c.to(device) for c in category_ids]

        # Forward prop.
        predicted_locs, predicted_scores = model(images)  # (N, 8732, 4), (N, 8732, n_classes)

        # Loss
        loss = criterion(predicted_locs, predicted_scores, bboxes, category_ids)  # scalar

        # Backward prop.
        optimizer.zero_grad()
        loss.backward()

        # Update model
        optimizer.step()

        losses.update(loss.item(), images.size(0))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print(f'[{i}/{len(train_loader)}]\t'
                  f'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  f'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  f'Loss {losses.val:.4f} ({losses.avg:.4f})\t')
    del predicted_locs, predicted_scores, images, bboxes, category_ids

    train_loss = losses.avg
    return train_loss        

## validation

In [7]:
def validation(val_loader, model, criterion):
    model.eval()

    num_batches = len(val_loader)
    losses = AverageMeter()
    with torch.no_grad():
        for i, (images, bboxes, category_ids, _) in enumerate(val_loader):
            images = images.to(device)
            bboxes = [b.to(device) for b in bboxes]
            category_ids = [l.to(device) for l in category_ids]

            predicted_locs, predicted_scores = model(images)
            loss = criterion(predicted_locs, predicted_scores, bboxes, category_ids).item()

            losses.update(loss, images.size(0))

    val_loss = losses.avg
    return val_loss

## checkpoint

In [8]:
checkpoint = os.path.join('/content/drive/MyDrive/2021.summer_URP/PD/checkpoint',
                          str(args.baselineID)+'.pth.tar')
checkpoint = checkpoint if os.path.isfile(checkpoint) else None
print(f"checkpoint: {checkpoint}")

checkpoint: None


In [9]:
if checkpoint is None:
    start_epoch = 1
    lr = args.lr
    model = object_detection.SSD300(n_classes=args.n_classes,
                                    base=args.base_model,
                                    one_ch_option=args.one_ch_option)
    if args.twice_b_lr:
        biases = list()
        not_biases = list()
        for param_name, param in model.named_parameters():
            if param.requires_grad:
                if param_name.endswith('.bias'):
                    biases.append(param)
                else:
                    not_biases.append(param)
        optimizer = getattr(torch.optim, args.optim)(params=[{'params': biases, 'lr': 2 * lr}, 
                                                            {'params': not_biases}],
                                                     lr=lr,
                                                     momentum=args.momentum,
                                                     weight_decay=args.weight_decay)
    else:
        optimizer = getattr(torch.optim, args.optim)(params=model.parameters(),
                                                     lr=lr,
                                                     momentum=args.momentum,
                                                     weight_decay=args.weight_decay)

else:
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    print('\nLoaded checkpoint from epoch %d.\n' % start_epoch)
    model = checkpoint['model']
    optimizer = checkpoint['optimizer']


model = model.to(device)
criterion = object_detection.MultiBoxLoss(priors_cxcy=model.priors_cxcy).to(device)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))



Loaded base model.





## dataset init

In [10]:
train_dataset = dataset.KaistPDDataset()

total_train_size = len(train_dataset)
val_size = int(total_train_size * args.val_split)
train_val_split = [total_train_size - val_size, val_size]
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, 
                                                           train_val_split)

In [11]:
workers = 4
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=args.train_batch_size, 
                                           shuffle=True,
                                           collate_fn=dataset.collate_fn,
                                           num_workers=workers,
                                           pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, 
                                         batch_size=args.test_batch_size, 
                                         shuffle=True,
                                         collate_fn=dataset.collate_fn,
                                         num_workers=workers,
                                         pin_memory=True)

  cpuset_checked))


In [12]:
checkpoint_dir = '/content/drive/MyDrive/2021.summer_URP/PD/checkpoint'
checkpoint_path = os.path.join(checkpoint_dir,
                               str(args.baselineID)+'.pth.tar')

epochs = args.epochs
decay_lr_at = args.decay_lr_at
val_save_freq = 5


# Epochs
for epoch in range(start_epoch, epochs+1):
    print(f"# ====== Epoch {epoch} ====== # {datetime.now()}")
    # Decay learning rate at particular epochs
    if epoch in decay_lr_at:
        adjust_learning_rate(optimizer, args.decay_lr_to)

    # One epoch's training
    train_loss = train(train_loader=train_loader,
                        model=model,
                        criterion=criterion,
                        optimizer=optimizer)
    run['train/loss'].log(train_loss)
    if epoch % val_save_freq == 0:
        val_loss = validation(val_loader, model, criterion)
        print(f"*** checkpoint train loss: {train_loss}, val loss: {val_loss}***")
        run['val/loss'].log(val_loss)
        save_checkpoint(epoch, model, optimizer, checkpoint_path)



  cpuset_checked))
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[0/246]	Batch Time 107.794 (107.794)	Data Time 102.118 (102.118)	Loss 17.7478 (17.7478)	
[61/246]	Batch Time 5.560 (9.930)	Data Time 5.031 (9.312)	Loss 8.3241 (8.3509)	
[122/246]	Batch Time 0.531 (9.066)	Data Time 0.000 (8.488)	Loss 9.9937 (9.1169)	
[183/246]	Batch Time 6.327 (8.782)	Data Time 5.793 (8.219)	Loss 8.6224 (9.5157)	
[244/246]	Batch Time 21.961 (8.528)	Data Time 21.417 (7.972)	Loss 25.5514 (9.7013)	
*** checkpoint train loss: 9.726133549988717, val loss: 18.46685275616378***
[0/246]	Batch Time 4.036 (4.036)	Data Time 3.363 (3.363)	Loss 18.5879 (18.5879)	
[61/246]	Batch Time 0.609 (0.997)	Data Time 0.004 (0.376)	Loss 6.1693 (11.5769)	
[122/246]	Batch Time 0.593 (0.967)	Data Time 0.001 (0.348)	Loss 5.2752 (9.4917)	
[183/246]	Batch Time 0.639 (0.944)	Data Time 0.001 (0.327)	Loss 7.4359 (8.2861)	
[244/246]	Batch Time 0.530 (0.935)	Data Time 0.000 (0.320)	Loss 4.2154 (7.4895)	
*** checkpoint train loss: 7.481080447048028, val loss: 4.363192404258702***
[0/246]	Batch Time 3.997 (

Experiencing connection interruptions. Will try to reestablish communication with Neptune.
Communication with Neptune restored!


[122/246]	Batch Time 0.585 (0.951)	Data Time 0.005 (0.334)	Loss 2.1860 (2.1412)	
[183/246]	Batch Time 0.612 (0.936)	Data Time 0.001 (0.318)	Loss 2.0764 (2.1319)	
[244/246]	Batch Time 0.533 (0.924)	Data Time 0.000 (0.309)	Loss 2.1238 (2.1323)	
*** checkpoint train loss: 2.1322176504699546, val loss: 2.140442431997188***
[0/246]	Batch Time 4.515 (4.515)	Data Time 3.884 (3.884)	Loss 2.4240 (2.4240)	
[61/246]	Batch Time 0.638 (0.995)	Data Time 0.003 (0.374)	Loss 2.0012 (2.1442)	
[122/246]	Batch Time 1.556 (0.967)	Data Time 0.956 (0.348)	Loss 2.3013 (2.1303)	
[183/246]	Batch Time 1.389 (0.953)	Data Time 0.763 (0.334)	Loss 2.0761 (2.1211)	
[244/246]	Batch Time 0.532 (0.940)	Data Time 0.000 (0.323)	Loss 2.5366 (2.1101)	
*** checkpoint train loss: 2.1115285854417283, val loss: 2.1489962097705435***
[0/246]	Batch Time 4.081 (4.081)	Data Time 3.468 (3.468)	Loss 1.9053 (1.9053)	
[61/246]	Batch Time 0.594 (0.981)	Data Time 0.002 (0.364)	Loss 2.2221 (2.1403)	
[122/246]	Batch Time 0.612 (0.953)	Data

Experiencing connection interruptions. Will try to reestablish communication with Neptune.
Communication with Neptune restored!


[183/246]	Batch Time 0.618 (0.946)	Data Time 0.002 (0.324)	Loss 2.2822 (2.1241)	
[244/246]	Batch Time 0.656 (0.931)	Data Time 0.123 (0.311)	Loss 2.1599 (2.1185)	
*** checkpoint train loss: 2.118068502166882, val loss: 2.150101358811334***
[0/246]	Batch Time 4.098 (4.098)	Data Time 3.467 (3.467)	Loss 2.0840 (2.0840)	
[61/246]	Batch Time 0.632 (1.006)	Data Time 0.013 (0.385)	Loss 2.1361 (2.1398)	
[122/246]	Batch Time 1.144 (0.965)	Data Time 0.510 (0.344)	Loss 2.0046 (2.1221)	
[183/246]	Batch Time 0.634 (0.945)	Data Time 0.005 (0.322)	Loss 2.3351 (2.1223)	
[244/246]	Batch Time 0.532 (0.936)	Data Time 0.000 (0.315)	Loss 2.1788 (2.1197)	


Experiencing connection interruptions. Will try to reestablish communication with Neptune.
Communication with Neptune restored!


*** checkpoint train loss: 2.1204028900030445, val loss: 2.1579742699281184***
[0/246]	Batch Time 4.886 (4.886)	Data Time 4.261 (4.261)	Loss 1.8283 (1.8283)	
[61/246]	Batch Time 0.818 (0.965)	Data Time 0.201 (0.344)	Loss 2.2097 (2.1018)	
[122/246]	Batch Time 0.655 (0.948)	Data Time 0.001 (0.328)	Loss 1.9459 (2.1121)	
[183/246]	Batch Time 0.606 (0.937)	Data Time 0.001 (0.318)	Loss 2.0578 (2.1041)	
[244/246]	Batch Time 0.534 (0.927)	Data Time 0.000 (0.310)	Loss 2.4669 (2.1110)	
*** checkpoint train loss: 2.112324381778907, val loss: 2.1620938117015265***
[0/246]	Batch Time 3.940 (3.940)	Data Time 3.287 (3.287)	Loss 2.1734 (2.1734)	
[61/246]	Batch Time 0.643 (0.997)	Data Time 0.004 (0.373)	Loss 2.2252 (2.1212)	
[122/246]	Batch Time 2.170 (0.957)	Data Time 1.569 (0.333)	Loss 1.8349 (2.1108)	
[183/246]	Batch Time 0.651 (0.949)	Data Time 0.010 (0.328)	Loss 2.1855 (2.1013)	
[244/246]	Batch Time 0.533 (0.936)	Data Time 0.000 (0.315)	Loss 2.0210 (2.1128)	
*** checkpoint train loss: 2.1125309233