In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
codebase = '../../'
sys.path.append(codebase)

In [2]:
import torch
torch.__version__

'1.8.2'

In [3]:
exp_name = 'exp_045_rep'
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [4]:
from dataloaders.visual_genome import VGDataLoader, VG
import numpy as np
from torch import optim
import torch
import pandas as pd
import time
import os

from config import ModelConfig, BOX_SCALE, IM_SCALE
from torch.nn import functional as F
from lib.pytorch_misc import optimistic_restore, de_chunkize, clip_grad_norm
from lib.evaluation.sg_eval import BasicSceneGraphEvaluator, calculate_mR_from_evaluator_list, eval_entry
from lib.pytorch_misc import print_para
from torch.optim.lr_scheduler import ReduceLROnPlateau

from lib.my_model_24 import KERN


In [5]:
conf = ModelConfig(f'''
-m predcls -p 1000 -clip 5 
-tb_log_dir summaries/kern_predcls/{exp_name} 
-save_dir checkpoints/kern_predcls/{exp_name}
-ckpt checkpoints/vgdet/vg-24.tar 
-val_size 5000 
-adam 
-b 4
-ngpu 1
-lr 1e-4 
''')

~~~~~~~~ Hyperparameters used: ~~~~~~~
ckpt : checkpoints/vgdet/vg-24.tar
save_dir : checkpoints/kern_predcls/exp_045_rep
num_gpus : 1
num_workers : 1
lr : 0.0001
batch_size : 4
val_size : 5000
l2 : 0.0001
adamwd : 0.0
clip : 5.0
print_interval : 1000
mode : predcls
cache : 
adam : True
test : False
num_epochs : 50
use_resnet : False
use_proposals : False
pooling_dim : 4096
use_ggnn_obj : False
ggnn_obj_time_step_num : 3
ggnn_obj_hidden_dim : 512
ggnn_obj_output_dim : 512
use_obj_knowledge : False
obj_knowledge : 
use_ggnn_rel : False
ggnn_rel_time_step_num : 3
ggnn_rel_hidden_dim : 512
ggnn_rel_output_dim : 512
use_rel_knowledge : False
rel_knowledge : 
tb_log_dir : summaries/kern_predcls/exp_045_rep
save_rel_recall : 


In [6]:
os.getcwd()

'/home/zhanwen/gbnet/ipynb/train_predcls'

In [8]:
train, val, _ = VG.splits(num_val_im=conf.val_size, filter_duplicate_rels=True,
                          use_proposals=conf.use_proposals,
                          filter_non_overlap=conf.mode == 'sgdet')


In [9]:
ind_to_predicates = train.ind_to_predicates # ind_to_predicates[0] means no relationship

In [10]:
train_loader, val_loader = VGDataLoader.splits(train, val, mode='rel',
                                               batch_size=conf.batch_size,
                                               num_workers=conf.num_workers,
                                               num_gpus=conf.num_gpus)


In [11]:
detector = KERN(classes=train.ind_to_classes, rel_classes=train.ind_to_predicates,
                num_gpus=conf.num_gpus, mode=conf.mode, require_overlap_det=True,
                use_resnet=conf.use_resnet, use_proposals=conf.use_proposals, pooling_dim=conf.pooling_dim,
                ggnn_rel_time_step_num=3, ggnn_rel_hidden_dim=1024, ggnn_rel_output_dim=None,
                graph_path=os.path.join(codebase, 'graphs/005/all_edges.pkl'), 
                emb_path=os.path.join(codebase, 'graphs/001/emb_mtx.pkl'), 
                rel_counts_path=os.path.join(codebase, 'graphs/001/pred_counts.pkl'), 
                use_knowledge=True, use_embedding=True, refine_obj_cls=False,
                class_volume=1.0
               )



Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/zhanwen/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100.0%
  init.xavier_normal(self.linear.weight)


In [12]:
# Freeze the detector
for n, param in detector.detector.named_parameters():
    param.requires_grad = False

In [13]:
print(print_para(detector), flush=True)


 444.6M total parameters 
 ----- 
 
