# HANet Training

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import os
from torch.utils.data import  DataLoader
from handover_grasping.model import HANet
from handover_grasping.Datavisualizer import handover_grasping_dataset

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## DataLoader and parameters

In [2]:
BATCH_SIZE = 5
EPOCH = 50
SAVE_EVERY = 25
DATA_PATH = '/home/arg/handover_grasping/data/HANet_v2_datasets'

dataset = handover_grasping_dataset(DATA_PATH, color_type='png')
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 8)

## Initial HANet

In [3]:
net = HANet(4)
net = net.cuda()

## Loss function and optimizer

In [7]:
criterion = nn.BCEWithLogitsLoss().cuda()

optimizer = optim.Adam(net.parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 25, gamma = 0.1)

## Start Training

In [8]:
if os.path.isdir(DATA_PATH + '/weight') == False:
    os.mkdir(DATA_PATH + '/weight')
    
loss_l = []
for epoch in range(EPOCH):
    loss_sum = 0.0
    ts = time.time()
    for i_batch, sampled_batched in enumerate(dataloader):
        print("\r[{:03.2f} %]".format(i_batch/float(len(dataloader))*100.0), end="\r")
        optimizer.zero_grad()
        color = sampled_batched['color'].cuda()
        depth = sampled_batched['depth'].cuda()
        label = sampled_batched['label'].permute(0,2,3,1).cuda().float()
        predict = net(color, depth)

        loss = criterion(predict, label)

        loss.backward()
        loss_sum += loss.detach().cpu().numpy()
        optimizer.step()
    scheduler.step()
    loss_l.append(loss_sum/len(dataloader))
    if (epoch+1)%SAVE_EVERY==0:
        torch.save(net.state_dict(), DATA_PATH + '/weight/grapnet_{}_{}.pth' .format(epoch+1, round(loss_l[-1],3)))

    print("Epoch: {}| Loss: {}| Time elasped: {}".format(epoch+1, round(loss_l[-1],5), time.time()-ts))

[0.00 %]

  "See the documentation of nn.Upsample for details.".format(mode))


Epoch: 1| Loss: 0.03018| Time elasped: 37.97710728645325
Epoch: 2| Loss: 0.02137| Time elasped: 37.9552845954895
Epoch: 3| Loss: 0.02017| Time elasped: 37.87660026550293
Epoch: 4| Loss: 0.01789| Time elasped: 37.86599397659302
Epoch: 5| Loss: 0.01653| Time elasped: 37.86790728569031
Epoch: 6| Loss: 0.01474| Time elasped: 37.90761137008667
Epoch: 7| Loss: 0.01337| Time elasped: 37.99703884124756
Epoch: 8| Loss: 0.01195| Time elasped: 38.35384917259216
Epoch: 9| Loss: 0.01108| Time elasped: 38.92362713813782
Epoch: 10| Loss: 0.00967| Time elasped: 38.407872915267944
Epoch: 11| Loss: 0.00878| Time elasped: 39.7547173500061
Epoch: 12| Loss: 0.00783| Time elasped: 39.26367402076721
Epoch: 13| Loss: 0.00698| Time elasped: 38.07651877403259
Epoch: 14| Loss: 0.00686| Time elasped: 37.97936129570007
Epoch: 15| Loss: 0.00547| Time elasped: 38.27909016609192
Epoch: 16| Loss: 0.00525| Time elasped: 38.156272888183594
Epoch: 17| Loss: 0.00497| Time elasped: 38.61170315742493
Epoch: 18| Loss: 0.0050

In [5]:
HANet(4)

HANet(
  (net): FCN_model(
    (color_trunk): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=