In [1]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np

from chainercv.datasets import VOCBboxDataset
from chainercv.datasets import voc_bbox_label_names
from chainercv.visualizations import vis_bbox
import torch
import torchvision


from model_init import faster_rcnn
from model.utils.transform import image_normalize


trainval_dataset_original = VOCBboxDataset(year='2007', split='trainval')
test_dataset = VOCBboxDataset(year='2007', split='test')




  from ._conv import register_converters as _register_converters


In [2]:
def adjust_learning_rate(optimizer, epoch, init_lr, lr_decay_factor=0.1, lr_decay_epoch=50):
    """Sets the learning rate to the initial LR decayed by lr_decay_factor every lr_decay_epoch epochs"""
    if epoch//lr_decay_epoch > 0:

        lr = init_lr * lr_decay_factor**(epoch//lr_decay_epoch)
        print('LR is set to {}'.format(lr))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


epoch_and_initial = np.load('epoch.npy')
epoch = epoch_and_initial[0]+1
initial = epoch_and_initial[1]

# epoch = 0
# initial = True

if initial:
    vgg16 = torchvision.models.vgg16(pretrained=True)
    model = faster_rcnn(20, vgg16)           #n_class = 20, model is an object of the class  Faster_RCNN
    loss_list = []
else:
    model = torch.load('weights.pt')
    loss_list = list(np.load('loss.npy'))


if torch.cuda.is_available():
    model = model.cuda()

optimizer = model.get_optimizer(is_adam=False)

model.train()

idx = np.array(np.arange(5011))

# for epoch in range(15):
while True:
    adjust_learning_rate(optimizer, epoch, 0.001, lr_decay_epoch=50)
    train_idx = np.random.choice(idx, 1000)
    
    trainval_dataset = trainval_dataset_original[train_idx]
    
    for i in range(len(trainval_dataset)):
#     for i in range(20):
        img, bbox, label = trainval_dataset[i]
        img = img/255

        loss = model.loss(img, bbox, label)
#         print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_value = float(loss.cpu().data.numpy())
        loss_list.append(loss_value)

        

        print('[epoch:{}]  [batch:{}/{}]  [sample_loss:{:.4f}]  '\
              .format(epoch, i, len(trainval_dataset), loss_value))
        
    torch.save(model, 'weights.pt')
    np.save('loss.npy', loss_list)
    np.save('epoch.npy', [epoch, False])
    epoch += 1


LR is set to 0.0001
[epoch:79]  [batch:0/1000]  [sample_loss:0.2767]  
[epoch:79]  [batch:1/1000]  [sample_loss:1.3417]  
[epoch:79]  [batch:2/1000]  [sample_loss:1.1162]  
[epoch:79]  [batch:3/1000]  [sample_loss:1.4929]  
[epoch:79]  [batch:4/1000]  [sample_loss:1.3981]  
[epoch:79]  [batch:5/1000]  [sample_loss:0.2591]  
[epoch:79]  [batch:6/1000]  [sample_loss:1.3183]  
[epoch:79]  [batch:7/1000]  [sample_loss:1.2298]  
[epoch:79]  [batch:8/1000]  [sample_loss:0.6005]  
[epoch:79]  [batch:9/1000]  [sample_loss:0.3033]  
[epoch:79]  [batch:10/1000]  [sample_loss:0.7632]  
[epoch:79]  [batch:11/1000]  [sample_loss:0.2392]  
[epoch:79]  [batch:12/1000]  [sample_loss:1.4216]  
[epoch:79]  [batch:13/1000]  [sample_loss:0.7181]  
[epoch:79]  [batch:14/1000]  [sample_loss:0.4915]  
[epoch:79]  [batch:15/1000]  [sample_loss:1.1628]  
[epoch:79]  [batch:16/1000]  [sample_loss:0.8798]  
[epoch:79]  [batch:17/1000]  [sample_loss:0.4545]  
[epoch:79]  [batch:18/1000]  [sample_loss:0.3600]  
[e

[epoch:79]  [batch:157/1000]  [sample_loss:0.8857]  
[epoch:79]  [batch:158/1000]  [sample_loss:0.4027]  
[epoch:79]  [batch:159/1000]  [sample_loss:0.8532]  
[epoch:79]  [batch:160/1000]  [sample_loss:1.1281]  
[epoch:79]  [batch:161/1000]  [sample_loss:0.7067]  
[epoch:79]  [batch:162/1000]  [sample_loss:0.3066]  
[epoch:79]  [batch:163/1000]  [sample_loss:0.9584]  
[epoch:79]  [batch:164/1000]  [sample_loss:0.5237]  
[epoch:79]  [batch:165/1000]  [sample_loss:1.3726]  
[epoch:79]  [batch:166/1000]  [sample_loss:1.2596]  
[epoch:79]  [batch:167/1000]  [sample_loss:0.2160]  
[epoch:79]  [batch:168/1000]  [sample_loss:0.6524]  
[epoch:79]  [batch:169/1000]  [sample_loss:0.1733]  
[epoch:79]  [batch:170/1000]  [sample_loss:0.2684]  
[epoch:79]  [batch:171/1000]  [sample_loss:0.7597]  
[epoch:79]  [batch:172/1000]  [sample_loss:0.6606]  
[epoch:79]  [batch:173/1000]  [sample_loss:0.9848]  
[epoch:79]  [batch:174/1000]  [sample_loss:0.4277]  
[epoch:79]  [batch:175/1000]  [sample_loss:0.2

[epoch:79]  [batch:312/1000]  [sample_loss:1.0986]  
[epoch:79]  [batch:313/1000]  [sample_loss:0.2886]  
[epoch:79]  [batch:314/1000]  [sample_loss:0.3055]  
[epoch:79]  [batch:315/1000]  [sample_loss:0.7936]  
[epoch:79]  [batch:316/1000]  [sample_loss:0.2623]  
[epoch:79]  [batch:317/1000]  [sample_loss:0.1717]  
[epoch:79]  [batch:318/1000]  [sample_loss:0.5425]  
[epoch:79]  [batch:319/1000]  [sample_loss:1.0225]  
[epoch:79]  [batch:320/1000]  [sample_loss:0.9148]  
[epoch:79]  [batch:321/1000]  [sample_loss:0.2702]  
[epoch:79]  [batch:322/1000]  [sample_loss:0.8475]  
[epoch:79]  [batch:323/1000]  [sample_loss:0.6773]  
[epoch:79]  [batch:324/1000]  [sample_loss:0.3120]  
[epoch:79]  [batch:325/1000]  [sample_loss:0.2988]  
[epoch:79]  [batch:326/1000]  [sample_loss:0.8024]  
[epoch:79]  [batch:327/1000]  [sample_loss:0.7877]  
[epoch:79]  [batch:328/1000]  [sample_loss:1.0284]  
[epoch:79]  [batch:329/1000]  [sample_loss:0.6810]  
[epoch:79]  [batch:330/1000]  [sample_loss:0.1

[epoch:79]  [batch:467/1000]  [sample_loss:1.4658]  
[epoch:79]  [batch:468/1000]  [sample_loss:0.9190]  
[epoch:79]  [batch:469/1000]  [sample_loss:1.4846]  
[epoch:79]  [batch:470/1000]  [sample_loss:0.4729]  
[epoch:79]  [batch:471/1000]  [sample_loss:0.3997]  
[epoch:79]  [batch:472/1000]  [sample_loss:0.4075]  
[epoch:79]  [batch:473/1000]  [sample_loss:0.4454]  
[epoch:79]  [batch:474/1000]  [sample_loss:0.9680]  
[epoch:79]  [batch:475/1000]  [sample_loss:0.4638]  
[epoch:79]  [batch:476/1000]  [sample_loss:0.3177]  
[epoch:79]  [batch:477/1000]  [sample_loss:0.9090]  
[epoch:79]  [batch:478/1000]  [sample_loss:1.0358]  
[epoch:79]  [batch:479/1000]  [sample_loss:1.3099]  
[epoch:79]  [batch:480/1000]  [sample_loss:0.8231]  
[epoch:79]  [batch:481/1000]  [sample_loss:1.6484]  
[epoch:79]  [batch:482/1000]  [sample_loss:0.1920]  
[epoch:79]  [batch:483/1000]  [sample_loss:0.2232]  
[epoch:79]  [batch:484/1000]  [sample_loss:0.2591]  
[epoch:79]  [batch:485/1000]  [sample_loss:0.2

[epoch:79]  [batch:622/1000]  [sample_loss:0.6567]  
[epoch:79]  [batch:623/1000]  [sample_loss:0.5653]  
[epoch:79]  [batch:624/1000]  [sample_loss:1.1841]  
[epoch:79]  [batch:625/1000]  [sample_loss:0.1780]  
[epoch:79]  [batch:626/1000]  [sample_loss:0.8230]  
[epoch:79]  [batch:627/1000]  [sample_loss:0.3575]  
[epoch:79]  [batch:628/1000]  [sample_loss:0.8620]  
[epoch:79]  [batch:629/1000]  [sample_loss:0.6935]  
[epoch:79]  [batch:630/1000]  [sample_loss:0.4098]  
[epoch:79]  [batch:631/1000]  [sample_loss:0.1722]  
[epoch:79]  [batch:632/1000]  [sample_loss:1.2505]  
[epoch:79]  [batch:633/1000]  [sample_loss:0.8504]  
[epoch:79]  [batch:634/1000]  [sample_loss:0.2923]  
[epoch:79]  [batch:635/1000]  [sample_loss:0.3337]  
[epoch:79]  [batch:636/1000]  [sample_loss:0.8426]  
[epoch:79]  [batch:637/1000]  [sample_loss:0.4799]  
[epoch:79]  [batch:638/1000]  [sample_loss:0.9031]  
[epoch:79]  [batch:639/1000]  [sample_loss:0.8829]  
[epoch:79]  [batch:640/1000]  [sample_loss:0.2

[epoch:79]  [batch:777/1000]  [sample_loss:0.6761]  
[epoch:79]  [batch:778/1000]  [sample_loss:1.5553]  
[epoch:79]  [batch:779/1000]  [sample_loss:0.1707]  
[epoch:79]  [batch:780/1000]  [sample_loss:0.3203]  
[epoch:79]  [batch:781/1000]  [sample_loss:0.3338]  
[epoch:79]  [batch:782/1000]  [sample_loss:0.4556]  
[epoch:79]  [batch:783/1000]  [sample_loss:1.1318]  
[epoch:79]  [batch:784/1000]  [sample_loss:0.4531]  
[epoch:79]  [batch:785/1000]  [sample_loss:1.5666]  
[epoch:79]  [batch:786/1000]  [sample_loss:1.0480]  
[epoch:79]  [batch:787/1000]  [sample_loss:0.3297]  
[epoch:79]  [batch:788/1000]  [sample_loss:0.3516]  
[epoch:79]  [batch:789/1000]  [sample_loss:0.8587]  
[epoch:79]  [batch:790/1000]  [sample_loss:1.3168]  
[epoch:79]  [batch:791/1000]  [sample_loss:0.5895]  
[epoch:79]  [batch:792/1000]  [sample_loss:0.2408]  
[epoch:79]  [batch:793/1000]  [sample_loss:0.9940]  
[epoch:79]  [batch:794/1000]  [sample_loss:1.5245]  
[epoch:79]  [batch:795/1000]  [sample_loss:0.6

[epoch:79]  [batch:932/1000]  [sample_loss:1.4385]  
[epoch:79]  [batch:933/1000]  [sample_loss:0.3048]  
[epoch:79]  [batch:934/1000]  [sample_loss:0.5142]  
[epoch:79]  [batch:935/1000]  [sample_loss:0.4077]  
[epoch:79]  [batch:936/1000]  [sample_loss:1.3580]  
[epoch:79]  [batch:937/1000]  [sample_loss:0.1924]  
[epoch:79]  [batch:938/1000]  [sample_loss:1.0182]  
[epoch:79]  [batch:939/1000]  [sample_loss:0.4157]  
[epoch:79]  [batch:940/1000]  [sample_loss:0.2665]  
[epoch:79]  [batch:941/1000]  [sample_loss:1.4169]  
[epoch:79]  [batch:942/1000]  [sample_loss:0.6295]  
[epoch:79]  [batch:943/1000]  [sample_loss:0.6018]  
[epoch:79]  [batch:944/1000]  [sample_loss:1.5589]  
[epoch:79]  [batch:945/1000]  [sample_loss:0.2367]  
[epoch:79]  [batch:946/1000]  [sample_loss:0.9149]  
[epoch:79]  [batch:947/1000]  [sample_loss:0.7830]  
[epoch:79]  [batch:948/1000]  [sample_loss:0.2428]  
[epoch:79]  [batch:949/1000]  [sample_loss:0.2272]  
[epoch:79]  [batch:950/1000]  [sample_loss:0.8

[epoch:80]  [batch:89/1000]  [sample_loss:0.8356]  
[epoch:80]  [batch:90/1000]  [sample_loss:0.5302]  
[epoch:80]  [batch:91/1000]  [sample_loss:0.2163]  
[epoch:80]  [batch:92/1000]  [sample_loss:0.5834]  
[epoch:80]  [batch:93/1000]  [sample_loss:1.7243]  
[epoch:80]  [batch:94/1000]  [sample_loss:0.1581]  
[epoch:80]  [batch:95/1000]  [sample_loss:1.2358]  
[epoch:80]  [batch:96/1000]  [sample_loss:0.3180]  
[epoch:80]  [batch:97/1000]  [sample_loss:0.3965]  
[epoch:80]  [batch:98/1000]  [sample_loss:0.5049]  
[epoch:80]  [batch:99/1000]  [sample_loss:0.3062]  
[epoch:80]  [batch:100/1000]  [sample_loss:0.4281]  
[epoch:80]  [batch:101/1000]  [sample_loss:0.1271]  
[epoch:80]  [batch:102/1000]  [sample_loss:0.6967]  
[epoch:80]  [batch:103/1000]  [sample_loss:1.1924]  
[epoch:80]  [batch:104/1000]  [sample_loss:0.5989]  
[epoch:80]  [batch:105/1000]  [sample_loss:1.0062]  
[epoch:80]  [batch:106/1000]  [sample_loss:0.8417]  
[epoch:80]  [batch:107/1000]  [sample_loss:0.2707]  
[epo

[epoch:80]  [batch:244/1000]  [sample_loss:1.4176]  
[epoch:80]  [batch:245/1000]  [sample_loss:0.1912]  
[epoch:80]  [batch:246/1000]  [sample_loss:0.4793]  
[epoch:80]  [batch:247/1000]  [sample_loss:0.5802]  
[epoch:80]  [batch:248/1000]  [sample_loss:0.3479]  
[epoch:80]  [batch:249/1000]  [sample_loss:1.3255]  
[epoch:80]  [batch:250/1000]  [sample_loss:1.4896]  
[epoch:80]  [batch:251/1000]  [sample_loss:0.2000]  
[epoch:80]  [batch:252/1000]  [sample_loss:1.1681]  
[epoch:80]  [batch:253/1000]  [sample_loss:0.4906]  
[epoch:80]  [batch:254/1000]  [sample_loss:1.0923]  
[epoch:80]  [batch:255/1000]  [sample_loss:0.6764]  
[epoch:80]  [batch:256/1000]  [sample_loss:0.3037]  
[epoch:80]  [batch:257/1000]  [sample_loss:1.1341]  
[epoch:80]  [batch:258/1000]  [sample_loss:1.2612]  
[epoch:80]  [batch:259/1000]  [sample_loss:0.2898]  
[epoch:80]  [batch:260/1000]  [sample_loss:1.1574]  
[epoch:80]  [batch:261/1000]  [sample_loss:0.1488]  
[epoch:80]  [batch:262/1000]  [sample_loss:1.0

[epoch:80]  [batch:399/1000]  [sample_loss:0.3737]  
[epoch:80]  [batch:400/1000]  [sample_loss:0.8178]  
[epoch:80]  [batch:401/1000]  [sample_loss:1.0699]  
[epoch:80]  [batch:402/1000]  [sample_loss:0.6042]  
[epoch:80]  [batch:403/1000]  [sample_loss:0.6704]  
[epoch:80]  [batch:404/1000]  [sample_loss:0.7469]  
[epoch:80]  [batch:405/1000]  [sample_loss:1.8037]  
[epoch:80]  [batch:406/1000]  [sample_loss:1.5981]  
[epoch:80]  [batch:407/1000]  [sample_loss:0.8399]  
[epoch:80]  [batch:408/1000]  [sample_loss:0.3432]  
[epoch:80]  [batch:409/1000]  [sample_loss:0.4561]  
[epoch:80]  [batch:410/1000]  [sample_loss:0.8645]  
[epoch:80]  [batch:411/1000]  [sample_loss:0.5484]  
[epoch:80]  [batch:412/1000]  [sample_loss:0.5432]  
[epoch:80]  [batch:413/1000]  [sample_loss:0.3786]  
[epoch:80]  [batch:414/1000]  [sample_loss:1.4349]  
[epoch:80]  [batch:415/1000]  [sample_loss:0.9333]  
[epoch:80]  [batch:416/1000]  [sample_loss:0.2962]  
[epoch:80]  [batch:417/1000]  [sample_loss:0.5

[epoch:80]  [batch:554/1000]  [sample_loss:0.4065]  
[epoch:80]  [batch:555/1000]  [sample_loss:0.3688]  
[epoch:80]  [batch:556/1000]  [sample_loss:1.3305]  
[epoch:80]  [batch:557/1000]  [sample_loss:1.1425]  
[epoch:80]  [batch:558/1000]  [sample_loss:0.3210]  
[epoch:80]  [batch:559/1000]  [sample_loss:0.7186]  
[epoch:80]  [batch:560/1000]  [sample_loss:1.2397]  
[epoch:80]  [batch:561/1000]  [sample_loss:0.3736]  
[epoch:80]  [batch:562/1000]  [sample_loss:1.1546]  
[epoch:80]  [batch:563/1000]  [sample_loss:0.2364]  
[epoch:80]  [batch:564/1000]  [sample_loss:0.9587]  
[epoch:80]  [batch:565/1000]  [sample_loss:0.6423]  
[epoch:80]  [batch:566/1000]  [sample_loss:0.2507]  
[epoch:80]  [batch:567/1000]  [sample_loss:0.2617]  
[epoch:80]  [batch:568/1000]  [sample_loss:1.5340]  
[epoch:80]  [batch:569/1000]  [sample_loss:0.4991]  
[epoch:80]  [batch:570/1000]  [sample_loss:0.1474]  
[epoch:80]  [batch:571/1000]  [sample_loss:1.2823]  
[epoch:80]  [batch:572/1000]  [sample_loss:0.2

[epoch:80]  [batch:709/1000]  [sample_loss:1.5745]  
[epoch:80]  [batch:710/1000]  [sample_loss:0.5059]  
[epoch:80]  [batch:711/1000]  [sample_loss:1.0791]  
[epoch:80]  [batch:712/1000]  [sample_loss:0.3758]  
[epoch:80]  [batch:713/1000]  [sample_loss:0.9610]  
[epoch:80]  [batch:714/1000]  [sample_loss:0.4100]  
[epoch:80]  [batch:715/1000]  [sample_loss:0.2829]  
[epoch:80]  [batch:716/1000]  [sample_loss:0.6924]  
[epoch:80]  [batch:717/1000]  [sample_loss:0.8232]  
[epoch:80]  [batch:718/1000]  [sample_loss:1.4207]  
[epoch:80]  [batch:719/1000]  [sample_loss:0.7508]  
[epoch:80]  [batch:720/1000]  [sample_loss:0.3676]  
[epoch:80]  [batch:721/1000]  [sample_loss:0.3816]  
[epoch:80]  [batch:722/1000]  [sample_loss:0.2068]  
[epoch:80]  [batch:723/1000]  [sample_loss:0.3198]  
[epoch:80]  [batch:724/1000]  [sample_loss:0.1826]  
[epoch:80]  [batch:725/1000]  [sample_loss:0.1356]  
[epoch:80]  [batch:726/1000]  [sample_loss:0.9853]  
[epoch:80]  [batch:727/1000]  [sample_loss:0.4

[epoch:80]  [batch:864/1000]  [sample_loss:0.2065]  
[epoch:80]  [batch:865/1000]  [sample_loss:0.4283]  
[epoch:80]  [batch:866/1000]  [sample_loss:0.5287]  
[epoch:80]  [batch:867/1000]  [sample_loss:1.1831]  
[epoch:80]  [batch:868/1000]  [sample_loss:0.7238]  
[epoch:80]  [batch:869/1000]  [sample_loss:1.0112]  
[epoch:80]  [batch:870/1000]  [sample_loss:1.4159]  
[epoch:80]  [batch:871/1000]  [sample_loss:0.8406]  
[epoch:80]  [batch:872/1000]  [sample_loss:1.1665]  
[epoch:80]  [batch:873/1000]  [sample_loss:0.0915]  
[epoch:80]  [batch:874/1000]  [sample_loss:0.9278]  
[epoch:80]  [batch:875/1000]  [sample_loss:0.2202]  
[epoch:80]  [batch:876/1000]  [sample_loss:0.4535]  
[epoch:80]  [batch:877/1000]  [sample_loss:1.9576]  
[epoch:80]  [batch:878/1000]  [sample_loss:0.9582]  
[epoch:80]  [batch:879/1000]  [sample_loss:0.8253]  
[epoch:80]  [batch:880/1000]  [sample_loss:1.0474]  
[epoch:80]  [batch:881/1000]  [sample_loss:1.3969]  
[epoch:80]  [batch:882/1000]  [sample_loss:0.7

[epoch:81]  [batch:19/1000]  [sample_loss:0.8133]  
[epoch:81]  [batch:20/1000]  [sample_loss:0.4709]  
[epoch:81]  [batch:21/1000]  [sample_loss:0.2059]  
[epoch:81]  [batch:22/1000]  [sample_loss:1.0973]  
[epoch:81]  [batch:23/1000]  [sample_loss:0.5343]  
[epoch:81]  [batch:24/1000]  [sample_loss:0.2965]  
[epoch:81]  [batch:25/1000]  [sample_loss:0.2613]  
[epoch:81]  [batch:26/1000]  [sample_loss:0.5166]  
[epoch:81]  [batch:27/1000]  [sample_loss:1.2155]  
[epoch:81]  [batch:28/1000]  [sample_loss:0.4650]  
[epoch:81]  [batch:29/1000]  [sample_loss:0.4740]  
[epoch:81]  [batch:30/1000]  [sample_loss:1.3047]  
[epoch:81]  [batch:31/1000]  [sample_loss:0.7466]  
[epoch:81]  [batch:32/1000]  [sample_loss:1.0415]  
[epoch:81]  [batch:33/1000]  [sample_loss:0.6763]  
[epoch:81]  [batch:34/1000]  [sample_loss:0.6383]  
[epoch:81]  [batch:35/1000]  [sample_loss:0.1815]  
[epoch:81]  [batch:36/1000]  [sample_loss:1.0802]  
[epoch:81]  [batch:37/1000]  [sample_loss:0.9744]  
[epoch:81]  

KeyboardInterrupt: 

In [3]:
model = torch.load('weights.pt')
model.eval()
# for i in range(len(test_dataset)):
for i in range(20):
    print(i)
    img, _, _ = test_dataset[i]
    imgx = img/255
    bbox_out, class_out, prob_out = model.predict(imgx, prob_threshold=0.9)

    #vis_bbox(img, bbox, label=None, score=None, label_names=None, instance_colors=None, alpha=1.0, linewidth=3.0, ax=None)
    vis_bbox(img, bbox_out, class_out, prob_out,label_names=voc_bbox_label_names) 
#     plt.show()
    fig = plt.gcf()
    fig.set_size_inches(11, 5)
    fig.savefig('test_'+str(i)+'.jpg', dpi=100)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
