In [1]:
from collections import defaultdict
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyprojroot
import seaborn as sns
import scipy.stats
import sklearn.metrics
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import searchnets
from searchnets import nets
from searchnets.datasets import VOCDetection
from searchnets.engine.abstract_trainer import AbstractTrainer
from searchnets.transforms.util import get_transforms
from searchnets.utils.general import make_save_path

In [2]:
def tile(a, dim, n_tile):
    dim_init_size = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(
        np.concatenate([dim_init_size * np.arange(n_tile) + i for i in range(dim_init_size)])
    )
    return torch.index_select(a, dim, order_index)

In [3]:
class DetectNet(nn.Module):
    def __init__(self,
                 vis_sys,
                 num_classes,
                 vis_sys_n_out,
                 embedding_n_out=512):
        super(DetectNet, self).__init__()
        self.vis_sys = vis_sys
        self.embedding = nn.Sequential(nn.Linear(in_features=num_classes,
                                                 out_features=embedding_n_out),
                                       nn.ReLU(inplace=True),
                                       )
        self.decoder = nn.Linear(in_features=vis_sys_n_out + embedding_n_out,
                                 out_features=1)  # always 1, because it indicates target present or absent

    def forward(self, img, query):
        vis_out = self.vis_sys(img)
        query_out = self.embedding(query)
        out = self.decoder(
            torch.cat((vis_out, query_out), dim=1)
        )
        return out

In [4]:
DATASET_TYPE = 'VOC'
LOSS_FUNCS = ['CE-largest', 'CE-random', 'BCE']

PAD_SIZE = 500

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [6]:
NUM_WORKERS = 4

In [7]:
VSD_CONFIGS_ROOT = pyprojroot.here('./data/configs/VSD')
VSD_CONFIG_INIS = sorted(VSD_CONFIGS_ROOT.glob('*.ini'))

In [8]:
ALEXNET_VSD_CONFIG_INIS = [vsd_config for vsd_config in VSD_CONFIG_INIS if 'alexnet' in str(vsd_config) and 'transfer' in str(vsd_config)]

In [9]:
configfile = ALEXNET_VSD_CONFIG_INIS[0]
cfg = searchnets.config.parse_config(configfile)

In [10]:
cfg = searchnets.config.parse_config(configfile)

In [11]:
transform, target_transform = get_transforms(cfg.data.dataset_type, 
                                             loss_func=cfg.train.loss_func,
                                             pad_size=cfg.data.pad_size)

trainset = VOCDetection(root=cfg.data.root,
                       csv_file=cfg.data.csv_file_out,
                       image_set='trainval',
                       split='train',
                       download=True,
                       transform=transform,
                       target_transform=target_transform
                       )

train_loader = DataLoader(trainset, batch_size=cfg.train.batch_size,
                          shuffle=False, num_workers=cfg.train.num_workers,
                          pin_memory=True)

Using downloaded and verified file: /home/bart/Documents/data/voc/VOCtrainval_11-May-2012.tar


In [12]:
vgg16 = searchnets.nets.vgg16.build(pretrained=True)
vgg16.classifier = vgg16.classifier[:-1]

In [13]:
a_batch = next(iter(train_loader))
tmp_img = a_batch['img']
tmp_out = vgg16(tmp_img)
n_out = tmp_out.shape[-1]

In [14]:
model = DetectNet(vis_sys=vgg16,
                  num_classes=20,
                  vis_sys_n_out=n_out,
                  embedding_n_out=512)

In [15]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'device: {device}')

model.to(device);

device: cuda


In [16]:
new_learn_rate_params = list(model.vis_sys.classifier.parameters())
new_learn_rate_params += list(model.embedding.parameters())
new_learn_rate_params += list(model.decoder.parameters())

In [17]:
MOMENTUM = 0.9

optimizers = []
optimizers.append(
    torch.optim.SGD(new_learn_rate_params,
                    lr=cfg.train.new_layer_learning_rate,
                    momentum=MOMENTUM))

In [18]:
feature_params = model.vis_sys.features.parameters()
for params in feature_params:
    params.requires_grad = False

In [19]:
criterion = nn.BCEWithLogitsLoss()

In [20]:
total_loss = 0
half_batch_size = int(cfg.train.batch_size / 2)

batch_total = int(np.ceil(len(trainset) / cfg.train.batch_size))
batch_pbar = tqdm(train_loader)
for i, batch in enumerate(batch_pbar):
    img, query = batch['img'], batch['target']
    img_tile = tile(img, 0, 20)
    batch_size, n_classes = query.shape
    target = query.flatten()
    query_expanded = torch.cat(batch_size * [torch.diag(torch.ones(n_classes,))])
    
    # -- make half of batch be target present, half target absent --
    all_target_present_inds = np.nonzero(target == 1)
    target_present_to_use = torch.randperm(all_target_present_inds.shape[0])[:half_batch_size]
    target_present_inds = all_target_present_inds[target_present_to_use]

    n_absent = target_present_to_use.shape[0]  # might be less than half_batch_size
    all_target_absent_inds = np.nonzero(target == 0)
    target_absent_to_use = torch.randperm(all_target_absent_inds.shape[0])[:n_absent]
    target_absent_inds = all_target_absent_inds[target_absent_to_use]
    
    batch_inds = torch.cat((target_present_inds, target_absent_inds)).flatten().sort()[0]
    
    target = target.unsqueeze(1)  # add back non-batch ind, so target matches output shape
    
    img, query, target = img_tile[batch_inds].to(device), query_expanded[batch_inds].to(device), target[batch_inds].to(device)

    output = model(img, query)
    loss = criterion(output, target)

    for optimizer in optimizers:
        optimizer.zero_grad()
    loss.mean().backward()  # mean needed for multiple GPUs
    for optimizer in optimizers:
        optimizer.step()

    batch_pbar.set_description(f'batch {i} of {batch_total}, loss: {loss: 7.3f}')
    total_loss += loss

batch 360 of 361, loss:   0.608: 100%|██████████| 361/361 [02:18<00:00,  2.96it/s]


In [27]:
target

tensor(8., device='cuda:0')

In [53]:
a_batch['target'].dtype

torch.float32

In [48]:
type(torch.FloatTensor)

torch.tensortype

In [24]:
len(list(model.vis_sys.classifier.parameters()))

4