In [1]:
import datetime
import os
import sys
import logging

import torch
import torch.utils.data
import torch.optim as optim
from torchsummary import summary
import tensorboardX

from utility import train_utility

from utility.data import get_custom_dataset
from utility.peggnet_model import PEGG_NET
import utility.io_processing as iop
print("imported everything")
logging.basicConfig(level=logging.INFO)

imported everything


In [2]:
def validate(net: PEGG_NET, device, val_data: torch.utils.data.DataLoader, use_sparse_loss=False):
    """
    Run validation.
    :param net: Network
    :param device: Torch device
    :param val_data: Validation Dataset
    :param batches_per_epoch: Number of batches to run
    :return: Successes, Failures and Losses
    """
    net.eval()

    results = {
        'correct': 0,
        'failed': 0,
        'loss': 0,
        'losses': {

        }
    }

    ld = len(val_data)

    with torch.no_grad():
        for x, y, didx, rot, zoom_factor, score, angle_at_fault in val_data:

            xc = x.to(device)
            yc = [yy.to(device) for yy in y]
            if use_sparse_loss:
                lossd = net.compute_loss_sparse(xc, yc, angle_at_fault=angle_at_fault, score=score, score_min_mask=0)
            else:
                lossd = net.compute_loss(xc, yc)

            loss = lossd['loss']

            results['loss'] += loss.item()/ld
            for ln, l in lossd['losses'].items():
                if ln not in results['losses']:
                    results['losses'][ln] = 0
                results['losses'][ln] += l.item()/ld

            q_out, ang_out, w_out = iop.process_raw_output(lossd['pred']['pos'], lossd['pred']['cos'],
                                                        lossd['pred']['sin'], lossd['pred']['width'])
            # logging.info('rot, zoom : {},{}'.format(rot, zoom_factor)) #Its 2 tensors of size 1
            s = train_utility.calculate_iou_match(q_out, ang_out,
                                                val_data.dataset.get_gtbb(didx, 0, 1.0),
                                                no_grasps=1,
                                                grasp_width=w_out,
                                                )

            if s:
                results['correct'] += 1
            else:
                results['failed'] += 1
    return results

def train(epoch: int, net: PEGG_NET, device, train_data: torch.utils.data.DataLoader, optimizer: optim.Adam, batches_per_epoch, vis=False):
    """
    Run one training epoch
    :param epoch: Current epoch
    :param net: Network
    :param device: Torch device
    :param train_data: Training Dataset
    :param optimizer: Optimizer
    :param batches_per_epoch:  Data batches to train on
    :param vis:  Visualise training progress
    :return:  Average Losses for Epoch
    """
    results = {
        'loss': 0,
        'losses': {
        }
    }

    net.train()

    batch_idx = 0
    # Use batches per epoch to make training on different sized datasets (cornell/jacquard) more equivalent.
    while batch_idx < batches_per_epoch:
        # logging.info('batch_idx: {}'.format(batch_idx))
        for x, y, _, _, _, score, angle_at_fault in train_data:
            batch_idx += 1
            if batch_idx >= batches_per_epoch:
                break

            xc = x.to(device)
            yc = [yy.to(device) for yy in y]
            lossd = net.compute_loss_sparse(xc, yc, angle_at_fault=angle_at_fault, score=score)

            loss = lossd['loss']

            if batch_idx % 100 == 0:
                logging.info('Epoch: {}, Batch: {}, Loss: {:0.4f}'.format(epoch, batch_idx, loss.item()))

            results['loss'] += loss.item()
            for ln, l in lossd['losses'].items():
                if ln not in results['losses']:
                    results['losses'][ln] = 0
                results['losses'][ln] += l.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Display the images
            # if vis:
            #     imgs = []
            #     n_img = min(4, x.shape[0])
            #     for idx in range(n_img):
            #         imgs.extend([x[idx,].numpy().squeeze()] + [yi[idx,].numpy().squeeze() for yi in y] + [
            #             x[idx,].numpy().squeeze()] + [pc[idx,].detach().cpu().numpy().squeeze() for pc in lossd['pred'].values()])
            #     gridshow('Display', imgs,
            #              [(xc.min().item(), xc.max().item()), (0.0, 1.0), (0.0, 1.0), (-1.0, 1.0), (0.0, 1.0)] * 2 * n_img,
            #              [cv2.COLORMAP_BONE] * 10 * n_img, 10)
            #     cv2.waitKey(2)

    results['loss'] /= batch_idx
    for l in results['losses']:
        results['losses'][l] /= batch_idx

    return results

In [3]:
model = None
state_dict = "saved_models/baseline_peggnet/epoch_17_iou_0.91_statedict.pt"
network = "peggnet"
input_size = 480
max_width = 150
dataset = "custom"
dataset_train = "training/training"
dataset_eval_sparse = "training/validation_sparse"
dataset_eval_annotated = "training/validation_annotated"
use_depth = True
use_rgb = True

# random_zoom = True
# random_rotations = True
# random_symmetry=True
# random_brightness=True
# random_contrast=True
force_save_every = 1

