# Pointer networks for words->X

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import collections
import matplotlib.pyplot as plt
import numpy as np
import os
import torch

%matplotlib inline

from Dataset.dataset_docschema import Dataset
from docreader.evaluation.metrics.bbox_evaluation import calculate_iou
from docschema.semantic import Word, Paragraph, TextLine, Section, Document, Field
from typing import Union

In [None]:
from doc_data import Preprocessor, collate_fn

### Define parameters

In [None]:
Params = collections.namedtuple('Params', [
    'gpu_device',
    'batch_size', 'embedding_size', 'hiddens', 'n_lstms', 'dropout', 'bidir',
    'lr', 'n_epochs', 'val_every', 'save_every',
    'target_container',
])

In [None]:
params = Params(
    gpu_device=2,
    
    # Data
    batch_size=8,
    
    # MODEL SPECIFC
    target_container=TextLine,
    
    # Training params
    lr=1e-4,
    n_epochs=50,
    val_every=100,
    save_every=1000,
    
    # Model params # FIXME: NOT USED RIGHT NOW!
    embedding_size=128,
    hiddens=512,
    n_lstms=2,
    dropout=0,
    bidir=False,
)

In [None]:
USE_CUDA = params.gpu_device >= 0 and torch.cuda.is_available()
DEVICE = params.gpu_device

In [None]:
preprocessor = Preprocessor(params.target_container, crop_h=200, crop_w=500, random_shuffle=True, only_midpoints=True)

In [None]:
# NOTE: specify training data
synth_list_files = ['/opt/data/field-train-acord-20190214/synth.list']
# dataset = Dataset(synth_list_files, adapter=lambda x: x)
dataset_train = Dataset(synth_list_files, adapter=preprocessor)
dataloader_train = DataLoader(dataset_train, batch_size=params.batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn)
print('Training: {:,} total images {:,} mini batches'.format(len(dataset_train), len(dataloader_train)))

In [None]:
# NOTE: specify valing data
synth_list_files = ['/opt/data/field-train-acord-20190214/synth.list']
# dataset = Dataset(synth_list_files, adapter=lambda x: x)
dataset_val = Dataset(synth_list_files, adapter=preprocessor)
dataloader_val = DataLoader(dataset_val, batch_size=params.batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn)
print('valing: {:,} total images {:,} mini batches'.format(len(dataset_val), len(dataloader_val)))

In [None]:
# TEST
batch = [
    dataset_train[0],
    dataset_train[1],
#     dataset_train[2],
#     dataset_train[3],
    {'bboxes': np.array([[]]), 'pointers': np.array([]), 'image': np.array([[]]), 'is_empty': True}
]

c = collate_fn(batch)

### Visualize the data

In [None]:
from doc_visualize import plot_points_and_lines

In [None]:
for ix, batch in enumerate(dataloader_train):
    if ix == 20:
        break
    
    if batch is None:
        continue
    
    image = batch['images'][0].data.cpu().numpy().squeeze()
    bboxes = batch['sequence'][0].data.cpu().numpy().squeeze()
    pointers = batch['pointers'][0].data.cpu().numpy().squeeze()
    scale = batch['scales'][0].data.cpu().numpy()
#     plot_word_bboxes_ponters(image, bboxes * scale, pointers, figsize=(20, 10))

    plt.figure()
    plot_points_and_lines(bboxes, pointers, image=image, scale=scale)
    print(pointers)
    plt.show()

### Input/Output definition

#### Inputs
0. Choose a sub-section of the page!
1. Take all the words (points): just b-boxes for now, we'll add in the words later
2. Order matters - let's always sort them top-to-bottom & left-to-right. Discretize the coordinates with some basic thresholding.

#### How should the output be structured?
1. sequence of pointers (duh!)
2. Different containers (text-lines) must be separted by `<EOC>`
3. The entire sequence should end with an `<EOS>`
4. The pointers within each group must be sorted in the order: top-bottom, left-right

## Define the model

In [None]:
from pointer_net import PointerNet

In [None]:
model = PointerNet(n_in=2)

# SANITY RUN THE MODEL
batch = next(iter(dataloader_val))
sequence = batch['sequence']
seq_lens = batch['sequence_lens']

pointers = model(sequence, seq_lens, max_output_len=10)
print(points.shape)
print(pointers.shape)
pointers.sum(dim=-1)

In [None]:
if USE_CUDA >= 0:
    model.cuda(device=params.gpu_device)

## Define the optimizer / loss

In [None]:
model_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=params.lr)

