In [None]:
import gc
import torch
import argparse
import collections
import numpy as np
import data_loader.data_loaders as module_data
from parse_config import ConfigParser
import time
import shutil
from tqdm import *
import os
import torch.nn as nn
import matplotlib.pyplot as plt


In [None]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

c = "config/config.json"
p = open(os.path.join(os.path.curdir,c))


args_dataset = argparse.ArgumentParser(description='PyTorch Template')
args_dataset.add_argument('-c', '--config', default=c, type=str,
                    help='config file path (default: None)')
args_dataset.add_argument('-r', '--resume', default=None, type=str,
                    help='path to latest checkpoint (default: None)')
args_dataset.add_argument('-d', '--device', default=None, type=str,
                    help='indices of GPUs to enable (default: all)')
args_dataset.add_argument('--limited_memory', default=False, action='store_true',
                    help='prevent "too many open files" error by setting pytorch multiprocessing to "file_system".')
args_dataset.add_argument(
    "-i", "--ip", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-s", "--stdin", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-control", "--control", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-b", "--hb", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-K", "--Session.key", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-S", "--Session.signature_scheme", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-l", "--shell", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-t", "--transport", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-o", "--iopub", help="a dummy argument to fool ipython", default="1")
args_dataset.add_argument(
    "-f", "--ffff", help="a dummy argument to fool ipython", default="1")
# custom cli options to modify configuration from default values given in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
    CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
    CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size'),
    CustomArgs(['--rmb', '--reset_monitor_best'], type=bool, target='trainer;reset_monitor_best'),
    CustomArgs(['--vo', '--valid_only'], type=bool, target='trainer;valid_only')
]
config_dataset = ConfigParser.from_args(args_dataset, options)
if args_dataset.parse_args().limited_memory:
    # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
    import torch.multiprocessing
    torch.multiprocessing.set_sharing_strategy('file_system')
batch_size = config_dataset.config['data_loader']['args']['batch_size']
num_bins = config_dataset.config['data_loader']['args']['sequence_kwargs']['dataset_kwargs']['num_bins']
epochs = 50

num_files = 0
with open(config_dataset.config['data_loader']['args']['data_file']) as f:
    num_files = len(f.readlines())

seq_size = config_dataset.config['data_loader']['args']['sequence_kwargs']['sequence_length']


logger = config_dataset.get_logger('train')

# setup data_loader instances
train_loader = config_dataset.init_obj('data_loader', module_data)
val_loader = config_dataset.init_obj('valid_data_loader', module_data)

In [None]:
parser = argparse.ArgumentParser(description='Training DCFNet in Pytorch 0.4.0')
parser.add_argument('--input_sz', dest='input_sz', default=128, type=int, help='crop input size')
parser.add_argument('--padding', dest='padding', default=2.0, type=float, help='crop padding size')
parser.add_argument('--range', dest='range', default=10, type=int, help='select range')
parser.add_argument('--epochs', default=epochs, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
                    help='number of data loading workers (default: 8)')
parser.add_argument('-b', '--batch-size', default=batch_size, type=int,
                    metavar='N', help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-6, type=float,
                    metavar='W', help='weight decay (default: 5e-5)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--save', '-s', default='./work', type=str, help='directory for saving')
parser.add_argument('--num_bins', '-nb', default=num_bins, type=str, help='number of bins')
parser.add_argument(
    "-i", "--ip", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
     "--stdin", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
    "-control", "--control", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
    "--hb", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
    "-K", "--Session.key", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
    "-l", "--shell", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
    "-t", "--transport", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
    "-o", "--iopub", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
    "-f", "--ffff", help="a dummy argument to fool ipython", default="1")
parser.add_argument(
    "-S", "--Session.signature_scheme", help="a dummy argument to fool ipython", default="1")
args = parser.parse_args()

print(args)
best_loss = 1e6

In [None]:

def complex_mul(x, z):
    out_real = x[..., 0] * z[..., 0] - x[..., 1] * z[..., 1]
    out_imag = x[..., 0] * z[..., 1] + x[..., 1] * z[..., 0]
    return torch.stack((out_real, out_imag), -1)


def complex_mulconj(x, z):
    x = torch.view_as_real(x)
    z = torch.view_as_real(z)
    out_real = x[..., 0] * z[..., 0] + x[..., 1] * z[..., 1]
    out_imag = x[..., 1] * z[..., 0] - x[..., 0] * z[..., 1]
    return torch.stack((out_real, out_imag), -1)


class DCFNetFeature(nn.Module):
    def __init__(self,in_channel):
        super(DCFNetFeature, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(in_channel, 32, kernel_size=3),#, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1),
        )

    def forward(self, x):
        return self.feature(x)


class EventEnc(nn.Module):
    def __init__(self,in_channel):
        super(EventEnc,self).__init__()
        self.conv1 = nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.pool(x)
        return x


