# Pointer networks for words->fields

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import collections
import matplotlib.pyplot as plt
import numpy as np
import os
import torch

%matplotlib inline

In [None]:
from Dataset.dataset_docschema import Dataset
from docreader.evaluation.metrics.bbox_evaluation import calculate_iou
from docschema.semantic import Word, Paragraph, TextLine, Section, Document, Field
from typing import Union

### Define parameters

In [None]:
Params = collections.namedtuple('Params', [
    'gpu_device',
    'batch_size', 'embedding_size', 'hiddens', 'n_lstms', 'dropout', 'bidir',
    'lr', 'n_epochs',
    'target_container',
])

In [None]:
params = Params(
    gpu_device=1,
    
    # Data
    batch_size=1,
    
    # MODEL SPECIFC
    target_container=TextLine,
    
    # Training params
    lr=1e-4,
    n_epochs=50,
    
    # Model params # FIXME: NOT USED RIGHT NOW!
    embedding_size=128,
    hiddens=512,
    n_lstms=2,
    dropout=0,
    bidir=False,
)

In [None]:
USE_CUDA = params.gpu_device >= 0 and torch.cuda.is_available()
DEVICE = params.gpu_device

## Load the data

In [None]:
def _get_relative_bbox(bbox, start_row, start_col, h, w):
    return (
        max(bbox[0] - start_col, 0),
        max(bbox[1] - start_row, 0),
        min(bbox[2] - start_col, w),
        min(bbox[3] - start_row, h)
    )
    

def get_random_crop(doc, image, crop_w=None, crop_h=500):
    w, h = image.shape[:2]
    if crop_w is None:
        crop_w = w
    if crop_h is None:
        crop_h = h
    
    start_row = np.random.randint(0, high=(h-crop_h+1))
    start_col = np.random.randint(0, high=(w-crop_w+1))

    # Clear the word annotations from the Document Schema
    words = doc.filter_descendants(Word)

    for word in words:
        if word.bbox is None:
            continue

        # Remove the word if it overlaps with the region to delete
        iou = calculate_iou(
            predicted_bboxes=np.array([(start_col, start_row, start_col+crop_w, start_row+crop_h)]),
            gt_bboxes=np.array([word.bbox]),
            is_xywh=False, denominator='gt')

        if iou[0][0] > 0.2:  # Don't delete the word!
            word.bbox = _get_relative_bbox(word.bbox, start_row, start_col, crop_h, crop_w)
            continue

        word.parent = None
        del word

    image = image[start_row: start_row+crop_h, start_col: start_col+crop_w]
    return doc, image

def get_words_and_containers(doc: Document, target_container: Union[TextLine, Paragraph, Field, str]):
    """
    Inputs: document and the container to generate sequences for.
    Outputs: a dictionary containing the image, the 
    
    The image/doc gets cropped to a smaller size.
    """
    if isinstance(target_container, str) and target_container not in ['key', 'value']:
        raise ValueError('target_container doesnot have the correct type / value')
    image = doc.rendered_image
    if len(doc.filter_descendants(target_container)) == 0:
        raise ValueError('The doc has no containers from the target container!')

    doc, image = get_random_crop(doc, image, crop_w=None, crop_h=500)

    list_containers = []
    list_child_words = []

    for el in doc.filter_descendants(target_container):
        child_words = el.filter_descendants(Word)
        if len(child_words) == 0:
            continue
        list_child_words.append(child_words)
        list_containers.append(el)

    assert len(list_child_words) == len(list_containers)

    return {
        'containers': list_containers,
        'words': list_child_words,
        'image': image,
    }

def discretize(vals: np.ndarray, binv: int) -> np.ndarray:
    maxval = max(vals)
    bins = np.arange(0, maxval+1, binv)
    return (np.digitize(vals, bins) - 1) * binv

assert np.all(discretize([0, 1, 2, 3, 4], 3) == [0, 0, 0, 3, 3])

def get_sorted_bboxes_inds(bboxes: np.ndarray, binv=3) -> np.ndarray:
    """
    Sort bounding boxes from top to bottom, left to right after discretizing the co-ordinates by binv 
    """
    if bboxes.shape[1] != 4:
        raise ValueError

    # FIXME: The scale might help here
    lefts = discretize(bboxes[:, 0], binv)
    tops = discretize(bboxes[:, 1], binv)

    # The key allows us to easily implement the desired sorting
    keys = tops * 1e5 + lefts
    sort_inds = np.argsort(keys)
    return sort_inds

def preprocess_bboxes_pointers(data):
    EOC_TOKEN = np.array([-1, -1, -1, -1])
    # Inputs
    containers = data['containers']
    child_words = data['words']

    container_bboxes = np.array([el.bbox for el in containers if el.bbox is not None])
    if len(container_bboxes) == 0:
