In [1]:
import importlib

import model
import pan_loader
import base_config
import loss_functions as L

import utils

import torch
import torch.nn as nn
import torch.optim as optim

import os
import time
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


In [2]:
data_dir = "/home/aravind/dataset/"
ann_dir = data_dir + "annotations/panoptic/"

train_img_dir = data_dir + "train2017/"
train_seg_dir = ann_dir + "panoptic_train2017/"
train_ann_json = ann_dir + "panoptic_train2017.json"

val_img_dir = data_dir + "val2017/"
val_seg_dir = ann_dir + "panoptic_val2017/"
val_ann_json = ann_dir + "panoptic_val2017.json"

# train_img_dir = val_img_dir 
# train_seg_dir = val_seg_dir 
# train_ann_json = val_ann_json 

In [3]:
with open(val_ann_json,"r") as f:
    val_ann = json.load(f)
with open(train_ann_json,"r") as f:
    train_ann = json.load(f)

In [4]:
config = base_config.Config()

In [5]:
train_loader = pan_loader.get_loader(train_img_dir, train_seg_dir, train_ann, config)
val_loader = pan_loader.get_loader(val_img_dir, val_seg_dir, val_ann, config)

In [6]:
net = model.hgmodel()
# model_dir="models/"
# model_name="first_0.pt"
# pretrained_dict = torch.load(model_dir+model_name)
# net_dict = net.state_dict()

# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net_dict}
# net_dict.update(pretrained_dict) 
# net.load_state_dict(net_dict)


In [7]:
def set_trainable(module,state):
    for param in module.parameters():
        param.requires_grad = state

set_trainable(net,False)

set_trainable(net.mb0,True)
# set_trainable(net.mb1,True)
set_trainable(net.cb,True)

set_trainable(net.iresnet0, False)
# set_trainable(net.iresnet1, False)

for name,module in net.iresnet0.named_modules():
    if 'copy_bn' in name:
        set_trainable(module, False)
    elif 'copy_conv' in name:
        set_trainable(module, False)
    elif 'wing' in name:
        set_trainable(module, True)
    else:
        set_trainable(module, False)
        if isinstance(module, nn.BatchNorm2d):
            module.eval()
        
param_lr = []
param_lr.append({'params': net.mb0.parameters(),'lr':1e-4,'momentum':0.9})
# param_lr.append({'params': net.mb1.parameters(),'lr':1e-4,'momentum':0.9})
param_lr.append({'params': net.cb.parameters(),'lr':1e-4,'momentum':0.9})

# for name,module in net.iresnet0.named_modules():
#     if 'copy_bn' in name:
#         param_lr.append({'params':module.parameters(),'lr':1e-4,'momentum':0.9})
#     elif 'copy_conv' in name:
#         param_lr.append({'params':module.parameters(),'lr':1e-4,'momentum':0.9})

for name,child in net.iresnet0.named_children():
    if 'wing' in name:
        param_lr.append({'params':child.parameters(),'lr':1e-4,'momentum':0.9})

# param_lr = [{'params':net.parameters(),'lr':1e-4, 'momentum':0.9}]
net_size = sum([i.numel() for i in net.parameters()])
trainable_params = filter(lambda p: p.requires_grad, net.parameters())
trainable_size = sum([i.numel() for i in trainable_params])
print(net_size,trainable_size)
optimizer = optim.SGD(param_lr, weight_decay=1e-4)

37761792 11222096


In [8]:
# net = nn.DataParallel(net, device_ids=[0,1])
net = net.cuda()

In [None]:
ckpt = utils.Checkpoint(iters_per_epoch=20, model_dir="./models/", model_name="second")
# torch.set_printoptions(float('nan'))
for i, data in enumerate(train_loader,0):
    optimizer.zero_grad()
    
    images, impulses, instance_masks, cat_ids = utils.cudify_data(data)
    del(data)
    outs = net([images,instance_masks])
#     print(outs[0])
    del(images, impulses)
    loss = L.loss_criterion(outs, [instance_masks, cat_ids])
    del(instance_masks, cat_ids, outs)
    ckpt.update(loss.data, net)
    loss.backward()
    del(loss)
    optimizer.step()

Step: 20	0.74008
Step: 40	0.62779
Step: 60	0.63501
Step: 80	0.62628
Step: 100	0.62956
Step: 120	0.62726
Step: 140	0.62230
Step: 160	0.62394
Step: 180	0.63522
Step: 200	0.61174
Step: 220	0.62770
Step: 240	0.62537
Step: 260	0.63358
Step: 280	0.62213
Step: 300	0.63597
Step: 320	0.62760
Step: 340	0.62963
Step: 360	0.63404
Step: 380	0.63106
Step: 400	0.61842
Step: 420	0.62727
Step: 440	0.63402
Step: 460	0.63308
Step: 480	0.64058
Step: 500	0.61893
Step: 520	0.63143
Step: 540	0.61802
Step: 560	0.61967
Step: 580	0.61689
Step: 600	0.62451
Step: 620	0.63328
Step: 640	0.62647
Step: 660	0.62049
Step: 680	0.62926
Step: 700	0.61909
Step: 720	0.64378
Step: 740	0.62772
Step: 760	0.64284
Step: 780	0.62606
Step: 800	0.60594
Step: 820	0.62421
Step: 840	0.62898
Step: 860	0.62507
Step: 880	0.62217
Step: 900	0.64682
Step: 920	0.63329
Step: 940	0.62759
Step: 960	0.62723
Step: 980	0.62593
Step: 1000	0.62578
Step: 1020	0.62822
Step: 1040	0.62307
Step: 1060	0.63017
Step: 1080	0.61533
Step: 1100	0.63054
Step: 11