## Logging

In [None]:
model_str = 'test-1'

# logging
weights_folder = "/opt/weights/{}".format(model_str)
log_folder =  '../tensorboard-logs/{}'.format(model_str)
writer = SummaryWriter(log_folder) # writing log to tensorboard
print('logging to: {}'.format(weights_folder))

os.makedirs(weights_folder)  # MEANT TO FAIL IF IT ALREADY EXISTS

## Train

In [None]:
def predict_and_eval(model, batch, CCE):
    points = Variable(batch['sequence'])
    target_pointers = Variable(batch['pointers'])  # FIXME: Must append an EOS token
    seq_lens, target_pointer_lens = batch['sequence_lens'], batch['pointer_lens']

    if USE_CUDA:
        points = points.cuda(params.gpu_device)
        target_pointers = target_pointers.cuda(params.gpu_device)

    # generate as many outputs as in the target sequence
    n_outputs = target_pointer_lens.max()
    pointers = model(points, seq_lens, max_output_len=n_outputs)  # FIXME: because we don't have an EOS token. Also, makes sense during traing
    assert n_outputs == pointers.shape[1]

    n_classes = pointers.shape[-1]
    loss = CCE(pointers.contiguous().view(-1, n_classes), target_pointers.contiguous().view(-1))
    return pointers, loss


def visualize(batch, pred_pointers, figsize=(10, 5)):
    image = batch['images'].data.cpu().numpy()[0].squeeze()
    bboxes = batch['sequence'].data.cpu().numpy()[0].squeeze()
    target_pointers = batch['pointers'].data.cpu().numpy()[0].squeeze()
    scale = batch['scales'].data.cpu().numpy()[0].squeeze()

    assert len(target_pointers) == pred_pointers.shape[0]
    print('Targets: {}, Preds: {}'.format(target_pointers.flatten(), pred_pointers.flatten()))

    plt.figure()
    plot_points_and_lines(bboxes, target_pointers)  # , image=image, scale=scale)
    
    plt.figure()
    plot_points_and_lines(bboxes, pred_pointers)  # , image=image, scale=scale)
    
    plt.show()

def get_normalized_loss_func(pointers):
    """
    Calculates loss weights based on the numbers in "pointers" and returns a loss function initialized with those weights.
    """
    bc = np.bincount(pointers)
    bc = 1. / bc
    bc /= bc.sum()

    weight = Variable(torch.from_numpy(bc.astype(np.float32))).cuda(DEVICE)
    loss_func = torch.nn.CrossEntropyLoss(weight=weight)
    return loss_func

In [None]:
save_every = 10000
val_every = 100

In [None]:
epoch, i_batch = 0, 0

In [None]:
while epoch < 5000:  # params.n_epochs:
    train_data_iter = iter(dataloader_train)
    while i_batch < len(dataloader_train):
        i_batch += 1
        train_batch = next(train_data_iter)
        iter_cntr = epoch * len(dataloader_train) + i_batch  # The overall iteration number across epochs

        # This could happen because of random cropping - a better cropping strategy would help
        if train_batch is None or len(train_batch['pointers']) == 0:
            continue

        # Forward
        l = train_batch['pointers'].data.cpu().numpy().flatten().max()
        loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100).cuda(DEVICE)
        pointers, train_loss = predict_and_eval(model, train_batch, loss_func)

        # Backprop
        model_optim.zero_grad()
        train_loss.backward()
        model_optim.step()

        writer.add_scalar('train.loss', train_loss.data.cpu().numpy(), iter_cntr)
        
        # Save
        if i_batch % params.save_every == 0:
            torch.save(model.state_dict(), os.path.join(weights_folder, '{}_{}.pt'.format(epoch, i_batch)))

        # Validation
        if i_batch % val_every == 0:

            total_val_loss = 0
            for jx, val_batch in enumerate(dataloader_val):
                if val_batch is None or len(val_batch['pointers']) == 0:
                    continue

                if jx == 10:
                    break
                l = val_batch['pointers'].data.cpu().numpy().flatten().max()
                loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100).cuda(DEVICE)
                pointers, val_loss = predict_and_eval(model, val_batch, loss_func)
                total_val_loss += val_loss.data.cpu().numpy()

                # plot few
                if jx < 4:
                    pred_pointers = pointers.argmax(dim=-1).data.cpu().numpy()[0]
                    visualize(val_batch, pred_pointers)

            writer.add_scalar('val.loss', total_val_loss / 10, iter_cntr)
    epoch += 1