#         raise ValueError('no elements in the cropped area')
        return {'bboxes': [[]], 'pointers': [], 'image': []}
        
    sort_inds_container = get_sorted_bboxes_inds(container_bboxes)

    all_word_bboxes = [EOC_TOKEN]  # ADD ALL THE FIXED TOKENS
    pointer_seq = []

    word_idx_start = len(all_word_bboxes)
    for idx in sort_inds_container:  # iterate over the containers in the order top-bottom left-right
        words = child_words[idx]
        word_bboxes = np.vstack([word.bbox for word in words if word.bbox is not None])
        if len(word_bboxes) == 0:
            continue

        word_bboxes = word_bboxes[get_sorted_bboxes_inds(word_bboxes)]

        all_word_bboxes.append(word_bboxes)
        n_words = len(word_bboxes)
        
        # Get the indices for the words and store them as the "pointers"
        pointer_seq.extend(np.arange(n_words) + word_idx_start)
        pointer_seq.append(0)  # NOTE: 0 is the index of the fixed EOC token
        word_idx_start += n_words

    all_word_bboxes = np.vstack(all_word_bboxes).astype(np.float32)
    pointer_seq = np.array(pointer_seq).astype(np.long)

    return {
        'bboxes': all_word_bboxes,
        'pointers': pointer_seq,
        'image': data['image'],
    }

In [None]:
preprocessor1 = lambda x: get_words_and_containers(x, params.target_container)
preprocessor2 = preprocess_bboxes_pointers
preprocessor = lambda x: preprocessor2(preprocessor1(x))

In [None]:
# NOTE: specify training data
synth_list_files = ['/opt/data/field-train-acord-20190208-large-train/synth.list']
# dataset = Dataset(synth_list_files, adapter=lambda x: x)
dataset_train = Dataset(synth_list_files, adapter=preprocessor)
dataloader_train = DataLoader(dataset_train, batch_size=params.batch_size, shuffle=True, num_workers=1)
print('Training: {:,} total images {:,} mini batches'.format(len(dataset_train), len(dataloader_train)))

In [None]:
# NOTE: specify valing data
synth_list_files = ['/opt/data/field-train-acord-20190208-large-val/synth.list']
# dataset = Dataset(synth_list_files, adapter=lambda x: x)
dataset_val = Dataset(synth_list_files, adapter=preprocessor)
dataloader_val = DataLoader(dataset_val, batch_size=params.batch_size, shuffle=True, num_workers=1)
print('valing: {:,} total images {:,} mini batches'.format(len(dataset_val), len(dataloader_val)))

## Data Statistics

In [None]:
# data_iter = iter(dataloader_val)
# all_seq_lens = []
# for ix in tqdm(range(200)):
#     batch = next(data_iter)
#     if len(batch['pointers']) == 0:
#         continue

#     all_seq_lens.append(len(batch['pointers'][0]))

# plt.hist(all_seq_lens)

### Visualize the data

In [None]:
import matplotlib.patches as mpatches

def _draw_bbox(ax, bbox, margin=0, color='r', linestyle='solid', fill=False, **kwargs):
    x1, y1, x2, y2 = bbox
    w, h = x2 - x1, y2 - y1

    x1 -= margin
    y1 -= margin
    w += 2*margin
    h += 2*margin

    rect = mpatches.Rectangle((x1, y1), w, h, fill=fill, color=color, linestyle=linestyle, **kwargs)
    ax.add_patch(rect)

In [None]:
def plot_word_bboxes_ponters(image, word_bboxes, pointers):
    colors = [
        (1, 0, 0, 0.2),
        (1, 1, 0, 0.2),
        (1, 0, 1, 0.2),
        (0.5, 0, 0, 0.2),
        (0, 0, 0.5, 0.2),
        (0, 0.5, 0, 0.2),
        (0.5, 0.5, 0, 0.2),
    ]

    # Plot image
    plt.figure(figsize=(15, 20))
    ax = plt.subplot(1, 1, 1)
    ax.imshow(image)

    # Draw Words
    np.random.seed(10)
    color = colors[np.random.randint(len(colors))]
    for pointer in pointers:
        if pointer == 0:  # THE EOC token
            color = colors[np.random.randint(len(colors))]
            continue

        bbox = word_bboxes[pointer]
        _draw_bbox(ax, bbox, fill=True, color=color)
        plt.text((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2, str(pointer), ha='center', color='r')

### Input/Output definition

#### Inputs
0. Choose a sub-section of the page!
1. Take all the words (points): just b-boxes for now, we'll add in the words later
2. Order matters - let's always sort them top-to-bottom & left-to-right. Discretize the coordinates with some basic thresholding.

#### How should the output be structured?
1. sequence of pointers (duh!)
2. Different containers (text-lines) must be separted by `<EOC>`
3. The entire sequence should end with an `<EOS>`
4. The pointers within each group must be sorted in the order: top-bottom, left-right

## Define the model

In [None]:
from pointer_net import PointerNet

In [None]:
model = PointerNet(n_in=4)

# SANITY RUN THE MODEL
batch = next(iter(dataloader_val))
points = batch['bboxes'][0]

pointers = model(points[np.newaxis, ...], 10)
print(points.shape)
print(pointers.shape)
pointers.sum(dim=2)

In [None]:
if USE_CUDA >= 0:
    model.cuda(device=params.gpu_device)
#     cudnn.benchmark = True

## Define the optimizer / loss

In [None]:
model_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=params.lr)

