In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

sys.path.append('../..')
sys.path.append('../classification/')

from dataset.dataset.image import ImagesBatch
from dataset import Dataset, DatasetIndex, B, V
from detection_mnist import DetectionMnist
from faster_rcnn import FRCNNModel
%matplotlib inline

In [2]:
IMAGE_SHAPE = (64, 64)
MAP_SHAPE = (8, 8)
N_ANCHORS = MAP_SHAPE[0] * MAP_SHAPE[1] * 9
MNIST_PER_IMAGE = 5
BATCH_SIZE = 64

In [3]:
ind = DatasetIndex(np.arange(20000))          
mnist = Dataset(ind, batch_class=DetectionMnist)   
mnist.cv_split([0.9, 0.1]) 

In [4]:
import sys
import tensorflow as tf

sys.path.append('../task_03')

from dataset.dataset.models.tf.layers import conv_block
from dataset.dataset.models.tf import TFModel
from vgg import VGGModel

class FRCNNModel(TFModel):
    """LinkNet as TFModel"""
    def _build(self, *args, **kwargs):

        #n_classes = self.num_channels('masks')
        _, inp2 = self._make_inputs(['images'])
        data_format = self.data_format('images')
        dim = self.spatial_dim('images')
        b_norm = self.get_from_config('batch_norm', True)

        conv = {'data_format': data_format}
        batch_norm = {'momentum': 0.1}

        kwargs = {'conv': conv, 'batch_norm': batch_norm}
        
        inp = inp2['images']
        with tf.variable_scope('FRCNN'): # pylint: disable=not-context-manager
            net = VGGModel.fully_conv_block(dim, inp, b_norm, 'VGG7', **kwargs)
            net = conv_block(dim, net, 512, 3, 'ca', **kwargs)
            reg = conv_block(dim, net, 4*9, 1, 'ca', **kwargs)
            cls = conv_block(dim, net, 1*9, 1, 'c', **kwargs)

        reg = tf.reshape(reg, [-1, N_ANCHORS, 4], name='RoI')
        cls = tf.reshape(cls, [-1, N_ANCHORS], name='IoU')
        true_cls = tf.placeholder(tf.float32, shape = [None, N_ANCHORS], name='proposal_targets')
        true_reg = tf.placeholder(tf.float32, shape = [None, N_ANCHORS, 4], name='bbox_targets')
        
        loss = self.rpn_loss(reg, cls, true_reg, true_cls)
        loss = tf.identity(loss, name='loss')
        tf.losses.add_loss(loss)
    
    def rpn_loss(self, reg, cls, true_reg, true_cls):
        cls_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=true_cls, logits=cls) / MNIST_PER_IMAGE   
        cls_loss = tf.reduce_sum(cls_loss, axis=-1)
        cls_loss = tf.reduce_mean(cls_loss, name='cls_loss')

        sums = tf.reduce_sum((true_reg - reg) ** 2, axis=-1)
        reg_mask = tf.cast(true_cls, dtype=tf.float32)
        reg_mask = tf.reshape(reg_mask, shape=[-1, N_ANCHORS])

        reg_loss = sums * true_cls
        reg_loss = tf.reduce_sum(reg_loss, axis=-1)
        reg_loss = tf.reduce_mean(reg_loss, name='reg_loss')
        
        loss = reg_loss * 10 + cls_loss
        return loss

In [5]:
placeholders_config = {
                       'images': {'shape': IMAGE_SHAPE + (1,),
                                 'dtype': 'float32',
                                 'data_format': 'channels_last',
                                 'name': 'reshaped_images'},
                       }

In [6]:
model_config = {'inputs': placeholders_config,
                'batch_norm': True,
                'optimizer': 'Adam'}

In [7]:
train_feed_dict = {'images': B('images'),
                   'proposal_targets': B('clsf'),
                   'bbox_targets': B('reg')}        

test_feed_dict = {'images': B('images'),
                   'proposal_targets': B('clsf'),
                   'bbox_targets': B('reg')}  

