In [1]:
import model_lib
import numpy as np
import warnings
warnings.filterwarnings('ignore', '.*output shape of zoom.*')
import pickle
import importlib
importlib.reload(model_lib)
import os
import time
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"


In [2]:
# config to train
# TODO: check Config is correct
class ProposalConfig():
    NAME = "InSegm"
    GPU_COUNT = 1
    # online training
    IMAGES_PER_GPU = 32
    STEPS_PER_EPOCH = 100
    # not going to use these
    N_DISTORTIONS = 0
    MAX_DISTORTION = 0.3
    MIN_DISTORTION = -0.1
    NUM_WORKERS = 16
    PIN_MEMORY = True
    VALIDATION_STEPS = 20
    # including gt
    NUM_CLASSES = 81
    # only flips
    IMAGE_AUGMENT = True
    DATA_ORDER = "ins"
    MEAN_PIXEL = np.array([0.485, 0.456, 0.406],dtype=np.float32).reshape(1,1,-1)
    STD_PIXEL = np.array([0.229, 0.224, 0.225],dtype=np.float32).reshape(1,1,-1)
    MAX_GT_INSTANCES = 100
    DETECTION_MAX_INSTANCES = 100
    DETECTION_MIN_CONFIDENCE = 0.7
    CLASS_NAMES = [
        'BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
        'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
        'scissors', 'teddy bear', 'hair drier', 'toothbrush'
    ]
    DETECTION_NMS_THRESHOLD = 0.3
    LEARNING_RATE = 0.05
    LEARNING_MOMENTUM = 0.9
    WEIGHT_DECAY = 0.0001
    WIDTH = 224
    HEIGHT = 224
    MASK_SHAPE = (64,64)
    GRID_WIDTH = 16
    GRID_HEIGHT = 16
    CLUE_SHAPE = (20,20)
    GRID_SHAPE = (GRID_WIDTH, GRID_HEIGHT)
    GRID_RESOLUTION = (1, 1)
    IS_PADDED = True
    MASK_THRESOLD = 0.7
    CROP_SIZE = 224
    def __init__(self):
        self.BATCH_SIZE = self.IMAGES_PER_GPU * self.GPU_COUNT
        self.IMAGE_SHAPE = (self.WIDTH, self.HEIGHT,3)
        self.MAX_BATCH_SIZE = self.BATCH_SIZE*32

    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")

In [3]:
# train_image_dir = "/media/data/nishanth/aravind/train2017/"
train_image_dir = "/media/data/nishanth/aravind/val2017/"
val_image_dir = "/media/data/nishanth/aravind/val2017/"
config = ProposalConfig()
model_dir = "./models/"
# train_pickle = "/home/aravind/re/data/train_cwid.pickle"
train_pickle = "/home/aravind/re/data/val_cwid.pickle"
val_pickle = "/home/aravind/re/data/val_cwid.pickle"

In [4]:
with open(train_pickle,"rb") as train_ann:
    train_cwid = pickle.load(train_ann)
with open(val_pickle,"rb") as val_ann:
    val_cwid = pickle.load(val_ann)

In [5]:
train_loader = model_lib.get_loader(train_cwid,config,train_image_dir)
val_loader = model_lib.get_loader(val_cwid,config,val_image_dir)
# for j in enumerate(train_loader):
#     k = j

In [6]:
import torch.optim as optim
import torch
net = model_lib.SimpleHGModel()
# net.vgg.load_state_dict(torch.load("./models/vgg11_features.pt"))
# net.load_state_dict(torch.load("./models/model_big_bce_1_4200.pt"))
pretrained_dict = torch.load(model_dir+"model_vgg_class_only.pt")
net = model_lib.SimpleHGModel()
net_dict = net.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net_dict}
net_dict.update(pretrained_dict) 
net.load_state_dict(net_dict)

net = net.cuda()


In [7]:
# for name,child in net.named_children():
#     if name == 'vgg':
#         for k,gc in child.named_children():
#             if k[:-1] == 'layer':
#                 print("no grad to",k)
#                 for param in gc.parameters():
#                     param.requires_grad = False
for name,child in net.named_children():
    if name == "mask_predictor":
        for param in child.parameters():
            param.requires_grad = True
    elif name == "vgg":
        for n,gc in child.named_children():
            if n[:-1] =="wing_conv":
                for param in gc.parameters():
                    param.requires_grad = True
    else:
        for param in child.parameters():
            param.requires_grad = False

In [8]:
# net.load_state_dict(torch.load(model_dir+"model_bce_0_2600.pt"))
# optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr = 0.001)
net_size = sum([i.numel() for i in net.parameters()])
print(net_size)

iters_per_checkpoint = 60
for epoch in range(10000):  # loop over the dataset multiple times
    running_loss = 0.0
    for i,data in enumerate(train_loader,0):
        batch_images,batch_impulses,batch_gt_responses,batch_one_hot = data
        print(batch_gt_responses.squeeze().sum(-1).sum(-1))
        batch_images,batch_impulses,batch_gt_responses,batch_one_hot = batch_images.cuda(),batch_impulses.cuda(),batch_gt_responses.cuda(),batch_one_hot.cuda()
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        pred_class,pred_mask = net([batch_images,batch_impulses])
        # we are giving no weighting for classes...
        class_loss,mask_loss = model_lib.loss_criterion(pred_class,batch_one_hot,pred_mask,batch_gt_responses)
        loss = class_loss+0*mask_loss
        print(class_loss.item(),mask_loss.item())
        loss.backward()
        optimizer.step()
        running_loss += loss.item() 
        if i % iters_per_checkpoint == 0:
            print("batch: ",i,"epoch: ",epoch, "loss: %0.5f" % (running_loss/iters_per_checkpoint))
            torch.save(net.state_dict(),model_dir+("model_mask_vgg_%d_%d.pt")%(0,0))
            running_loss = 0.0
    # print("batch: %d time:%0.3f sec" %(i, end-start)); print(loss.item())
