In [1]:
%%capture
!pip install neptune-client psutil
!git clone https://github.com/Cho-D-YoungRae/URP_PD.git
%cd URP_PD

In [2]:
import dataset
import object_detection
from utils import *
import eval

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
import torchvision.transforms.functional as TF
import torch.nn.functional as F

import os
import json
import numpy as np
import argparse
from tqdm.auto import tqdm
import time
from datetime import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using <{device}> device")

loading annotations into memory...
Done (t=0.08s)
creating index...
index created!
Using <cpu> device


In [3]:
# ====== constants ======#
label_map = {'background': 0, 'person': 1}
rev_label_map = {v: k for k, v in label_map.items()} 

## setting

In [4]:
from torch.backends import cudnn
import random

cudnn.benchmark = True

# ====== Random Seed Initialization ====== #
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
args = parser.parse_args("")
args.baselineID = 37

# ====== Dataset ====== #
args.ch_option = {'num_ch': 1,
                  'img_type': 'lwir',
                  'one_ch_option': 'mean'}
args.val_split = 0.1

# ====== Model ====== #
args.base_model = 'VGG16bnBase'
args.n_classes = len(label_map)
args.is_sds = True
args.usages_seg_feats = [True, True, False, False, False, False]
args.get_separate_loss = True


# ====== Optimizer & Training ====== #
args.optim = 'Adam'
args.lr = 5e-4
args.twice_b_lr = True
args.weight_decay = 5e-4

args.epochs = 150
args.train_batch_size = 32
args.test_batch_size = 64

args.decay_lr_at = [int(args.epochs/6)*4,
                    int(args.epochs/6)*5]
args.decay_lr_to = 0.1

## neptune init

In [5]:
import neptune.new as neptune

api_token = 
run = neptune.init(
    project='jodyr/urp',
    source_files=['*.py'],
    api_token=api_token,
    run='PD-100',
    )

run["parameters"] = vars(args)


Info (NVML): Driver Not Loaded. GPU usage metrics may not be reported. For more information, see https://docs-legacy.neptune.ai/logging-and-managing-experiment-results/logging-experiment-data.html#hardware-consumption 


https://app.neptune.ai/jodyr/urp/e/PD-100
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


## copy-paste augmentation

In [None]:
def limit_point_boundary(bbox, boundary=(0, 0, 300, 300)):
    for i in range(2):
        if bbox[i] < 0:
            bbox[i] = 0
    for i in range(2, 4):
        if bbox[i] > 300:
            bbox[i] = 300

    return bbox

def get_all_objects(images, bboxes):
    img_w, img_h = images.size(3), images.size(2)
    dims = torch.FloatTensor((img_w, img_h, img_w, img_h)).unsqueeze(0)
    bboxes = map(lambda x: torch.round(dims * x), bboxes)
    
    all_objects = []
    img_idx = []
    for i, (image, bboxes_in_img) in enumerate(zip(images, bboxes)):
        for bbox in bboxes_in_img:
            bbox = torch.round(bbox).long()
            x_min, y_min, x_max, y_max = limit_point_boundary(bbox)
            obj = image[:, y_min:y_max+1, x_min:x_max+1]
            img_idx.append(i)
            all_objects.append(obj)

    return all_objects, img_idx

In [None]:
def flip_objects(all_objects):
    for i in range(len(all_objects)):
        if random.random() < 0.5:
            all_objects[i] = torch.flip(all_objects[i], dims=(0, 2))

    return all_objects

In [None]:
def get_to_imgs(img_idxs, batch_size):
    counts = [0] * batch_size
    for i in img_idxs:
        counts[i] += 1
    max_count, min_count = max(counts), min(counts)
    weights = [max_count - count + min_count for count in counts]
    to_imgs = random.choices(
        range(batch_size), weights=weights, k=len(img_idxs))
    
    return to_imgs