In [None]:
train_pp = (mnist.train.p
            .init_model('static', FRCNNModel, 'frcnn', config=model_config)
            .init_variable('loss_history', init_on_each_run=list)
            .init_variable('IoU_predictions', init_on_each_run=list)
            .load_images()
            .generate_multimnist_images(image_shape=IMAGE_SHAPE, max_dig=MNIST_PER_IMAGE)
            .create_anchors(IMAGE_SHAPE, MAP_SHAPE)
            .create_reg_cls()
            .param_reg()
            .train_model('frcnn', 
                         fetches=['loss', 'reg_loss', 'cls_loss'],
                         feed_dict=train_feed_dict,
                         save_to=V('loss_history'), 
                         mode='a'
            ))

In [None]:
for i in range(500):
    batch = train_pp.next_batch(BATCH_SIZE, n_epochs=None, shuffle=True)
    if (i+1) % 10 == 0:
        #print(batch.data.bbox_batch_sizes)
        print('Iter {}: {}'.format(i+1, train_pp.get_variable('loss_history')[-1]))

Iter 10: [16.120089, 1.0291331, 5.8287587]
Iter 20: [14.707017, 1.0664941, 4.0420761]
Iter 30: [13.128283, 1.0180069, 2.9482136]
Iter 40: [12.571494, 0.98477441, 2.7237501]
Iter 50: [13.363638, 1.0740654, 2.6229837]
Iter 60: [10.742109, 0.85752404, 2.1668689]
Iter 70: [11.860249, 0.95682609, 2.2919879]
Iter 80: [12.33462, 0.99792016, 2.3554187]
Iter 90: [13.056274, 1.0633078, 2.423197]
Iter 100: [13.954223, 1.1554145, 2.4000778]
Iter 110: [11.507006, 0.94614482, 2.0455565]
Iter 120: [11.549006, 0.94783378, 2.0706675]
Iter 130: [11.108252, 0.93359059, 1.7723455]
Iter 140: [11.489969, 0.96212578, 1.8687111]
Iter 150: [11.947775, 0.99760282, 1.9717464]
Iter 160: [10.60153, 0.87902176, 1.8113127]
Iter 170: [12.191529, 1.0286627, 1.9049022]
Iter 180: [11.639009, 0.97798145, 1.8591949]
Iter 190: [11.639383, 0.98386306, 1.8007524]
Iter 200: [11.43763, 0.95782238, 1.8594055]
Iter 210: [11.11263, 0.93504912, 1.7621385]
Iter 220: [11.1482, 0.94081646, 1.7400347]
Iter 230: [12.695911, 1.0784878, 

In [None]:
reg_loss = np.array(train_pp.get_variable('loss_history'))[:,1]
cls_loss = np.array(train_pp.get_variable('loss_history'))[:,2]
total_loss = np.array(train_pp.get_variable('loss_history'))[:,0]

In [None]:
plt.plot(np.log(total_loss), label='Total loss')
plt.plot(np.log(reg_loss), label='Reg loss')
plt.plot(np.log(cls_loss), label='Cls loss')
plt.legend()

In [None]:
test_pp = (mnist.test.p
            .import_model('frcnn', train_pp)
            .load_images()
            .generate_multimnist_images(image_shape=IMAGE_SHAPE, max_dig=MNIST_PER_IMAGE)
            .create_anchors(IMAGE_SHAPE, MAP_SHAPE)
            .create_reg_cls()
            .param_reg()
            .predict_model('frcnn', 
                           fetches=['RoI','IoU'],
                           feed_dict=test_feed_dict,
                           save_to=[B('RoI_predictions'), B('IoU_predictions')])
            .unparam_predictions())

batch = test_pp.next_batch(BATCH_SIZE)

In [None]:
i = 1
im = batch.data.images[i]
bboxes = batch.data.bboxes[i]

fig, ax = plt.subplots(1)

ax.imshow(im)

for bbox in bboxes:    
    rect = patches.Rectangle((bbox[1], bbox[0]), bbox[2], bbox[3] ,linewidth=1,edgecolor='g',facecolor='none')
    ax.add_patch(rect)
plt.show()

In [None]:
for i in range(10):
    im = batch.data.images[i]
    bboxes = batch.data.RoI_predictions[i]
    selected_RoI = batch.data.IoU_predictions[i]

    fig, ax = plt.subplots(1)

    ax.imshow(im)

    for bbox in bboxes[selected_RoI > 0.5]:    
        rect = patches.Rectangle((bbox[1], bbox[0]), bbox[2], bbox[3], linewidth=1, edgecolor='r',facecolor='none')
        ax.add_patch(rect)
    plt.show()