print('Finished Training')

15944252
tensor([   145.,   1178.,  18766.,    542.,   3718.,   1025.,   9067.,
          2176.,   9749.,   3077.,   1915.,    787.,   4956.,    733.,
            77.,    172.,  33145.,   8958.,   1617.,   9848.,   2671.,
           217.,    276.,   4578.,   7293.,    559.,   1854.,    281.,
           441.,    174.,   3799.,   1174.])
0.02026868239045143 1.253913164138794
batch:  0 epoch:  0 loss: 0.00034
tensor([   611.,    643.,   1948.,   5856.,   1032.,  10504.,    248.,
           291.,   6620.,    897.,   6721.,   1258.,    436.,   1789.,
          1779.,     91.,    815.,   5140.,   1120.,  30440.,    252.,
         33035.,   7408.,     46.,    454.,   2543.,    288.,   1348.,
          1363.,    252.,    102.,    497.])
0.01579604111611843 1.1895676851272583
tensor([ 13716.,    892.,    487.,   2814.,   6784.,  27277.,   1440.,
           931.,    351.,   1555.,   1277.,   2862.,    184.,   1814.,
         35243.,    595.,   1824.,    217.,    645.,   2330.,    853.,
         

tensor([    83.,   1069.,    941.,    230.,    281.,    553.,     56.,
          6529.,   1465.,     61.,  22818.,    198.,   6982.,     94.,
          2656.,   4059.,   4539.,    200.,   5924.,    644.,  21533.,
          1077.,    994.,   2394.,   3394.,  10150.,  30348.,   5072.,
           615.,     96.,   4927.,    231.])
0.021825211122632027 1.256609320640564
tensor([  7530.,   3499.,   3688.,    340.,   3288.,   1993.,   1528.,
           107.,   1925.,   1980.,   1995.,   2010.,    175.,   2944.,
          1011.,    615.,    975.,   2738.,     80.,    479.,    884.,
           598.,   3781.,    341.,   6170.,    532.,   5055.,   3217.,
         10093.,   1649.,     49.,   2793.])
0.015352061949670315 1.3553285598754883
tensor([   128.,   1029.,   4380.,  18512.,  12964.,    150.,  22708.,
           400.,   3260.,   1222.,   1577.,    946.,   1224.,   1017.,
           176.,   8993.,    290.,  33662.,   7607.,  12021.,    614.,
          2627.,   1296.,   3324.,    285.,    622

tensor([  2948.,   1511.,    352.,    542.,     66.,   8220.,   1430.,
            86.,   5749.,   1664.,    360.,   3716.,   1901.,   3565.,
           779.,   3527.,   1168.,  15963.,   2198.,    130.,    180.,
          1237.,   1786.,   1296.,   4698.,   2918.,   2806.,    142.,
          1047.,   4204.,   7540.,   2961.])
0.014335543848574162 1.313193917274475
tensor([  1531.,    711.,    343.,     98.,   2434.,   4617.,    147.,
          8399.,    644.,   2780.,    168.,   6177.,    260.,  26353.,
           272.,   6807.,   1329.,   6980.,   5570.,    316.,    264.,
          4982.,   1486.,   1667.,   2518.,    997.,   5896.,   1714.,
          3585.,    457.,    145.,     36.])
0.013572797179222107 1.3137270212173462
tensor([  1264.,   2573.,   1575.,   5953.,   2045.,   1385.,   3713.,
          1080.,   6025.,   2857.,   1059.,   2279.,    795.,   8039.,
          1188.,   2385.,   3662.,    152.,     76.,  12797.,    658.,
          5803.,   5577.,    361.,   4542.,   2967

tensor([ 13463.,   2809.,   5565.,    279.,    296.,    204.,   2368.,
          1175.,     60.,   5366.,    414.,  23616.,  18171.,   1930.,
          2623.,    500.,   6105.,   7023.,     82.,    213.,  14959.,
          4511.,   4282.,   2372.,   2775.,    240.,   1909.,    614.,
          3577.,   3858.,    905.,    566.])
0.15275375545024872 1.2275389432907104
tensor([   665.,   3524.,    118.,    446.,    180.,   2552.,    249.,
          2531.,   5363.,   7686.,    170.,  12498.,   3133.,   2199.,
         28119.,    729.,   9047.,    211.,    821.,   1960.,     62.,
           222.,   1678.,    126.,    356.,   2529.,   1552.,   1901.,
          2808.,    757.,   1876.,   1379.])
0.15298143029212952 1.2160931825637817


Process Process-1:
Traceback (most recent call last):
  File "/home/aravind/anaconda3/envs/myenv/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/aravind/anaconda3/envs/myenv/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/aravind/anaconda3/envs/myenv/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 57, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/aravind/anaconda3/envs/myenv/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 57, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/aravind/re/model_lib.py", line 43, in __getitem__
    image, masks, is_crowd = self.load_image_gt(class_id, cwid)
  File "/home/aravind/re/model_lib.py", line 185, in load_image_gt
    image = self.read_image(image_id)
  File "/home/aravind/re/model_lib.py", line 165, in read_im

KeyboardInterrupt: 