detector.roi_fmap.0.weight                        : [4096,25088]    (102760448) (    )
roi_fmap.1.0.weight                               : [4096,25088]    (102760448) (grad)
roi_fmap_obj.0.weight                             : [4096,25088]    (102760448) (grad)
detector.roi_fmap.3.weight                        : [4096,4096]     (16777216) (    )
roi_fmap.1.3.weight                               : [4096,4096]     (16777216) (grad)
roi_fmap_obj.3.weight                             : [4096,4096]     (16777216) (grad)
ggnn_rel_reason.ggnn.fc_mp_receive_ont_ent.model.0.linear.weight: [3328,3328]     (11075584) (grad)
ggnn_rel_reason.obj_proj.weight                   : [1024,4096]     ( 4194304) (grad)
ggnn_rel_reason.rel_proj.weight                   : [1024,4096]     ( 4194304) (grad)
ggnn_rel_reason.ggnn.fc_mp_receive_ont_pred.model.0.linear.weight: [2048,2048]     ( 4194304) (grad)
ggnn_rel_reason.ggnn.fc_mp_receive_ont_ent.model.2.linear.weight: [1024

In [14]:
def get_optim(lr):
    # Lower the learning rate on the VGG fully connected layers by 1/10th. It's a hack, but it helps
    # stabilize the models.
    fc_params = [p for n,p in detector.named_parameters() if n.startswith('roi_fmap') and p.requires_grad]
    non_fc_params = [p for n,p in detector.named_parameters() if not n.startswith('roi_fmap') and p.requires_grad]
    params = [{'params': fc_params, 'lr': lr / 10.0}, {'params': non_fc_params}]
    # params = [p for n,p in detector.named_parameters() if p.requires_grad]

    if conf.adam:
        optimizer = optim.Adam(params, weight_decay=conf.adamwd, lr=lr, eps=1e-3)
    else:
        optimizer = optim.SGD(params, weight_decay=conf.l2, lr=lr, momentum=0.9)

    # scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.1,
    #                               verbose=True, threshold=0.0001, threshold_mode='abs', cooldown=1)
    return optimizer #, scheduler



In [15]:
ckpt = torch.load(conf.ckpt)


In [16]:
if conf.ckpt.split('-')[-2].split('/')[-1] == 'vgrel':
    print("Loading EVERYTHING")
    start_epoch = ckpt['epoch']

    if not optimistic_restore(detector, ckpt['state_dict']):
        start_epoch = -1
        # optimistic_restore(detector.detector, torch.load('checkpoints/vgdet/vg-28.tar')['state_dict'])
else:
    start_epoch = -1
    optimistic_restore(detector.detector, ckpt['state_dict'])

    detector.roi_fmap[1][0].weight.data.copy_(ckpt['state_dict']['roi_fmap.0.weight'])
    detector.roi_fmap[1][3].weight.data.copy_(ckpt['state_dict']['roi_fmap.3.weight'])
    detector.roi_fmap[1][0].bias.data.copy_(ckpt['state_dict']['roi_fmap.0.bias'])
    detector.roi_fmap[1][3].bias.data.copy_(ckpt['state_dict']['roi_fmap.3.bias'])

    detector.roi_fmap_obj[0].weight.data.copy_(ckpt['state_dict']['roi_fmap.0.weight'])
    detector.roi_fmap_obj[3].weight.data.copy_(ckpt['state_dict']['roi_fmap.3.weight'])
    detector.roi_fmap_obj[0].bias.data.copy_(ckpt['state_dict']['roi_fmap.0.bias'])
    detector.roi_fmap_obj[3].bias.data.copy_(ckpt['state_dict']['roi_fmap.3.bias'])


In [17]:
detector.cuda();


In [18]:
def train_epoch(epoch_num):
    detector.train()
    tr = []
    start = time.time()
    for b, batch in enumerate(train_loader):
        result, loss_pd = train_batch(batch, verbose=b % (conf.print_interval*10) == 0)
        tr.append(loss_pd)
        '''
        if b % 100 == 0:
            print(loss_pd)
            gt = result.rel_labels[:,3].data.cpu().numpy()
            out = result.rel_dists.data.cpu().numpy()
            ind = np.where(gt)[0]
            print(gt[ind])
            print(np.argmax(out[ind], 1))
            print(np.argmax(out[ind, 1:], 1) + 1)
        '''

        if b % conf.print_interval == 0 and b >= conf.print_interval:
            mn = pd.concat(tr[-conf.print_interval:], axis=1).mean(1)
            time_per_batch = (time.time() - start) / conf.print_interval
            print("\ne{:2d}b{:5d}/{:5d} {:.3f}s/batch, {:.1f}m/epoch".format(
                epoch_num, b, len(train_loader), time_per_batch, len(train_loader) * time_per_batch / 60))
            print(mn)
            print('-----------', flush=True)
            start = time.time()
    return pd.concat(tr, axis=1)


In [19]:
def train_batch(b, verbose=False):
    """
    :param b: contains:
          :param imgs: the image, [batch_size, 3, IM_SIZE, IM_SIZE]
          :param all_anchors: [num_anchors, 4] the boxes of all anchors that we'll be using
          :param all_anchor_inds: [num_anchors, 2] array of the indices into the concatenated
                                  RPN feature vector that give us all_anchors,
                                  each one (img_ind, fpn_idx)
          :param im_sizes: a [batch_size, 4] numpy array of (h, w, scale, num_good_anchors) for each image.

          :param num_anchors_per_img: int, number of anchors in total over the feature pyramid per img

          Training parameters:
          :param train_anchor_inds: a [num_train, 5] array of indices for the anchors that will
                                    be used to compute the training loss (img_ind, fpn_idx)
          :param gt_boxes: [num_gt, 4] GT boxes over the batch.
          :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
    :return:
    """
    result = detector[b]

    losses = {}
    losses['class_loss'] = detector.obj_loss(result)
    losses['rel_loss'] = detector.rel_loss(result)
    loss = sum(losses.values())

    optimizer.zero_grad()
    loss.backward()
    clip_grad_norm(
        [(n, p) for n, p in detector.named_parameters() if p.grad is not None],
        max_norm=conf.clip, verbose=verbose, clip=True)
    losses['total'] = loss
    optimizer.step()
    loss_pd = pd.Series({x: y.data.item() for x, y in losses.items()})
    return result, loss_pd


In [20]:
def val_epoch():
    detector.eval()
    evaluator_list = [] # for calculating recall of each relationship except no relationship
    evaluator_multiple_preds_list = []
    for index, name in enumerate(ind_to_predicates):
        if index == 0:
            continue
        evaluator_list.append((index, name, BasicSceneGraphEvaluator.all_modes()))
        evaluator_multiple_preds_list.append((index, name, BasicSceneGraphEvaluator.all_modes(multiple_preds=True)))
    evaluator = BasicSceneGraphEvaluator.all_modes() # for calculating recall
    evaluator_multiple_preds = BasicSceneGraphEvaluator.all_modes(multiple_preds=True)
    for val_b, batch in enumerate(val_loader):
        val_batch(conf.num_gpus * val_b, batch, evaluator, evaluator_multiple_preds, evaluator_list, evaluator_multiple_preds_list)

    recall = evaluator[conf.mode].print_stats()
    recall_mp = evaluator_multiple_preds[conf.mode].print_stats()
    
    mean_recall = calculate_mR_from_evaluator_list(evaluator_list, conf.mode)
    mean_recall_mp = calculate_mR_from_evaluator_list(evaluator_multiple_preds_list, conf.mode, multiple_preds=True)
    
    detector.train()
    return recall, recall_mp, mean_recall, mean_recall_mp


In [21]:
def val_batch(batch_num, b, evaluator, evaluator_multiple_preds, evaluator_list, evaluator_multiple_preds_list):
    det_res = detector[b]
    if conf.num_gpus == 1:
        det_res = [det_res]

    for i, (boxes_i, objs_i, obj_scores_i, rels_i, pred_scores_i) in enumerate(det_res):
        gt_entry = {
            'gt_classes': val.gt_classes[batch_num + i].copy(),
            'gt_relations': val.relationships[batch_num + i].copy(),
            'gt_boxes': val.gt_boxes[batch_num + i].copy(),
        }
        assert np.all(objs_i[rels_i[:, 0]] > 0) and np.all(objs_i[rels_i[:, 1]] > 0)

        pred_entry = {
            'pred_boxes': boxes_i * BOX_SCALE/IM_SCALE,
            'pred_classes': objs_i,
            'pred_rel_inds': rels_i,
            'obj_scores': obj_scores_i,
            'rel_scores': pred_scores_i,  # hack for now.
        }

        eval_entry(conf.mode, gt_entry, pred_entry, evaluator, evaluator_multiple_preds, 
                   evaluator_list, evaluator_multiple_preds_list)



In [22]:
if conf.tb_log_dir is not None:
    from tensorboardX import SummaryWriter
    if not os.path.exists(conf.tb_log_dir):
        os.makedirs(conf.tb_log_dir) 
    writer = SummaryWriter(log_dir=conf.tb_log_dir)
    use_tb = True
else:
    use_tb = False


In [None]:
print("Training starts now!")
optimizer = get_optim(conf.lr * conf.num_gpus * conf.batch_size)

for epoch in range(0, 30):
    if epoch == 10 or epoch == 20:
        for param_group in optimizer.param_groups:
            param_group['lr'] /= 10
    
    rez = train_epoch(epoch)
    print("overall{:2d}: ({:.3f})\n{}".format(epoch, rez.mean(1)['total'], rez.mean(1)), flush=True)

    if use_tb:
        writer.add_scalar('loss/rel_loss', rez.mean(1)['rel_loss'], epoch)
        writer.add_scalar('loss/class_loss', rez.mean(1)['class_loss'], epoch)
        writer.add_scalar('loss/total', rez.mean(1)['total'], epoch)

    if conf.save_dir is not None:
        torch.save({
            'epoch': epoch,
            'state_dict': detector.state_dict(), #{k:v for k,v in detector.state_dict().items() if not k.startswith('detector.')},
            # 'optimizer': optimizer.state_dict(),
        }, os.path.join(conf.save_dir, '{}-{}.tar'.format('vgrel', epoch)))

    recall, recall_mp, mean_recall, mean_recall_mp = val_epoch()
    if use_tb:
        for key, value in recall.items():
            writer.add_scalar('eval_' + conf.mode + '_with_constraint/' + key, value, epoch)
        for key, value in recall_mp.items():
            writer.add_scalar('eval_' + conf.mode + '_without_constraint/' + key, value, epoch)
        for key, value in mean_recall.items():
            writer.add_scalar('eval_' + conf.mode + '_with_constraint/mean ' + key, value, epoch)
        for key, value in mean_recall_mp.items():
            writer.add_scalar('eval_' + conf.mode + '_without_constraint/mean ' + key, value, epoch)



Training starts now!
---Total norm 0.687 clip coef 7.278-----------------
roi_fmap.1.0.weight                               : 0.497, (torch.Size([4096, 25088]))
roi_fmap.1.3.weight                               : 0.202, (torch.Size([4096, 4096]))
ggnn_rel_reason.ggnn.fc_output_proj_ont_pred.model.0.linear.weight: 0.157, (torch.Size([1024, 1024]))
ggnn_rel_reason.rel_proj.weight                   : 0.146, (torch.Size([1024, 4096]))
roi_fmap_obj.0.weight                             : 0.123, (torch.Size([4096, 25088]))
ggnn_rel_reason.ggnn.fc_output_proj_ont_pred.model.2.linear.weight: 0.115, (torch.Size([1024, 1024]))
ggnn_rel_reason.ggnn.fc_output_proj_img_pred.model.2.linear.weight: 0.112, (torch.Size([1024, 1024]))
ggnn_rel_reason.ggnn.fc_output_proj_img_pred.model.0.linear.weight: 0.111, (torch.Size([1024, 1024]))
ggnn_rel_reason.ggnn.fc_eq5_u_img_pred.weight     : 0.101, (torch.Size([1024, 1024]))
ggnn_rel_reason.ggnn.fc_eq5_w_ont_pred.weight     : 0.096, (torch.Size([1024, 1024]))