## Logging

In [None]:
model_str = 'ptr-textline-loss-scaled-1.01'

# logging
weights_folder = "/opt/weights/{}".format(model_str)
log_folder =  '../tensorboard-logs/{}'.format(model_str)
writer = SummaryWriter(log_folder) # writing log to tensorboard
print('logging to: {}'.format(weights_folder))

os.makedirs(weights_folder)  # MEANT TO FAIL IF IT ALREADY EXISTS

## Train

In [None]:
def predict_and_eval(model, batch, CCE):
    points = Variable(batch['bboxes'])
    target_pointers = Variable(batch['pointers'])  # FIXME: Must append an EOS token

    if USE_CUDA:
        points = points.cuda(params.gpu_device)
        target_pointers = target_pointers.cuda(params.gpu_device)

    # generate as many outputs as in the target sequence
    n_outputs = len(target_pointers[0])
    pointers = model(points, max_output_len=n_outputs)  # FIXME: because we don't have an EOS token. Also, makes sense during traing
    assert n_outputs == pointers.shape[1]

    loss = CCE(pointers.squeeze(), target_pointers.squeeze())
    return pointers, loss


def visualize(batch, pred_pointers):
    image = batch['image'].data.cpu().numpy()[0]
    word_bboxes = batch['bboxes'].data.cpu().numpy()[0]
    target_pointers = batch['pointers'].data.cpu().numpy()[0]

    assert len(target_pointers) == pred_pointers.shape[0]
    print('Targets: {}, Preds: {}'.format(target_pointers, pred_pointers))

    print("Target")
    plot_word_bboxes_ponters(image, word_bboxes, target_pointers)
    plt.show()
    print("Predicted")
    plot_word_bboxes_ponters(image, word_bboxes, pred_pointers)
    plt.show()

In [None]:
def get_normalized_loss_func(pointers):
    bc = np.bincount(pointers)
    bc = 1. / bc
    bc /= bc.sum()

    weight = Variable(torch.from_numpy(bc.astype(np.float32))).cuda(DEVICE)
    loss_func = torch.nn.CrossEntropyLoss(weight=weight)
    return loss_func

In [None]:
save_every = 10000
val_every = 100

In [None]:
for epoch in range(params.n_epochs):
    for i_batch, train_batch in enumerate(dataloader_train):
        iter_cntr = epoch * len(dataloader_train) + i_batch  # The overall iteration number across epochs
            
        # This could happen because of random cropping - a better cropping strategy would help
        if len(train_batch['pointers']) == 0:
            continue

        # Forward
        loss_func = get_normalized_loss_func(train_batch['pointers'].data.cpu().numpy().flatten())
        pointers, train_loss = predict_and_eval(model, train_batch, loss_func)

        # Backprop
        model_optim.zero_grad()
        train_loss.backward()
        model_optim.step()

        writer.add_scalar('train.loss', train_loss.data.cpu().numpy(), iter_cntr)
        
        # Save
        if i_batch % save_every == 0:
            torch.save(model.state_dict(), os.path.join(weights_folder, '{}_{}.pt'.format(epoch, i_batch)))

        # Validation
        if i_batch % val_every == 0:

            total_val_loss = 0
            for jx, val_batch in enumerate(dataloader_val):
                if len(val_batch['pointers']) == 0:
                    continue

                if jx == 10:
                    break
                loss_func = get_normalized_loss_func(val_batch['pointers'].data.cpu().numpy().flatten())
                pointers, val_loss = predict_and_eval(model, val_batch, loss_func)
                total_val_loss += val_loss.data.cpu().numpy()

                # plot few
                if jx < 4:
                    pred_pointers = pointers.argmax(dim=-1).data.cpu().numpy()[0]
                    visualize(val_batch, pred_pointers)

            writer.add_scalar('val.loss', total_val_loss / 10, iter_cntr)