State: Array of image + coordinates of previous bbox
Action: Left, Right, None

In this first phase, we only pass a single number indicating the position at which the object is present. Along with this, we include the coordinates (in this case, just the position along the 1-D array) of the previous bounding box in the state. The action then gives the direction in which this previous bounding box should move in order to be closer to the actual bounding box.

This method just trains on single action trajectories i.e. given the previous bounding box and the current object position, it performs Q-learning for this single sample, which becomes the 'trajectory'. This is, in short, fully supervised i.e. we are providing the labels for each state-action pair.

Reward:
+1 if bbox moves in the direction of the object, or the object and bbox coincide and the bbox stays where it is
-1 otherwise

In [1]:
import numpy as np

In [2]:
positions = 10
actions = 3
states = 100 # 10 positions, and the prev bbox could correspond to any of the states
gamma = 0.9
alpha = 0.01
epsilon = 0.1

In [3]:
Q = np.random.rand(positions, positions, actions)

In [4]:
for i in xrange(100000):
    obj = np.random.randint(0, positions)
    prev = np.random.randint(0, positions)
    # Epsilon-greedy behaviour policy
    if np.random.rand() <= epsilon:
        a = np.random.randint(0, actions)
    else:
        a = np.argmax(Q[obj, prev])
    # Policy evaluation
    # left : 0, none : 1, right : 2
    current = prev - 1 if a == 0 else prev if a == 1 else prev + 1
    r = 1 if abs(obj - current) < abs(obj - prev) or abs(obj - current) == 0 else -1
    if current >= 0 and current < positions:
        Q[obj, prev, a] += alpha * (r + gamma * np.max(Q[obj, current]) - Q[obj, prev, a])
    else:
        Q[obj, prev, a] = 0.0
    

### For different number of episodes:

    Number of episodes       Training Accuracy
    1000                     0.65
    10000                    0.86
    100000                   1.0

In [5]:
# Testing

total = 0
correct = 0

for i in xrange(10000):
    obj = np.random.randint(0, positions)
    prev = np.random.randint(0, positions)
    a = np.argmax(Q[obj, prev])
    # Policy evaluation
    # left : 0, none : 1, right : 2
    current = max(prev - 1, 0) if a == 0 else prev if a == 1 else min(prev + 1, positions - 1)
    print 'Object: ' + str(obj) + ' prev: ' + str(prev) + ' current: ' + str(current)
    r = 1 if abs(obj - current) < abs(obj - prev) or abs(obj - current) == 0 else -1
    if r == 1:
        correct += 1
    total += 1
    
print 'Accuracy: ' + str(correct * 1.0 / total)

Object: 4 prev: 1 current: 2
Object: 1 prev: 2 current: 1
Object: 4 prev: 8 current: 7
Object: 6 prev: 3 current: 4
Object: 1 prev: 4 current: 3
Object: 1 prev: 1 current: 1
Object: 3 prev: 0 current: 1
Object: 2 prev: 5 current: 4
Object: 3 prev: 3 current: 3
Object: 5 prev: 5 current: 5
Object: 5 prev: 9 current: 8
Object: 4 prev: 7 current: 6
Object: 1 prev: 4 current: 3
Object: 3 prev: 7 current: 6
Object: 8 prev: 9 current: 8
Object: 3 prev: 4 current: 3
Object: 2 prev: 4 current: 3
Object: 3 prev: 6 current: 5
Object: 3 prev: 7 current: 6
Object: 1 prev: 4 current: 3
Object: 9 prev: 7 current: 8
Object: 0 prev: 1 current: 0
Object: 3 prev: 3 current: 3
Object: 7 prev: 9 current: 8
Object: 9 prev: 5 current: 6
Object: 5 prev: 6 current: 5
Object: 2 prev: 3 current: 2
Object: 3 prev: 9 current: 8
Object: 3 prev: 6 current: 5
Object: 8 prev: 6 current: 7
Object: 6 prev: 4 current: 5
Object: 3 prev: 3 current: 3
Object: 8 prev: 1 current: 2
Object: 0 prev: 2 current: 1
Object: 4 prev

Object: 0 prev: 0 current: 0
Object: 0 prev: 0 current: 0
Object: 0 prev: 2 current: 1
Object: 1 prev: 9 current: 8
Object: 0 prev: 7 current: 6
Object: 2 prev: 2 current: 2
Object: 9 prev: 1 current: 2
Object: 6 prev: 8 current: 7
Object: 3 prev: 0 current: 1
Object: 5 prev: 9 current: 8
Object: 9 prev: 4 current: 5
Object: 3 prev: 3 current: 3
Object: 3 prev: 4 current: 3
Object: 3 prev: 2 current: 3
Object: 9 prev: 6 current: 7
Object: 9 prev: 9 current: 9
Object: 5 prev: 6 current: 5
Object: 6 prev: 5 current: 6
Object: 3 prev: 6 current: 5
Object: 0 prev: 2 current: 1
Object: 2 prev: 4 current: 3
Object: 0 prev: 1 current: 0
Object: 9 prev: 1 current: 2
Object: 3 prev: 6 current: 5
Object: 1 prev: 3 current: 2
Object: 7 prev: 4 current: 5
Object: 4 prev: 6 current: 5
Object: 1 prev: 1 current: 1
Object: 5 prev: 1 current: 2
Object: 9 prev: 2 current: 3
Object: 0 prev: 0 current: 0
Object: 6 prev: 7 current: 6
Object: 9 prev: 4 current: 5
Object: 8 prev: 3 current: 4
Object: 4 prev