class GreyEnc(nn.Module):
    def __init__(self, in_channel):
        super(GreyEnc,self).__init__()
        self.conv1 = nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)


    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.pool(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.convT1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=1,padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.convT2 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2, padding=1)
        self.lclres = nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1)

    def forward(self, x):
        x = self.convT1(x)
        x = self.relu(x)
        x = self.convT2(x)
        x = self.lclres(x)
        return x


class DCFNet(nn.Module):
    def __init__(self,in_channel,num_bins, config=None):
        super(DCFNet, self).__init__()
        self.enc_e = EventEnc(num_bins)
        self.enc_f = GreyEnc(in_channel)
        self.dec = Decoder()
        self.event_feat = DCFNetFeature(num_bins)
        self.frame_feat = DCFNetFeature(in_channel)

        if(config != None):
            self.yf = config.yf.clone()
            self.lambda0 = config.lambda0

    def forward(self, z, x, label):
        zf = torch.torch.fft.rfft2(z)
        xf = torch.torch.fft.rfft2(x)

        kzzf = torch.sum(torch.sum(torch.view_as_real(zf) ** 2, dim=4, keepdim=True), dim=1, keepdim=True)
        kxzf = torch.sum(complex_mulconj(xf, zf), dim=1, keepdim=True)
        alphaf = label.to(device=z.device) / (kzzf + self.lambda0)
        response = torch.fft.irfft2(torch.view_as_complex(complex_mul(kxzf, alphaf)))
        return response

def gaussian_shaped_labels(sigma, sz):
    x, y = np.meshgrid(np.arange(1, sz[0]+1) - np.floor(float(sz[0]) / 2), np.arange(1, sz[1]+1) - np.floor(float(sz[1]) / 2))
    d = x ** 2 + y ** 2
    g = np.exp(-0.5 / (sigma ** 2) * d)
    g = np.roll(g, int(-np.floor(float(sz[0]) / 2.) + 1), axis=0)
    g = np.roll(g, int(-np.floor(float(sz[1]) / 2.) + 1), axis=1)
    return g.astype(np.float32)

def output_drop(output, target):
    delta1 = (output - target)**2
    batch_sz = delta1.shape[0]
    delta = delta1.view(batch_sz, -1).sum(dim=1)
    sort_delta, index = torch.sort(delta, descending=True)
    # unreliable samples (10% of the total) do not produce grad (we simply copy the groundtruth label)
    for i in range(int(round(0.1*batch_sz))):
        output[index[i],...] = target[index[i],...]
    return output

class TrackerConfig(object):
    crop_sz = 128
    output_sz = 124
    lambda0 = 1e-4
    padding = 2.0
    output_sigma_factor = 0.1
    output_sigma = crop_sz / (1 + padding) * output_sigma_factor
    y = gaussian_shaped_labels(output_sigma, [output_sz, output_sz])
    yf = torch.fft.rfft2(torch.Tensor(y).view(1, 1, output_sz, output_sz).cuda())
    yf = torch.view_as_real(yf)


config = TrackerConfig()

model = DCFNet(in_channel=num_bins,config=config)
model.cuda()
gpu_num = torch.cuda.device_count()
print('GPU NUM: {:2d}'.format(gpu_num))
if gpu_num > 1:
    model = torch.nn.DataParallel(model, list(range(gpu_num))).cuda()

criterion = nn.MSELoss(size_average=False).cuda()

optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

target = torch.Tensor(config.y).cuda().unsqueeze(0).unsqueeze(0).repeat(args.batch_size * gpu_num, 1, 1, 1)  # for training
print(model)
dict_net= {}
for i,x in enumerate(model.modules()):
    dict_net[str(i)] = x

args.model = dict_net

from datetime import datetime
def adjust_learning_rate(optimizer, epoch):
    lr = np.logspace(-2, -5, num=args.epochs)[epoch]
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
dt_string = datetime.now().strftime("%d%m%Y_%H%M%S")
save_path = os.path.join(args.save, '{}_{:d}_{:1.1f}_{:d}'.format(dt_string,args.input_sz, args.padding,num_bins))
if not os.path.isdir(save_path):
    os.makedirs(save_path)
args.model_path = save_path

def save_checkpoint(state, is_best, model_name, filename=os.path.join(save_path, 'checkpoint.pth.tar')):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(save_path, 'model_best.pth.tar'))


In [None]:
def get_features(x, is_frame):
    if(is_frame):
        x_enc = model.enc_f(x)
        x_old = model.frame_feat(x)
    else:
        x_enc = model.enc_e(x)
        x_old = model.event_feat(x)
    return model.dec(x_enc) + x_old

def get_response(model,template_feat,search1_feat,label,initial_y,args):
    with torch.no_grad():
        s1_response = model(template_feat, search1_feat, label)
    # label transform
    peak, index = torch.max(s1_response.view(args.batch_size*gpu_num, -1), 1)
    r_max, c_max = np.unravel_index(index.cpu(), [config.output_sz, config.output_sz])
    fake_y = np.zeros((args.batch_size*gpu_num, 1, config.output_sz, config.output_sz))
    # label shift
    for j in range(args.batch_size*gpu_num):
        shift_y  = np.roll(initial_y, r_max[j])
        fake_y[j,...] = np.roll(shift_y, c_max[j])
    fake_yf = torch.fft.rfft2(torch.Tensor(fake_y).view(args.batch_size*gpu_num, 1, config.output_sz, config.output_sz).cuda())
    fake_yf = torch.view_as_real(fake_yf)

    return fake_yf.cuda(non_blocking=True)