In [None]:
def copy_paste_aug(images, bboxes, category_ids, seg_labels,
                   max_paste_try=5,
                   resize_min_max=(0.8, 2),
                   max_overlap=0
                   ):
    batch_size, num_ch, img_h, img_w = images.size()
    img_dim = torch.Tensor([img_w, img_h, img_w, img_h])
    all_objects, img_idxs = get_all_objects(images, bboxes)
    all_objects = flip_objects(all_objects)
    to_imgs = get_to_imgs(img_idxs, batch_size)

    for obj, to_img in zip(all_objects, to_imgs):
        _, obj_h, obj_w = obj.size()
        max_rescale = min(img_w / obj_w, img_h / obj_h)
        bboxes_in_img = bboxes[to_img]  # Tensor (n_objects, 4)
        cat_ids_in_img = category_ids[to_img]
        for _ in range(max_paste_try):
            rescale = random.uniform(*resize_min_max)
            rescale = min(rescale, max_rescale)
            r_h, r_w = int(rescale*obj_h), int(rescale*obj_w)
            r_h = 1 if r_h == 0 else r_h
            r_w = 1 if r_w == 0 else r_w
            resize_obj = F.interpolate(
                obj.unsqueeze(0), size=(r_h, r_w), mode='nearest'
            ).squeeze(0)
            paste_x = random.randint(0, img_w - r_w)
            paste_y = random.randint(0, img_h - r_h)
            paste_box = torch.Tensor(
                [paste_x, paste_y, paste_x+r_w, paste_y+r_h])   # (4,)
            scaled_paste_box = (paste_box / img_dim).unsqueeze(0)   # (1, 4)
            overlaps = find_jaccard_overlap(
                scaled_paste_box, bboxes_in_img
            )
            if overlaps.max().item() <= max_overlap:
                x_min, y_min, x_max, y_max = paste_box.long()
                seg_labels[to_img, 0, y_min:y_max, x_min:x_max] = 1
                images[to_img, :, y_min:y_max, x_min:x_max] = resize_obj
                bboxes_in_img = torch.cat((bboxes_in_img, scaled_paste_box))
                bboxes[to_img] = bboxes_in_img
                cat_ids_in_img = torch.cat((cat_ids_in_img,
                                           torch.LongTensor([1])))
                category_ids[to_img] = cat_ids_in_img
                break
            

    return images, bboxes, category_ids, seg_labels

## train

In [None]:
from neptune.new.types import File