split = 1.0
ds_rotate = 0.0
image_wise = False
random_seed = 10
augment = True
num_workers = 8

lr = 0.001
lr_step = [10,20,30,40]
lr_step_coeff = 0.8
batch_size = 8
epochs = 50
batches_per_epoch = 1000

description = ""
outdir = "output/models"
logdir = "tensorboard/"
vis = False

In [4]:
# !!! USING CV2 VIS IN JUPYTER NOTEBOOK MAKE THE CORE CRASH
# if vis:
#     cv2.namedWindow('Display', cv2.WINDOW_NORMAL)


# Set-up output directories
dt = datetime.datetime.now().strftime('%y%m%d_%H%M')
net_desc = '{}_{}'.format(dt, '_'.join(description.split()))
save_folder = os.path.join(outdir, net_desc)
if not os.path.exists(save_folder):
    os.makedirs(save_folder)
tb = tensorboardX.SummaryWriter(os.path.join(logdir, net_desc))
# Load Dataset
logging.info('Loading {} Dataset...'.format(dataset.title()))
Dataset = get_custom_dataset(dataset)

train_dataset = Dataset(file_path=dataset_train,
                        output_size=input_size,
                        start=0.0,
                        end=split,
                        ds_rotate=ds_rotate,
                        image_wise=image_wise,
                        random_seed=random_seed,
                        random_rotate=augment,
                        random_zoom=augment,
                        random_symmetry=augment,
                        random_brightness=augment,
                        random_contrast=augment,
                        include_depth=use_depth, 
                        include_rgb=use_rgb,
                        max_width=max_width)
train_data = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)
val_dataset_sparse = Dataset(file_path=dataset_eval_sparse,
                             output_size=input_size,
                             start=0.0,
                             end=1.0,
                             ds_rotate=ds_rotate,
                             image_wise=image_wise,
                             random_seed=random_seed,
                             random_rotate=False,
                             random_zoom=False,
                             random_symmetry=False,
                             random_brightness=False,
                             random_contrast=False,
                             include_depth=use_depth,
                             include_rgb=use_rgb,
                             max_width=max_width)
val_data_sparse = torch.utils.data.DataLoader(
    val_dataset_sparse,
    batch_size=1,
    shuffle=False,
    num_workers=num_workers
)
val_dataset_annotated = Dataset(file_path=dataset_eval_annotated,
                                output_size=input_size,
                                start=0.0,
                                end=1.0,
                                ds_rotate=ds_rotate,
                                image_wise=image_wise,
                                random_seed=random_seed,
                                random_rotate=False,
                                random_zoom=False,
                                random_symmetry=False,
                                random_brightness=False,
                                random_contrast=False,
                                include_depth=use_depth,
                                include_rgb=use_rgb,
                                max_width=max_width)
val_data_annotated = torch.utils.data.DataLoader(
    val_dataset_annotated,
    batch_size=1,
    shuffle=False,
    num_workers=num_workers
)

logging.info('Done')
logging.info('Number of training images: {}'.format(len(train_dataset)))
logging.info('Number of sparse validation images: {}'.format(len(val_data_sparse)))
logging.info('Number of annotated validation images: {}'.format(len(val_data_annotated)))
logging.info('Data augmentation (for training only): {}'.format(augment))

# Load the network
logging.info('Loading Network...')
input_channels = 1*use_depth + 3*use_rgb
logging.info("Number of input channels: {}".format(input_channels))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info('Using device: {}'.format(device))


net = PEGG_NET(input_channels=input_channels)
if model is not None:
    net = torch.load(model, map_location=device)
elif state_dict is not None:
    net.load_state_dict(torch.load(state_dict, map_location=device, weights_only=True))
net = net.to(device)
optimizer = optim.Adam(net.parameters(), lr=lr)
logging.info('Network Loaded')

# Print model architecture.
summary(net, (input_channels, input_size, input_size))
f = open(os.path.join(save_folder, 'arch.txt'), 'w')
sys.stdout = f
summary(net, (input_channels, input_size, input_size))
sys.stdout = sys.__stdout__
f.close()

INFO:root:Loading Custom Dataset...
INFO:root:Done
INFO:root:Number of training images: 265
INFO:root:Number of sparse validation images: 20
INFO:root:Number of annotated validation images: 40
INFO:root:Data augmentation (for training only): True
INFO:root:Loading Network...
INFO:root:Number of input channels: 4
INFO:root:Using device: cuda
  net.load_state_dict(torch.load(state_dict, map_location=device))
INFO:root:Network Loaded


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 480, 480]           1,152
       BatchNorm2d-2         [-1, 32, 480, 480]              64
              Mish-3         [-1, 32, 480, 480]               0
Conv_Bn_Activation-4         [-1, 32, 480, 480]               0
            Conv2d-5         [-1, 32, 480, 480]           1,024
       BatchNorm2d-6         [-1, 32, 480, 480]              64
              Mish-7         [-1, 32, 480, 480]               0