def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()
    label = config.yf.repeat(args.batch_size*gpu_num,1,1,1,1).cuda(non_blocking=True)
    initial_y = config.y.copy()

    end = time.time()

    for i, x in enumerate(train_loader):

        template = x[0]
        search1, search2 = x[1],x[2]

        if(template["events"].shape[0] != label.shape[0]):
            continue
        # measure data loading time
        data_time.update(time.time() - end)

        template_e = get_features(template["events"].cuda(non_blocking=True),False)
        template_f = get_features(template["frame"].cuda(non_blocking=True).permute((0,3,1,2)),True)
        template_feat = template_e + template_f

        search1_e = get_features(search1["events"].cuda(non_blocking=True),False)
        search1_f = get_features(search1["frame"].cuda(non_blocking=True).permute((0,3,1,2)),True)
        search1_feat = search1_e + search1_f

        search2_e = get_features(search2["events"].cuda(non_blocking=True),False)
        search2_f = get_features(search2["frame"].cuda(non_blocking=True).permute((0,3,1,2)),True)
        search2_feat = search2_e + search2_f

        # forward tracking 1
        fake_yf = get_response(model, template_feat, search1_feat, label, initial_y, args)

        # forward tracking 2
        fake_yf = get_response(model, search1_feat, search2_feat, fake_yf, initial_y, args)

        # backward tracking
        output = model(search2_feat, template_feat, fake_yf)
        output = output_drop(output, target)  # the sample dropout is necessary, otherwise we find the loss tends to become unstable

        output_e = model(search2_e, template_e, fake_yf)
        output_e = output_drop(output_e, target)

        output_f = model(search2_f, template_f, fake_yf)
        output_f = output_drop(output_f, target)

        # consistency loss. target is the initial Gaussian label
        loss = (0.5*criterion(output, target) +0.2*criterion(output_e, target) + 0.3*criterion(output_f, target))/template_feat.size(0)

        # measure accuracy and record loss
        losses.update(loss.item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if(torch.isnan(loss)):
            return False

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))

    return True

def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()

    model.eval()
    initial_y = config.y.copy()
    label = config.yf.repeat(args.batch_size*gpu_num,1,1,1,1).cuda(non_blocking=True)

    with torch.no_grad():
        end = time.time()
        for i, x in enumerate(val_loader):

            # compute output
            template = x[0]
            search1, search2 = x[1],x[2]

            if(template["events"].shape[0] != label.shape[0]):
                continue
            template_e = get_features(template["events"].cuda(non_blocking=True),False)
            template_f = get_features(template["frame"].cuda(non_blocking=True).permute((0,3,1,2)),True)
            template_feat = template_e + template_f

            search1_e = get_features(search1["events"].cuda(non_blocking=True),False)
            search1_f = get_features(search1["frame"].cuda(non_blocking=True).permute((0,3,1,2)),True)
            search1_feat = search1_e + search1_f

            search2_e = get_features(search2["events"].cuda(non_blocking=True),False)
            search2_f = get_features(search2["frame"].cuda(non_blocking=True).permute((0,3,1,2)),True)
            search2_feat = search2_e + search2_f

            # forward tracking 1
            fake_yf = get_response(model, template_feat, search1_feat, label, initial_y, args)

            # forward tracking 2
            fake_yf = get_response(model, search1_feat, search2_feat, fake_yf, initial_y, args)

            # backward tracking
            output = model(search2_feat, template_feat, fake_yf)
            output = output_drop(output, target)  # the sample dropout is necessary, otherwise we find the loss tends to become unstable

            output_e = model(search2_e, template_e, fake_yf)
            output_e = output_drop(output_e, target)

            output_f = model(search2_f, template_f, fake_yf)
            output_f = output_drop(output_f, target)

            # consistency loss. target is the initial Gaussian label
            loss = (0.5*criterion(output, target) +0.2*criterion(output_e, target) + 0.3*criterion(output_f, target))/(args.batch_size * gpu_num)

            # measure accuracy and record loss
            losses.update(loss.item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses))

        print(' * Loss {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses))

    return losses.avg


In [None]:

try:
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        worked = train(train_loader, model, criterion, optimizer, epoch)
        if(not worked):
            print("Broken")
            break
        # evaluate on validation set
        loss = validate(val_loader, model, criterion)
        if(epoch %5 == 0):
            print(loss)
        # remember best loss and save checkpoint
        is_best = loss < best_loss
        best_loss = min(best_loss, loss)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_loss': best_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best,"baseModel")

except Exception as e:
    print(e)