# Pointer networks basic implementation

## Tasks
Pick "convex hull"
* [x] Generate the dataset
* [x] Evaluation metric
* [x] Implement the model
* [ ] Reproduce the results from the paper

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import collections
import matplotlib.pyplot as plt
import numpy as np
import os
import torch

%matplotlib inline

In [None]:
def plot_points_and_hull(points, hull_indices, c='r'):
    hull_indices = np.hstack([hull_indices, [hull_indices[0]]])

    points_hull = points[hull_indices]
    points_hull = points_hull[points_hull[:, 0] != -1]
    
    points = points[points[:, 0] != -1]

    print('{} points, {} in the hull'.format(points.shape[0], points_hull.shape[0]))
    plt.scatter(points[:, 0], points[:, 1])
    plt.plot(points_hull[:, 0], points_hull[:, 1], c)

### Define parameters

In [None]:
Params = collections.namedtuple('Params', [
    'gpu_device',
    'batch_size', 'embedding_size', 'hiddens', 'n_lstms', 'dropout', 'bidir',
    'lr', 'n_epochs',
])

In [None]:
params = Params(
    gpu_device=2,
    
    # Data
    batch_size=256,
    
    # Training params
    lr=1e-4,
    n_epochs=50,
    
    # Model params # FIXME: NOT USED RIGHT NOW!
    embedding_size=128,
    hiddens=512,
    n_lstms=2,
    dropout=0,
    bidir=False,
)

In [None]:
USE_CUDA = params.gpu_device >= 0 and torch.cuda.is_available()
DEVICE = 'cpu'
if USE_CUDA:
    DEVICE = 'cuda:{}'.format(params.gpu_device)

In [None]:
DEVICE

## Load the data

In [None]:
from datasets import ConvexHullDataset, collate_fn

In [None]:
###### For convex hull
# The data was generated using convex_hull_generator.py
data = np.load('data/convex_hull_5.npz')
# data = np.load('data/convex_hull.npz')

data_train, data_val, data_test = data['arr_0']

data_train = np.array(data_train)
data_val = np.array(data_val)
data_test = np.array(data_test)

In [None]:
data_val.shape

In [None]:
data_test.shape

In [None]:
dataset_train = ConvexHullDataset(data_train, append_eol=True)
dataset_val = ConvexHullDataset(data_val, append_eol=True)
dataset_test = ConvexHullDataset(data_test, append_eol=True)

### Define Dataloader

In [None]:
dataloader_train = DataLoader(dataset_train, batch_size=params.batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
dataloader_val = DataLoader(dataset_val, batch_size=params.batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
dataloader_test = DataLoader(dataset_test, batch_size=params.batch_size, shuffle=False, num_workers=1, collate_fn=collate_fn)

### Visualize the data

In [None]:
d = iter(dataloader_train)
for ix in range(5):
    batch = next(d)

    plt.figure()
    points = batch['sequence'][0].data.numpy()
    inds_hull = batch['pointers'][0].data.numpy().ravel()
    inds_hull = inds_hull[: batch['pointer_lens'][0]]
    plot_points_and_hull(points, inds_hull)
    plt.show()

## Define the model

In [None]:
from pointer_net import PointerNet, Encoder, Decoder

In [None]:
model = PointerNet()

# SANITY RUN THE MODEL
batch = next(iter(dataloader_val))

seq = batch['sequence']
seq_lens = batch['sequence_lens']

target_pointers = batch['pointers']
pointer_lens = batch['pointer_lens']

pointers = model(seq, seq_lens, max_output_len=target_pointers.shape[1])
pointers.sum(dim=-1)

In [None]:
if USE_CUDA:
    model.cuda(device=params.gpu_device)
#     cudnn.benchmark = True

## Define the optimizer / loss

In [None]:
loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100).to(DEVICE)  # -1 for the padded elements
model_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=params.lr)

In [None]:
loss_func

## Logging

In [None]:
model_str = 'ptr-convex-hull-batched-5-eol-1.00'

# logging
weights_folder = "/opt/weights/{}".format(model_str)
log_folder =  '../tensorboard-logs/{}'.format(model_str)
writer = SummaryWriter(log_folder) # writing log to tensorboard
print('logging to: {}'.format(weights_folder))

os.makedirs(weights_folder)  # MEANT TO FAIL IF IT ALREADY EXISTS

## Train

In [None]:
save_every = 10000
val_every = 1000

In [None]:
def predict_and_eval(model, batch, loss_func):
    seq = Variable(batch['sequence'])
    seq_lens, target_pointer_lens = batch['sequence_lens'], batch['pointer_lens']
    target_pointers = Variable(batch['pointers'])  # FIXME: Must append an EOS token, subtract 1 to make 0-based

    if USE_CUDA:
        seq = seq.cuda(params.gpu_device)
        target_pointers = target_pointers.cuda(params.gpu_device)

    # generate as many outputs as in the target sequence
    n_outputs = max(target_pointer_lens)
    pointers = model(seq, seq_lens, max_output_len=n_outputs)  # FIXME: because we don't have an EOS token. Also, makes sense during traing
    assert n_outputs == pointers.shape[1]

    n_classes = pointers.shape[-1]
    loss = loss_func(pointers.contiguous().view(-1, n_classes), target_pointers.contiguous().view(-1))
    return pointers, loss

In [None]:
epoch = 0

In [None]:
while epoch < 5000:  # params.n_epochs:
    for i_batch, train_batch in enumerate(dataloader_train):
        iter_cntr = epoch * len(dataloader_train) + i_batch  # The overall iteration number across epochs

        # Forward
        pointers, train_loss = predict_and_eval(model, train_batch, loss_func)

        # Backprop
        model_optim.zero_grad()
        train_loss.backward()
        model_optim.step()

        writer.add_scalar('train.loss', train_loss.data.cpu().numpy(), iter_cntr)
        
        # Save
        if i_batch % save_every == 0:
            torch.save(model.state_dict(), os.path.join(weights_folder, '{}_{}.pt'.format(epoch, i_batch)))
        
        # Validation
        if i_batch % val_every == 0:
            plt.figure(figsize=(5, 5))

            total_val_loss = 0
            for jx, val_batch in enumerate(dataloader_val):
                if jx == 10:
                    break
                pointers, val_loss = predict_and_eval(model, val_batch, loss_func)
                total_val_loss += val_loss.data.cpu().numpy()

                # plot few
                if jx < 4:
                    plt.subplot(2, 2, jx+1)
                    pred_indices = pointers.argmax(dim=-1).data.cpu().numpy()
                    
                    target_indices = val_batch['pointers'][0].data.cpu().numpy()
                    assert len(target_indices) == pred_indices.shape[1]
                    print('Targets: {}, Preds: {}'.format(target_indices.flatten(), pred_indices[0].flatten()))
                    seq_lens = val_batch['sequence_lens']
                    pointer_lens = val_batch['pointer_lens']
                    points = val_batch['sequence'][0].data.cpu().numpy()[: seq_lens[0]]
                    plot_points_and_hull(points, pred_indices[0].flatten()[: pointer_lens[0]], c='b')
                    plot_points_and_hull(points, target_indices[: pointer_lens[0]].flatten(), c='r--')

            plt.show()

            writer.add_scalar('val.loss', total_val_loss / 10, iter_cntr)
    epoch += 1