Conv_Bn_Activation-8         [-1, 32, 480, 480]               0
            Conv2d-9         [-1, 32, 480, 480]           9,216
      BatchNorm2d-10         [-1, 32, 480, 480]              64
             Mish-11         [-1, 32, 480, 480]               0
Conv_Bn_Activation-12         [-1, 32, 480, 480]               0
         ResBlock-13         [-1, 32, 480, 480]               0
           Conv2d-14         [-1, 64, 

In [None]:
best_iou_sparse = 0.0
best_iou_annotated = 0.0
for epoch in range(1, epochs + 1):
    logging.info('Beginning Epoch {:02d}'.format(epoch))

    if epoch in lr_step:
        print(lr_step)
        lr = lr * lr_step_coeff
        print('Drop LR to', lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    train_results = train(
        epoch, net, device, train_data, optimizer, batches_per_epoch, vis=vis
    )

    # Log training losses to tensorboard
    tb.add_scalar('loss/train_loss', train_results['loss'], epoch)
    for n, l in train_results['losses'].items():
        tb.add_scalar('train_loss/' + n, l, epoch)

    # Run Validation
    logging.info('Validating on sparse validation set...')
    test_results = validate(net, device, val_data_sparse, use_sparse_loss=True)
    logging.info('%d/%d = %f' % (test_results['correct'],
                 test_results['correct'] + test_results['failed'],
                 test_results['correct']/(test_results['correct']+test_results['failed'])
                 ))

    # Log validation results to tensorbaord
    tb.add_scalar(
        'loss/IOU_sparse',
        test_results['correct'] / (test_results['correct'] + test_results['failed']),
        epoch
    )
    tb.add_scalar('loss/val_loss_sparse', test_results['loss'], epoch)
    for n, l in test_results['losses'].items():
        tb.add_scalar('val_loss_sparse/' + n, l, epoch)

    # Save best performing network
    iou_sparse = test_results['correct'] / (test_results['correct'] + test_results['failed'])

    logging.info('Validating on annotated validation set...')
    test_results = validate(net, device, val_data_annotated, use_sparse_loss=False)  
    logging.info('%d/%d = %f' % (test_results['correct'],
                                 test_results['correct'] + test_results['failed'],
                                 test_results['correct']/(test_results['correct']+test_results['failed'])
                                ))

    # Log validation results to tensorbaord
    tb.add_scalar(
        'loss/IOU_annotated',
        test_results['correct'] / (test_results['correct'] + test_results['failed']),
        epoch
    )
    tb.add_scalar('loss/val_loss_annotated', test_results['loss'], epoch)
    for n, l in test_results['losses'].items():
        tb.add_scalar('val_loss_annotated/' + n, l, epoch)

    # Save best performing network
    iou_annotated = test_results['correct'] / (test_results['correct'] + test_results['failed'])
    
    if iou_sparse > best_iou_sparse or iou_annotated>best_iou_annotated or epoch == 0 or (epoch % force_save_every) == 0:
        torch.save(
            net,
            os.path.join(save_folder, 'epoch_%02d_iou_S_%0.2f_iou_A_%0.2f' % (epoch, iou_sparse, iou_annotated))
        )
        torch.save(net.state_dict(), os.path.join(save_folder, 'epoch_%02d_iou_S_%0.2f_iou_A_%0.2f_statedict.pt' % (epoch, iou_sparse, iou_annotated)))
        best_iou_sparse = max(iou_sparse, best_iou_sparse)
        best_iou_annotated = max(iou_annotated, best_iou_annotated)

INFO:root:Beginning Epoch 01
INFO:root:Epoch: 1, Batch: 100, Loss: 0.0975
INFO:root:Epoch: 1, Batch: 200, Loss: 0.0731
INFO:root:Epoch: 1, Batch: 300, Loss: 0.0591
INFO:root:Epoch: 1, Batch: 400, Loss: 0.0336
INFO:root:Epoch: 1, Batch: 500, Loss: 0.0334
INFO:root:Epoch: 1, Batch: 600, Loss: 0.1120
INFO:root:Epoch: 1, Batch: 700, Loss: 0.1081
INFO:root:Epoch: 1, Batch: 800, Loss: 0.0345
INFO:root:Epoch: 1, Batch: 900, Loss: 0.0606
INFO:root:Validating on sparse validation set...
INFO:root:10/20 = 0.500000
INFO:root:Validating on annotated validation set...
INFO:root:32/40 = 0.800000
INFO:root:Beginning Epoch 02
INFO:root:Epoch: 2, Batch: 100, Loss: 0.0313
INFO:root:Epoch: 2, Batch: 200, Loss: 0.1174
INFO:root:Epoch: 2, Batch: 300, Loss: 0.0358
INFO:root:Epoch: 2, Batch: 400, Loss: 0.0191
INFO:root:Epoch: 2, Batch: 500, Loss: 0.1035
INFO:root:Epoch: 2, Batch: 600, Loss: 0.0577
INFO:root:Epoch: 2, Batch: 700, Loss: 0.0804
INFO:root:Epoch: 2, Batch: 800, Loss: 0.0331
INFO:root:Epoch: 2, Ba