Object: 0 prev: 0 current: 0
Object: 3 prev: 6 current: 5
Object: 8 prev: 0 current: 1
Object: 6 prev: 7 current: 6
Object: 6 prev: 0 current: 1
Object: 6 prev: 1 current: 2
Object: 6 prev: 7 current: 6
Object: 7 prev: 6 current: 7
Object: 7 prev: 1 current: 2
Object: 3 prev: 8 current: 7
Object: 8 prev: 8 current: 8
Object: 1 prev: 5 current: 4
Object: 5 prev: 7 current: 6
Object: 6 prev: 5 current: 6
Object: 0 prev: 0 current: 0
Object: 7 prev: 0 current: 1
Object: 8 prev: 5 current: 6
Object: 6 prev: 4 current: 5
Object: 1 prev: 2 current: 1
Object: 4 prev: 3 current: 4
Object: 3 prev: 8 current: 7
Object: 5 prev: 8 current: 7
Object: 3 prev: 6 current: 5
Object: 8 prev: 8 current: 8
Object: 0 prev: 7 current: 6
Object: 6 prev: 9 current: 8
Object: 7 prev: 6 current: 7
Object: 5 prev: 4 current: 5
Object: 0 prev: 2 current: 1
Object: 2 prev: 3 current: 2
Object: 8 prev: 8 current: 8
Object: 6 prev: 0 current: 1
Object: 0 prev: 5 current: 4
Object: 6 prev: 0 current: 1
Object: 8 prev

Object: 4 prev: 0 current: 1
Object: 7 prev: 5 current: 6
Object: 0 prev: 1 current: 0
Object: 6 prev: 3 current: 4
Object: 9 prev: 0 current: 1
Object: 2 prev: 6 current: 5
Object: 4 prev: 7 current: 6
Object: 0 prev: 1 current: 0
Object: 9 prev: 5 current: 6
Object: 7 prev: 9 current: 8
Object: 5 prev: 0 current: 1
Object: 0 prev: 1 current: 0
Object: 4 prev: 3 current: 4
Object: 3 prev: 3 current: 3
Object: 2 prev: 2 current: 2
Object: 6 prev: 5 current: 6
Object: 9 prev: 6 current: 7
Object: 7 prev: 7 current: 7
Object: 2 prev: 8 current: 7
Object: 2 prev: 7 current: 6
Object: 2 prev: 6 current: 5
Object: 4 prev: 3 current: 4
Object: 2 prev: 9 current: 8
Object: 9 prev: 7 current: 8
Object: 7 prev: 5 current: 6
Object: 6 prev: 8 current: 7
Object: 7 prev: 9 current: 8
Object: 9 prev: 9 current: 9
Object: 7 prev: 5 current: 6
Object: 4 prev: 1 current: 2
Object: 9 prev: 7 current: 8
Object: 9 prev: 9 current: 9
Object: 7 prev: 5 current: 6
Object: 2 prev: 0 current: 1
Object: 1 prev

Object: 1 prev: 7 current: 6
Object: 4 prev: 3 current: 4
Object: 3 prev: 3 current: 3
Object: 7 prev: 1 current: 2
Object: 9 prev: 1 current: 2
Object: 3 prev: 5 current: 4
Object: 2 prev: 9 current: 8
Object: 8 prev: 9 current: 8
Object: 3 prev: 4 current: 3
Object: 6 prev: 3 current: 4
Object: 9 prev: 2 current: 3
Object: 6 prev: 0 current: 1
Object: 5 prev: 6 current: 5
Object: 9 prev: 0 current: 1
Object: 8 prev: 8 current: 8
Object: 0 prev: 3 current: 2
Object: 6 prev: 2 current: 3
Object: 4 prev: 2 current: 3
Object: 8 prev: 6 current: 7
Object: 0 prev: 7 current: 6
Object: 0 prev: 8 current: 7
Object: 0 prev: 8 current: 7
Object: 9 prev: 0 current: 1
Object: 6 prev: 7 current: 6
Object: 3 prev: 7 current: 6
Object: 5 prev: 3 current: 4
Object: 2 prev: 2 current: 2
Object: 0 prev: 7 current: 6
Object: 7 prev: 7 current: 7
Object: 7 prev: 4 current: 5
Object: 7 prev: 3 current: 4
Object: 3 prev: 5 current: 4
Object: 0 prev: 0 current: 0
Object: 8 prev: 4 current: 5
Object: 1 prev

