In [1]:
import importlib

import model
import pan_loader
import base_config
import loss_functions as L

import utils

import torch
import torch.nn as nn
import torch.optim as optim

import os
import time
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


In [2]:
data_dir = "/home/aravind/dataset/"
ann_dir = data_dir + "annotations/panoptic/"

train_img_dir = data_dir + "train2017/"
train_seg_dir = ann_dir + "panoptic_train2017/"
train_ann_json = ann_dir + "panoptic_train2017.json"

val_img_dir = data_dir + "val2017/"
val_seg_dir = ann_dir + "panoptic_val2017/"
val_ann_json = ann_dir + "panoptic_val2017.json"

# train_img_dir = val_img_dir 
# train_seg_dir = val_seg_dir 
# train_ann_json = val_ann_json 

In [3]:
with open(val_ann_json,"r") as f:
    val_ann = json.load(f)
with open(train_ann_json,"r") as f:
    train_ann = json.load(f)

In [4]:
config = base_config.Config()

In [5]:
train_loader = pan_loader.get_loader(train_img_dir, train_seg_dir, train_ann, config)
val_loader = pan_loader.get_loader(val_img_dir, val_seg_dir, val_ann, config)

In [6]:
net = model.hgmodel()
# model_dir="models/"
# model_name="first_0.pt"
# pretrained_dict = torch.load(model_dir+model_name)
# net_dict = net.state_dict()

# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net_dict}
# net_dict.update(pretrained_dict) 
# net.load_state_dict(net_dict)


In [7]:
def set_trainable(module,state):
    for param in module.parameters():
        param.requires_grad = state

set_trainable(net,False)

set_trainable(net.mb0,True)
# set_trainable(net.mb1,True)
set_trainable(net.cb,True)

set_trainable(net.iresnet0, False)
# set_trainable(net.iresnet1, False)

for name,module in net.iresnet0.named_modules():
    if 'copy_bn' in name:
        set_trainable(module, True)
    elif 'copy_conv' in name:
        set_trainable(module, True)
    else:
        set_trainable(module, False)
        if isinstance(module, nn.BatchNorm2d):
            module.eval()
        
param_lr = []
param_lr.append({'params': net.mb0.parameters(),'lr':1e-2,'momentum':0.9})
# param_lr.append({'params': net.mb1.parameters(),'lr':1e-2,'momentum':0.9})
param_lr.append({'params': net.cb.parameters(),'lr':1e-2,'momentum':0.9})

for name,module in net.iresnet0.named_modules():
    if 'copy_bn' in name:
#         print(name)
        param_lr.append({'params':module.parameters(),'lr':1e-2,'momentum':0.9})
    elif 'copy_conv' in name:
        param_lr.append({'params':module.parameters(),'lr':1e-2,'momentum':0.9})
        
# param_lr = [{'params':net.parameters(),'lr':1e-2, 'momentum':0.9}]
net_size = sum([i.numel() for i in net.parameters()])
trainable_params = filter(lambda p: p.requires_grad, net.parameters())
trainable_size = sum([i.numel() for i in trainable_params])
print(net_size,trainable_size)
optimizer = optim.SGD(param_lr)

50590008 25612216


In [8]:
# net = nn.DataParallel(net, device_ids=[0,1])
net = net.cuda()

In [None]:
ckpt = utils.Checkpoint(iters_per_epoch=20, model_dir="./models/", model_name="first")
# torch.set_printoptions(float('nan'))
for i, data in enumerate(train_loader,0):
    optimizer.zero_grad()
    
    images, impulses, instance_masks, cat_ids = utils.cudify_data(data)
    del(data)
    outs = net([images,impulses])
#     print(outs[0])
    del(images, impulses)
    loss = L.loss_criterion(outs, [instance_masks, cat_ids])
    del(instance_masks, cat_ids, outs)
    ckpt.update(loss.data, net)
    loss.backward()
    del(loss)
    optimizer.step()

Step: 20	0.93147
Step: 40	0.90026
Step: 60	0.92321
Step: 80	0.91132
Step: 100	0.87631
Step: 120	0.90102
Step: 140	0.87082
Step: 160	0.90526
Step: 180	0.91003
Step: 200	0.91944
Step: 220	0.88891
Step: 240	0.87260
Step: 260	0.89098
Step: 280	0.91110
Step: 300	0.89898
Step: 320	0.87101
Step: 340	0.89507
Step: 360	0.87914
Step: 380	0.89551
Step: 400	0.91259
Step: 420	0.85297
Step: 440	0.84652
Step: 460	0.90812
Step: 480	0.88938
Step: 500	0.87586
Step: 520	0.89239
Step: 540	0.89513
Step: 560	0.89195
Step: 580	0.89171
Step: 600	0.90204
Step: 620	0.88442
Step: 640	0.87494
Step: 660	0.88223
Step: 680	0.88652
Step: 700	0.90307
Step: 720	0.88182
Step: 740	0.86434
Step: 760	0.87727
Step: 780	0.88058
Step: 800	0.89721
Step: 820	0.90023
Step: 840	0.84204
Step: 860	0.88710
Step: 880	0.87204
Step: 900	0.87763
Step: 920	0.90314
Step: 940	0.85096
Step: 960	0.88396
Step: 980	0.87518
Step: 1000	0.85864
Step: 1020	0.88437
Step: 1040	0.89730
problem loading image index: 18586 91492
Step: 1060	0.90414
Step:

Step: 8660	0.66571
