# Training Q Model
## Load Files

In [14]:
from FileUtils import FileUtils
import matplotlib.pyplot as plt 
data = FileUtils()

Start to load data
Start to load images: 1000=>2000=>3000=>4000=>5000=>6000=>7000=>8000=>9000=>10000=>11000=>11540!
Start to load annotations: 1000=>2000=>3000=>4000=>5000=>6000=>7000=>8000=>9000=>10000=>11000=>11540!
End


In [15]:
data.filter_by_class("bird")
total_num = len(data.images) # the length of the cat images
total_num

773

## Model Loading

In [16]:
from Models import QModel
from Settings import Settings
import torch
import torch.optim as optim

In [17]:
device = torch.device("cuda" if Settings.cuda else "cpu")

In [None]:
policy_model = QModel()
target_model = QModel()
target_model.load_state_dict(policy_model.state_dict())

if Settings.cuda:
    policy_model.cuda()
    target_model.cuda()
target_model = target_model.eval()

In [None]:
from ReplayMemory import ReplayMemory
memory = ReplayMemory(1000)

In [None]:
optimizer = optim.Adam(policy_model.parameters(), lr=1e-6)

## Training Loop

In [None]:
import time
from Agent import Agent
from IoU import *
from TrainUtils import optimize_model
import json

In [None]:
epoch_num = 100
train_num = int(0.8 * total_num)
update_num = 3
eps = Settings.eps_start

In [None]:
Settings.iou_threshold, Settings.gamma

(0.6, 0.25)

In [None]:
print("Start Model Training")
loss_list = list()
min_loss = float("inf")


for epoch in range(epoch_num):
#     print("==> Epoch {} start ...".format(epoch))
    start = time.time()
    cur_loss = 0.0
    policy_model.train()
    
    for i, image in enumerate(data.images[:train_num]):
        annotation_list = data.annotations[i]
        agent = Agent(image)
        done = False
        old_iou_list = None
        
        for step in range(Settings.max_step):
            iou_list = [iou_calculator(agent.boundary, x) for x in annotation_list]
            max_index = max(range(len(iou_list)), key=lambda x: iou_list[x])
            iou = max(iou_list)

            old_iou = old_iou_list[max_index] if old_iou_list else 0
            old_iou_list = iou_list
            cur_state = agent.get_state()
            
            # determine if we should end the result 
            if iou > Settings.iou_threshold:
                action = torch.tensor(6).to(device)
            else:
                action = agent.get_next_action(policy_model, eps)

            if action == 6:
                reward = reward_terminal(iou)
                agent.update_history_vector(action)
                done = True
                next_state = None
                
            else:
                
                agent.hierarchical_move(action)
                agent.update_history_vector(action)
                
                if agent.sub_image.shape[0] * agent.sub_image.shape[1] == 0:
                    done = True
                    next_state = None
                
                else:
                    next_state = agent.get_state()
                    reward = reward_move(old_iou, iou)
            
            memory.push(cur_state, action, next_state, reward)
            cur_loss = optimize_model(policy_model, target_model, memory, optimizer)
            
            if done:
                break
    
    min_loss = min(cur_loss, min_loss)
    if min_loss == cur_loss:
        print("Save min loss network")
        with open("{}model_{}.pt".format(Settings.model_path, epoch), 'wb') as f:
            torch.save(policy_model.cpu().state_dict(), f)
            if Settings.cuda:
                policy_model.cuda()
                
    if epoch % update_num == 0:
        print("Update Network")
        target_model.load_state_dict(policy_model.state_dict())
        target_model.eval()
        loss_list.append(cur_loss)
        
        # save logs
        with open("{}log_1.json".format(Settings.model_path), "w") as f:
            json.dump(loss_list, f)

    if eps > 0.11:
        eps -= 0.1
    else:
        eps = 0.1
    
    time_cost = time.time() - start
    print("==> Epoch {} End, time cost = {}, current loss = {}, next eps = {}".format(epoch, round(time_cost,4), cur_loss, round(eps,2)))


Start Model Training
Save min loss network
Update Network
==> Epoch 0 End, time cost = 54.2607, current loss = 2.6791, next eps = 0.8
Save min loss network
==> Epoch 1 End, time cost = 58.2235, current loss = 1.5933, next eps = 0.7
Save min loss network
==> Epoch 2 End, time cost = 64.2051, current loss = 0.8984, next eps = 0.6
Update Network
==> Epoch 3 End, time cost = 69.2652, current loss = 1.3867, next eps = 0.5
==> Epoch 4 End, time cost = 75.1861, current loss = 0.9113, next eps = 0.4
Update Network
==> Epoch 6 End, time cost = 81.1185, current loss = 1.1357, next eps = 0.2
Save min loss network
==> Epoch 7 End, time cost = 89.653, current loss = 0.8641, next eps = 0.1
==> Epoch 8 End, time cost = 92.8349, current loss = 1.2574, next eps = 0.1
Save min loss network
Update Network
==> Epoch 9 End, time cost = 96.5364, current loss = 0.778, next eps = 0.1
Save min loss network
==> Epoch 10 End, time cost = 93.1028, current loss = 0.751, next eps = 0.1
Save min loss network
==> Epo

# Test Model

In [None]:
def testModelOnOneImage(image_index):
    

In [None]:
plt.imshow(data.images[])

In [None]:
from Image import Image
image = Image(data,4)
annotation_list = image.objects
agent = Agent(image.image)
done = False
old_iou_list = None

for step in range(Settings.max_step):
    print("Step {}".format(step), end="=>")
    

    iou_list = [iou_calculator(agent.boundary, x) for x in annotation_list]
    max_index = max(range(len(iou_list)), key=lambda x: iou_list[x])
    iou = max(iou_list)
    
    print("current iou = {}".format(iou), end=" || ")
    old_iou = old_iou_list[max_index] if old_iou_list else 0
    old_iou_list = iou_list

    cur_state = agent.get_state()

    # determine if we should end the result 
    if iou > Settings.iou_threshold:
        action = torch.tensor(6).to(device)
    else:
        action = agent.get_next_action(target_model, eps=0)

    if action == 6:
        reward = reward_terminal(iou)
        done = True
        next_state = None
    else:
        agent.hierarchical_move(action)
        agent.update_history_vector(action)
        image.draw_one_box(agent.boundary)
        image.add_text(step, (agent.boundary["xmin"], agent.boundary["ymin"]))
        if agent.sub_image.shape[0] * agent.sub_image.shape[1] == 0:
            done = True
            next_state = None
        else:
            next_state = agent.get_state()
            reward = reward_move(old_iou, iou)
            
    print("current action = {}".format(int(action)), end=" || ")
    print("current reward = {}".format(float(reward)))
    if done:
        break
image.show()    

In [None]:

import json

In [None]:
plt.plot(loss_list)