Object: 2 prev: 0 current: 1
Object: 4 prev: 3 current: 4
Object: 8 prev: 0 current: 1
Object: 6 prev: 9 current: 8
Object: 2 prev: 7 current: 6
Object: 2 prev: 2 current: 2
Object: 3 prev: 7 current: 6
Object: 9 prev: 7 current: 8
Object: 8 prev: 9 current: 8
Object: 7 prev: 3 current: 4
Object: 9 prev: 7 current: 8
Object: 5 prev: 3 current: 4
Object: 2 prev: 0 current: 1
Object: 3 prev: 3 current: 3
Object: 3 prev: 4 current: 3
Object: 0 prev: 1 current: 0
Object: 8 prev: 0 current: 1
Object: 7 prev: 4 current: 5
Object: 9 prev: 1 current: 2
Object: 9 prev: 3 current: 4
Object: 5 prev: 8 current: 7
Object: 5 prev: 8 current: 7
Object: 2 prev: 2 current: 2
Object: 1 prev: 5 current: 4
Object: 8 prev: 7 current: 8
Object: 1 prev: 1 current: 1
Object: 5 prev: 3 current: 4
Object: 3 prev: 4 current: 3
Object: 8 prev: 7 current: 8
Object: 6 prev: 5 current: 6
Object: 0 prev: 4 current: 3
Object: 9 prev: 3 current: 4
Object: 7 prev: 4 current: 5
Object: 7 prev: 3 current: 4
Object: 3 prev

Object: 7 prev: 7 current: 7
Object: 4 prev: 9 current: 8
Object: 3 prev: 4 current: 3
Object: 5 prev: 7 current: 6
Object: 4 prev: 3 current: 4
Object: 6 prev: 4 current: 5
Object: 4 prev: 7 current: 6
Object: 2 prev: 7 current: 6
Object: 0 prev: 6 current: 5
Object: 5 prev: 8 current: 7
Object: 9 prev: 0 current: 1
Object: 0 prev: 6 current: 5
Object: 0 prev: 3 current: 2
Object: 5 prev: 8 current: 7
Object: 5 prev: 7 current: 6
Object: 2 prev: 4 current: 3
Object: 3 prev: 3 current: 3
Object: 6 prev: 4 current: 5
Object: 1 prev: 1 current: 1
Object: 1 prev: 5 current: 4
Object: 0 prev: 2 current: 1
Object: 8 prev: 5 current: 6
Object: 0 prev: 1 current: 0
Object: 6 prev: 5 current: 6
Object: 6 prev: 7 current: 6
Object: 8 prev: 5 current: 6
Object: 3 prev: 0 current: 1
Object: 6 prev: 0 current: 1
Object: 9 prev: 3 current: 4
Object: 4 prev: 6 current: 5
Object: 2 prev: 2 current: 2
Object: 1 prev: 2 current: 1
Object: 7 prev: 9 current: 8
Object: 6 prev: 5 current: 6
Object: 7 prev

### For different number of episodes:
    
    Number of episodes        Test accuracy
    1000                      0.5231
    10000                     0.8694
    100000                    1.0

In [6]:
Q[9, 9]

array([0.56321528, 5.90987594, 0.        ])

In [7]:
Q[0, 0]

array([0.        , 5.83771556, 0.83353492])

In [8]:
Q[9, 8]

array([0.99935002, 0.80560819, 5.89234116])

In [9]:
Q[7, 7]

array([0.92798844, 6.00686393, 0.86799978])

In [10]:
Q[9, 1]

array([0.95560026, 0.81832114, 6.04864212])

In [11]:
Q[9, 0]

array([0.        , 1.15292055, 6.11604542])

In [12]:
Q

array([[[0.        , 5.83771556, 0.83353492],
        [5.86951989, 0.89068243, 0.69349434],
        [5.86413787, 0.66535395, 0.86433256],
        [5.94804113, 0.94343857, 0.92839362],
        [5.96345033, 0.84303414, 1.12582671],
        [5.99895087, 1.03074813, 0.97765665],
        [6.04140319, 0.86383278, 1.25891773],
        [5.98742582, 0.91566108, 0.74771503],
        [6.00924063, 0.83768286, 0.67866838],
        [6.05098589, 0.8750526 , 0.        ]],

       [[0.        , 1.09995136, 6.09383778],
        [0.63591705, 6.05349698, 1.02860754],
        [6.04855869, 0.77490893, 0.81629614],
        [5.97746438, 0.93946965, 0.40991328],
        [5.97625359, 0.76923561, 0.83537916],
        [5.99235517, 0.98881785, 0.618754  ],
        [5.97671432, 0.55850853, 0.65736172],
        [5.90361361, 0.51524644, 0.52242773],
        [5.89183175, 0.60796245, 0.47139128],
        [5.86686248, 0.67685667, 0.        ]],

       [[0.        , 0.72261864, 6.26668563],
        [0.86540909, 1.1229690