def train(train_loader, model, criterion, optimizer, run=None):
    """
    One epoch's training.

    :param train_loader: DataLoader for training data
    :param model: model
    :param criterion: MultiBox loss
    :param optimizer: optimizer
    :param epoch: epoch number
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.train()  # training mode enables dropout

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    if criterion.get_separate_loss:
        sep_losses = {'conf_loss': AverageMeter(),
                      'loc_loss': AverageMeter(),
                      'total_seg_loss': AverageMeter()}

    print_freq = len(train_loader) // 4

    start = time.time()

    # Batches
    for i, (images, bboxes, category_ids, is_crowds, seg_labels) in enumerate(train_loader):
        data_time.update(time.time() - start)
        images, bboxes, category_ids, seg_labels = \
            copy_paste_aug(images, bboxes, category_ids, seg_labels)
        images = images.to(device)  # (batch_size (N), 3, 300, 300)
        bboxes = [b.to(device) for b in bboxes]
        category_ids = [c.to(device) for c in category_ids]
        seg_labels = seg_labels.to(device)

        # Forward prop.
        predicted_locs, predicted_scores, pred_segs = model(images)  # (N, 8732, 4), (N, 8732, n_classes)

        if criterion.get_separate_loss:
            loss, separate_loss = criterion(
                predicted_locs, predicted_scores, bboxes,
                category_ids,pred_segs, seg_labels)  # scalar
            for loss_name, sep_loss in separate_loss.items():
                sep_losses[loss_name].update(sep_loss, images.size(0))

        else:
            loss = criterion(predicted_locs, predicted_scores, 
                             bboxes, category_ids, pred_segs, seg_labels)  # scalar

        # Backward prop.
        optimizer.zero_grad()
        loss.backward()

        # Update model
        optimizer.step()

        losses.update(loss.item(), images.size(0))
        batch_time.update(time.time() - start)

        start = time.time()
        
        if i == 0 and run:
            ft_names = ['conv4_3', 'conv7', 'conv8_2', 'conv9_2', 'conv10_2', 'conv11_2']
            for ft_name, pred_seg in zip(ft_names, [pred_seg.detach() 
                                                    for pred_seg in pred_segs]):
                pred_seg = F.softmax(pred_seg, dim=1)
                pred_seg_grid = make_grid(pred_seg)
                pred_seg_grid = TF.to_pil_image(pred_seg_grid[1])
                run[f'train/pred_seg_{ft_name}'].log(File.as_image(pred_seg_grid))

            img_grid = make_grid(images)
            img_grid = TF.to_pil_image(img_grid)
            run['train/input'].log(File.as_image(img_grid))
            seg_gt_grid = make_grid(seg_labels)
            seg_gt_grid = TF.to_pil_image(seg_gt_grid)
            run['train/seg_gt'].log(File.as_image(seg_gt_grid))

        # Print status
        if i % print_freq == 0:
            print(f'[{i}/{len(train_loader)}]\t'
                  f'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  f'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  f'Loss {losses.val:.4f} ({losses.avg:.4f})\t')
    del predicted_locs, predicted_scores, images, bboxes, category_ids

    if run:
        run['train/loss'].log(losses.avg)
        if criterion.get_separate_loss:
            for loss_name, sep_loss in sep_losses.items():
                run[f'train/{loss_name}'].log(sep_loss.avg)

## validation

In [None]:
# def validation(val_loader, model, criterion):
#     model.eval()

#     num_batches = len(val_loader)
#     losses = AverageMeter()
#     with torch.no_grad():
#         for i, (images, bboxes, category_ids, _) in enumerate(val_loader):
#             images = images.to(device)
#             bboxes = [b.to(device) for b in bboxes]
#             category_ids = [l.to(device) for l in category_ids]

#             predicted_locs, predicted_scores = model(images)
#             loss = criterion(predicted_locs, predicted_scores, bboxes, category_ids).item()

#             losses.update(loss, images.size(0))

#     val_loss = losses.avg
#     return val_loss

## checkpoint

In [None]:
checkpoint = os.path.join('/content/drive/MyDrive/2021.summer_URP/PD/checkpoint',
                          str(args.baselineID)+'.pth.tar')
checkpoint = checkpoint if os.path.isfile(checkpoint) else None
print(f"checkpoint: {checkpoint}")

checkpoint: /content/drive/MyDrive/2021.summer_URP/PD/checkpoint/37.pth.tar


In [None]:
if checkpoint is None:
    start_epoch = 1
    lr = args.lr
    model = object_detection.SDSSSD300(n_classes=args.n_classes,
                                       base=args.base_model,
                                       ch_option=args.ch_option,
                                       usages_seg_feats=args.usages_seg_feats)
    if args.twice_b_lr:
        biases = list()
        not_biases = list()
        for param_name, param in model.named_parameters():
            if param.requires_grad:
                if param_name.endswith('.bias'):
                    biases.append(param)
                else:
                    not_biases.append(param)
        optimizer = getattr(torch.optim, args.optim)(params=[{'params': biases, 'lr': 2 * lr}, 
                                                            {'params': not_biases}],
                                                     lr=lr,
                                                     weight_decay=args.weight_decay)
    else:
        optimizer = getattr(torch.optim, args.optim)(params=model.parameters(),
                                                     lr=lr,
                                                     weight_decay=args.weight_decay)

else:
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    print('\nLoaded checkpoint from epoch %d.\n' % start_epoch)
    model = checkpoint['model']
    optimizer = checkpoint['optimizer']


model = model.to(device)
criterion = object_detection.SDSMultiBoxLoss(
    priors_cxcy=model.priors_cxcy, 
    usages_seg_feats=model.usages_seg_feats,
    get_separate_loss=args.get_separate_loss).to(device)


Loaded checkpoint from epoch 126.





## dataset init

In [None]:
workers = 4
train_dataset = dataset.KaistPDDataset(ch_option=args.ch_option,
                                       is_sds=args.is_sds)
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=args.train_batch_size, 
                                           shuffle=True,
                                           collate_fn=dataset.sds_collate_fn,
                                           num_workers=workers,
                                           pin_memory=True)

  cpuset_checked))


## experiment

In [None]:
checkpoint_dir = '/content/drive/MyDrive/2021.summer_URP/PD/checkpoint'
checkpoint_path = os.path.join(checkpoint_dir,
                               str(args.baselineID)+'.pth.tar')

In [None]:
epochs = args.epochs
decay_lr_at = args.decay_lr_at
save_freq = 5
eval_freq = 10

# Epochs
for epoch in range(start_epoch, epochs+1):
    print(f"# ====== Epoch {epoch} ====== # {datetime.now()}")
    # Decay learning rate at particular epochs
    if epoch in decay_lr_at:
        adjust_learning_rate(optimizer, args.decay_lr_to)

    # One epoch's training
    train(train_loader=train_loader, model=model,
          criterion=criterion, optimizer=optimizer, run=run)
    if epoch % save_freq == 0:
        save_checkpoint(epoch, model, optimizer, checkpoint_path)

    if epoch % eval_freq == 0:
        eval.evaluate(model, ch_option=args.ch_option)



  cpuset_checked))


[0/273]	Batch Time 4.829 (4.829)	Data Time 4.191 (4.191)	Loss 1.5444 (1.5444)	
[68/273]	Batch Time 2.490 (1.092)	Data Time 1.867 (0.557)	Loss 1.5115 (1.5664)	
[136/273]	Batch Time 2.197 (1.052)	Data Time 1.635 (0.519)	Loss 1.6478 (1.5734)	
[204/273]	Batch Time 2.574 (1.052)	Data Time 2.084 (0.517)	Loss 1.6635 (1.5705)	
[272/273]	Batch Time 0.312 (1.024)	Data Time 0.020 (0.493)	Loss 1.5675 (1.5658)	
[0/273]	Batch Time 4.129 (4.129)	Data Time 3.549 (3.549)	Loss 1.5503 (1.5503)	
[68/273]	Batch Time 0.848 (1.052)	Data Time 0.277 (0.517)	Loss 1.5959 (1.5731)	
[136/273]	Batch Time 1.962 (1.038)	Data Time 1.451 (0.506)	Loss 1.5818 (1.5689)	
[204/273]	Batch Time 1.308 (1.034)	Data Time 0.768 (0.505)	Loss 1.5381 (1.5716)	
[272/273]	Batch Time 0.436 (1.032)	Data Time 0.137 (0.503)	Loss 1.6852 (1.5675)	
[0/273]	Batch Time 5.368 (5.368)	Data Time 4.741 (4.741)	Loss 1.6232 (1.6232)	
[68/273]	Batch Time 2.240 (1.102)	Data Time 1.619 (0.563)	Loss 1.3931 (1.5671)	
[136/273]	Batch Time 0.516 (1.074)	Da

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  image_boxes.append(class_decoded_locs[1 - suppress])
  image_scores.append(class_scores[1 - suppress])



Loading and preparing results...
DONE (t=0.13s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=0.33s).
Accumulating evaluation results...
DONE (t=0.01s).
 Average Miss Rate  (MR) @ Reasonable         [ IoU=0.50      | height=[55:10000000000] | visibility=[none+partial_occ] ] = 25.72%
Recall: 0.8373408769448374


  cpuset_checked))


[0/273]	Batch Time 5.066 (5.066)	Data Time 4.439 (4.439)	Loss 1.4422 (1.4422)	
[68/273]	Batch Time 2.230 (1.146)	Data Time 1.660 (0.611)	Loss 1.6100 (1.5715)	
[136/273]	Batch Time 2.403 (1.095)	Data Time 1.801 (0.563)	Loss 1.4716 (1.5841)	
[204/273]	Batch Time 2.833 (1.074)	Data Time 2.290 (0.539)	Loss 1.4574 (1.5804)	
[272/273]	Batch Time 0.306 (1.055)	Data Time 0.000 (0.525)	Loss 1.4152 (1.5775)	
[0/273]	Batch Time 4.813 (4.813)	Data Time 4.248 (4.248)	Loss 1.5348 (1.5348)	
[68/273]	Batch Time 0.609 (1.086)	Data Time 0.001 (0.551)	Loss 1.6146 (1.5621)	
[136/273]	Batch Time 2.321 (1.052)	Data Time 1.786 (0.514)	Loss 1.6082 (1.5607)	
[204/273]	Batch Time 1.849 (1.048)	Data Time 1.341 (0.512)	Loss 1.5045 (1.5620)	
[272/273]	Batch Time 0.294 (1.037)	Data Time 0.000 (0.505)	Loss 1.8120 (1.5615)	
[0/273]	Batch Time 4.780 (4.780)	Data Time 4.233 (4.233)	Loss 1.4209 (1.4209)	
[68/273]	Batch Time 0.632 (1.102)	Data Time 0.010 (0.571)	Loss 1.4712 (1.5582)	
[136/273]	Batch Time 0.449 (1.074)	Da

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  image_boxes.append(class_decoded_locs[1 - suppress])
  image_scores.append(class_scores[1 - suppress])



Loading and preparing results...
DONE (t=0.02s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=0.29s).
Accumulating evaluation results...
DONE (t=0.01s).
 Average Miss Rate  (MR) @ Reasonable         [ IoU=0.50      | height=[55:10000000000] | visibility=[none+partial_occ] ] = 26.02%
Recall: 0.8380480905233381


  cpuset_checked))


[0/273]	Batch Time 4.451 (4.451)	Data Time 3.788 (3.788)	Loss 1.4818 (1.4818)	
[68/273]	Batch Time 0.601 (1.082)	Data Time 0.009 (0.545)	Loss 1.2766 (1.5663)	
[136/273]	Batch Time 0.898 (1.052)	Data Time 0.340 (0.518)	Loss 1.3889 (1.5620)	
[204/273]	Batch Time 2.774 (1.057)	Data Time 2.200 (0.526)	Loss 1.5190 (1.5635)	
[272/273]	Batch Time 0.307 (1.036)	Data Time 0.000 (0.510)	Loss 1.5647 (1.5641)	
[0/273]	Batch Time 4.706 (4.706)	Data Time 4.096 (4.096)	Loss 1.6047 (1.6047)	
[68/273]	Batch Time 2.055 (1.106)	Data Time 1.558 (0.578)	Loss 1.5885 (1.5474)	
[136/273]	Batch Time 2.774 (1.071)	Data Time 2.245 (0.540)	Loss 1.4924 (1.5546)	
[204/273]	Batch Time 2.530 (1.059)	Data Time 1.991 (0.528)	Loss 1.4320 (1.5539)	
[272/273]	Batch Time 0.298 (1.034)	Data Time 0.000 (0.506)	Loss 1.4989 (1.5511)	
[0/273]	Batch Time 4.510 (4.510)	Data Time 3.914 (3.914)	Loss 1.4981 (1.4981)	
[68/273]	Batch Time 2.696 (1.098)	Data Time 2.135 (0.565)	Loss 1.4232 (1.5346)	
[136/273]	Batch Time 1.750 (1.063)	Da

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  image_boxes.append(class_decoded_locs[1 - suppress])
  image_scores.append(class_scores[1 - suppress])



Loading and preparing results...
DONE (t=0.02s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=0.30s).
Accumulating evaluation results...
DONE (t=0.01s).
 Average Miss Rate  (MR) @ Reasonable         [ IoU=0.50      | height=[55:10000000000] | visibility=[none+partial_occ] ] = 25.73%
Recall: 0.8399153737658674
