# Pointer networks for words->fields

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import collections
import matplotlib.pyplot as plt
import numpy as np
import os
import torch

%matplotlib inline

In [None]:
from Dataset.dataset_docschema import Dataset
from docreader.evaluation.metrics.bbox_evaluation import calculate_iou
from docschema.semantic import Word, Paragraph, TextLine, Section, Document, Field
from typing import Union

### Define parameters

In [None]:
Params = collections.namedtuple('Params', [
    'gpu_device',
    'batch_size', 'embedding_size', 'hiddens', 'n_lstms', 'dropout', 'bidir',
    'lr', 'n_epochs', 'val_every', 'save_every',
    'target_container',
])

In [None]:
params = Params(
    gpu_device=1,
    
    # Data
    batch_size=1,
    
    # MODEL SPECIFC
    target_container=TextLine,
    
    # Training params
    lr=1e-4,
    n_epochs=50,
    val_every=100,
    save_every=1000,
    
    # Model params # FIXME: NOT USED RIGHT NOW!
    embedding_size=128,
    hiddens=512,
    n_lstms=2,
    dropout=0,
    bidir=False,
)

In [None]:
USE_CUDA = params.gpu_device >= 0 and torch.cuda.is_available()
DEVICE = params.gpu_device

## Load the data

In [None]:
EOC_TOKEN = np.array([-1, -1, -1, -1])


class Preprocessor:
    """
    Pre-processor for the 
    """
    def __init__(self, target_container, crop_h=500, crop_w=None, random_shuffle: bool = False):
        self.target_container = target_container
        self.crop_w = crop_w
        self.crop_h = crop_h
        self.random_shuffle = random_shuffle

    def __call__(self, doc):
        words_and_containers = self.get_words_and_containers(doc, self.target_container)
        final_data = self.preprocess_bboxes_pointers(words_and_containers, EOC_TOKEN)

        if self.random_shuffle:
            final_data = self.random_shuffle_sequence(final_data)

        return final_data

    def _get_relative_bbox(self, bbox, start_row, start_col, h, w):
        return (
            max(bbox[0] - start_col, 0),
            max(bbox[1] - start_row, 0),
            min(bbox[2] - start_col, w),
            min(bbox[3] - start_row, h)
        )

    def _get_random_crop(self, doc, image, crop_w=None, crop_h=500):
        h, w = image.shape[:2]
        if crop_w is None:
            crop_w = w
        if crop_h is None:
            crop_h = h

        start_row = np.random.randint(0, high=(h-crop_h+1))
        start_col = np.random.randint(0, high=(w-crop_w+1))

        # Clear the word annotations from the Document Schema
        words = doc.filter_descendants(Word)

        for word in words:
            if word.bbox is None:
                continue

            # Remove the word if it overlaps with the region to delete
            iou = calculate_iou(
                predicted_bboxes=np.array([(start_col, start_row, start_col+crop_w, start_row+crop_h)]),
                gt_bboxes=np.array([word.bbox]),
                is_xywh=False, denominator='gt')

            if iou[0][0] > 0.2:  # Don't delete the word!
                word.bbox = self._get_relative_bbox(word.bbox, start_row, start_col, crop_h, crop_w)
                continue

            word.parent = None
            del word

        image = image[start_row: start_row+crop_h, start_col: start_col+crop_w]
        return doc, image

    def get_words_and_containers(self, doc: Document, target_container: Union[TextLine, Paragraph, Field, str]):
        """
        Inputs: document and the container to generate sequences for.
        Outputs: a dictionary containing the image, the 

        The image/doc gets cropped to a smaller size.
        """
        if isinstance(target_container, str) and target_container not in ['key', 'value']:
            raise ValueError('target_container doesnot have the correct type / value')
        image = doc.rendered_image
        if len(doc.filter_descendants(target_container)) == 0:
            raise ValueError('The doc has no containers from the target container!')

        doc, image = self._get_random_crop(doc, image, crop_w=self.crop_w, crop_h=self.crop_h)

        list_containers = []
        list_child_words = []

        for el in doc.filter_descendants(target_container):
            child_words = el.filter_descendants(Word)
            if len(child_words) == 0:
                continue
            list_child_words.append(child_words)
            list_containers.append(el)

        assert len(list_child_words) == len(list_containers)
        assert np.all(image.shape[:2] == np.array([self.crop_h, self.crop_w]))

        return {
            'containers': list_containers,
            'words': list_child_words,
            'image': image,
        }

    def _discretize(self, vals: np.ndarray, binv: int) -> np.ndarray:
        maxval = max(vals)
        bins = np.arange(0, maxval+1, binv)
        return (np.digitize(vals, bins) - 1) * binv

    def _get_sorted_bboxes_inds(self, bboxes: np.ndarray, binv=3) -> np.ndarray:
        """
        Sort bounding boxes from top to bottom, left to right after discretizing the co-ordinates by binv 
        """
        if bboxes.shape[1] != 4:
            raise ValueError

        # FIXME: The scale might help here
        lefts = self._discretize(bboxes[:, 0], binv)
        tops = self._discretize(bboxes[:, 1], binv)

        # The key allows us to easily implement the desired sorting
        keys = tops * 1e5 + lefts
        sort_inds = np.argsort(keys)
        return sort_inds

    def preprocess_bboxes_pointers(self, data, EOC_TOKEN):
        # Inputs
        containers = data['containers']
        child_words = data['words']

        container_bboxes = np.array([el.bbox for el in containers if el.bbox is not None])
        if len(container_bboxes) == 0:
    #         raise ValueError('no elements in the cropped area')
            return {'bboxes': np.array([[]]), 'pointers': np.array([]), 'image': np.array([[]]), 'is_empty': True}

        sort_inds_container = self._get_sorted_bboxes_inds(container_bboxes)

        all_word_bboxes = [EOC_TOKEN]  # ADD ALL THE FIXED TOKENS
        pointer_seq = []

        word_idx_start = len(all_word_bboxes)
        for idx in sort_inds_container:  # iterate over the containers in the order top-bottom left-right
            words = child_words[idx]
            word_bboxes = np.vstack([word.bbox for word in words if word.bbox is not None])
            if len(word_bboxes) == 0:
                continue

            word_bboxes = word_bboxes[self._get_sorted_bboxes_inds(word_bboxes)]

            all_word_bboxes.append(word_bboxes)
            n_words = len(word_bboxes)

            # Get the indices for the words and store them as the "pointers"
            pointer_seq.extend(np.arange(n_words) + word_idx_start)
            pointer_seq.append(0)  # NOTE: 0 is the index of the fixed EOC token
            word_idx_start += n_words

        all_word_bboxes = np.vstack(all_word_bboxes).astype(np.float32)
        pointer_seq = np.array(pointer_seq).astype(np.long)

        return {
            'bboxes': all_word_bboxes,
            'pointers': pointer_seq,
            'image': data['image'],
            'is_empty': False,
        }

    @staticmethod
    def random_shuffle_sequence(datum):
        """
        Randomly shuffles the input sequence and the other arrays correspondingly.
        """
        is_empty = datum['is_empty']
        if is_empty:
            return datum
        
#         print('shuffling!')

        bboxes, pointers = datum['bboxes'], datum['pointers']
        n = len(bboxes)
        
        # Generate new 
        inds_new_order = np.arange(n)
        np.random.shuffle(inds_new_order)
        bboxes = bboxes[inds_new_order].squeeze()
        
        # map the pointers to the new indices
        inds_reverse = np.zeros(n)
        inds_reverse[inds_new_order] = np.arange(n)
        new_pointers = inds_reverse[pointers].astype(np.long)
        assert np.all(new_pointers.shape == pointers.shape)
        
        return {
            'bboxes': bboxes,
            'pointers': new_pointers,
            'image': datum['image'],
            'is_empty': is_empty,
        }

In [None]:
# Data loader specific
def get_padded_tensor_and_lens(list_seqs, pad_constant_value=0, n_dim=2):
    lens = np.array([len(x) for x in list_seqs])
    # Each sequence is an array of shape seq_len*n_dim
    for ix in range(len(list_seqs)):
        seq = list_seqs[ix]
        if len(seq) == 0 or len(seq[0]) == 0:
            list_seqs[ix] = np.zeros(n_dim, dtype=np.float32)[np.newaxis, :]
        seq = list_seqs[ix]
        assert len(seq.shape) == 2, 'Actual shape is: {}'.format(seq.shape)
        assert seq.shape[1] == n_dim

    max_len = max(lens)
    data = np.array([
        np.pad(seq, pad_width=[(0, max_len - len(seq)), (0, 0)], mode='constant', constant_values=pad_constant_value)
        for seq in list_seqs
    ])

    return data, lens


def collate_fn(batch):
    inds_to_take = np.array([not sample['is_empty'] for sample in batch], dtype=np.bool)
    batch = np.array(batch)[inds_to_take]
    assert len(batch) == sum(inds_to_take)
    
    if len(batch) == 0:
        return None

    sequences, lens1 = get_padded_tensor_and_lens([sample['bboxes'] for sample in batch], pad_constant_value=0, n_dim=4)
    pointers, lens2 = get_padded_tensor_and_lens([sample['pointers'][..., np.newaxis] for sample in batch], pad_constant_value=-100, n_dim=1)
    
    # Sort such that the longest sequence is first. Sort the pointers to match the sequences.
    inds_sorted_desc = np.argsort(lens1)[::-1]
    sequences, lens1 = sequences[inds_sorted_desc, ...], lens1[inds_sorted_desc]
    pointers, lens2 = pointers[inds_sorted_desc, ...], lens2[inds_sorted_desc]
    
    sequences = torch.from_numpy(sequences)
    pointers = torch.from_numpy(pointers)
    
    # Get the images
    images = np.array([sample['image'][np.newaxis, ...] for sample in batch])
    images = images[inds_sorted_desc]
    images = torch.from_numpy(images)
    
    return {
        'sequence': sequences,
        'sequence_lens': lens1,
        'pointers': pointers,
        'pointer_lens': lens2,
        'images': images,
    }


In [None]:
preprocessor = Preprocessor(params.target_container, crop_h=200, crop_w=500, random_shuffle=False)

In [None]:
# NOTE: specify training data
synth_list_files = ['/opt/data/field-train-acord-20190208-large-train/synth.list']
# dataset = Dataset(synth_list_files, adapter=lambda x: x)
dataset_train = Dataset(synth_list_files, adapter=preprocessor)
dataloader_train = DataLoader(dataset_train, batch_size=params.batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn)
print('Training: {:,} total images {:,} mini batches'.format(len(dataset_train), len(dataloader_train)))

In [None]:
# NOTE: specify valing data
synth_list_files = ['/opt/data/field-train-acord-20190208-large-val/synth.list']
# dataset = Dataset(synth_list_files, adapter=lambda x: x)
dataset_val = Dataset(synth_list_files, adapter=preprocessor)
dataloader_val = DataLoader(dataset_val, batch_size=params.batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn)
print('valing: {:,} total images {:,} mini batches'.format(len(dataset_val), len(dataloader_val)))

In [None]:
# TEST
batch = [
    dataset_train[0],
    dataset_train[1],
#     dataset_train[2],
#     dataset_train[3],
    {'bboxes': np.array([[]]), 'pointers': np.array([]), 'image': np.array([[]]), 'is_empty': True}
]

c = collate_fn(batch)

### Visualize the data

In [None]:
import matplotlib.patches as mpatches

def _draw_bbox(ax, bbox, margin=0, color='r', linestyle='solid', fill=False, **kwargs):
    x1, y1, x2, y2 = bbox
    w, h = x2 - x1, y2 - y1

    x1 -= margin
    y1 -= margin
    w += 2*margin
    h += 2*margin

    rect = mpatches.Rectangle((x1, y1), w, h, fill=fill, color=color, linestyle=linestyle, **kwargs)
    ax.add_patch(rect)

def plot_word_bboxes_ponters(image, word_bboxes, pointers, figsize=(10, 5)):
    colors = [
        (1, 0, 0, 0.2),
        (1, 1, 0, 0.2),
        (1, 0, 1, 0.2),
        (0.5, 0, 0, 0.2),
        (0, 0, 0.5, 0.2),
        (0, 0.5, 0, 0.2),
        (0.5, 0.5, 0, 0.2),
    ]

    # Plot image
    plt.figure(figsize=figsize)
    ax = plt.subplot(1, 1, 1)
    ax.imshow(image)

    # Draw Words
#     np.random.seed(10)
    color = colors[np.random.randint(len(colors))]
    for pointer in pointers:
        bbox = word_bboxes[pointer]
        if np.all(bbox == EOC_TOKEN):  # THE EOC token
            color = colors[np.random.randint(len(colors))]
            continue

        bbox = word_bboxes[pointer].flatten()
        _draw_bbox(ax, bbox, fill=True, color=color)
        plt.text((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2, str(pointer), ha='center', color='r')

In [None]:
for ix, batch in enumerate(dataloader_train):
    if ix == 4:
        break
    
    if batch is None:
        continue
    
    image = batch['images'][0].data.cpu().numpy().squeeze()
    bboxes = batch['sequence'][0].data.cpu().numpy().squeeze()
    pointers = batch['pointers'][0].data.cpu().numpy().squeeze()
    plot_word_bboxes_ponters(image, bboxes, pointers, figsize=(5, 10))
#     print(bboxes)
#     print(pointers)
    plt.show()

## Data Statistics

In [None]:
data_iter = iter(dataloader_train)
all_pointer_inds = []
for ix in tqdm(range(200)):
    batch = next(data_iter)
    if batch is None or len(batch['pointers']) == 0:
        continue

    all_pointer_inds.extend(batch['pointers'].data.numpy().flatten())

all_pointer_inds = np.array(all_pointer_inds)

In [None]:
all_pointer_inds = all_pointer_inds[all_pointer_inds != -100]

In [None]:
bc = np.bincount(all_pointer_inds)
w = 1. / bc
w /= w.sum()

In [None]:
plt.plot(w)

In [None]:
# The actual sequences can be of any length. So, add a constant probability for all the remaining numbers
max_ind = 300
w = np.hstack([w, [w[-1]] * (max_ind - len(w))])

In [None]:
plt.plot(w)

### Input/Output definition

#### Inputs
0. Choose a sub-section of the page!
1. Take all the words (points): just b-boxes for now, we'll add in the words later
2. Order matters - let's always sort them top-to-bottom & left-to-right. Discretize the coordinates with some basic thresholding.

#### How should the output be structured?
1. sequence of pointers (duh!)
2. Different containers (text-lines) must be separted by `<EOC>`
3. The entire sequence should end with an `<EOS>`
4. The pointers within each group must be sorted in the order: top-bottom, left-right

## Define the model

In [None]:
from pointer_net import PointerNet

In [None]:
model = PointerNet(n_in=4)

# SANITY RUN THE MODEL
# batch = next(iter(dataloader_val))
# points = batch['bboxes'][0]

# pointers = model(points[np.newaxis, ...], 10)
# print(points.shape)
# print(pointers.shape)
# pointers.sum(dim=2)

In [None]:
if USE_CUDA >= 0:
    model.cuda(device=params.gpu_device)
#     cudnn.benchmark = True

## Define the optimizer / loss

In [None]:
model_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=params.lr)

## Logging

In [None]:
model_str = 'ptr-textline-random-one-2.00'

# logging
weights_folder = "/opt/weights/{}".format(model_str)
log_folder =  '../tensorboard-logs/{}'.format(model_str)
writer = SummaryWriter(log_folder) # writing log to tensorboard
print('logging to: {}'.format(weights_folder))

os.makedirs(weights_folder)  # MEANT TO FAIL IF IT ALREADY EXISTS

## Train

In [None]:
def predict_and_eval(model, batch, CCE):
    points = Variable(batch['sequence'])
    target_pointers = Variable(batch['pointers'])  # FIXME: Must append an EOS token
    seq_lens, target_pointer_lens = batch['sequence_lens'], batch['pointer_lens']

    if USE_CUDA:
        points = points.cuda(params.gpu_device)
        target_pointers = target_pointers.cuda(params.gpu_device)

    # generate as many outputs as in the target sequence
    n_outputs = target_pointer_lens.max()
    pointers = model(points, seq_lens, max_output_len=n_outputs)  # FIXME: because we don't have an EOS token. Also, makes sense during traing
    assert n_outputs == pointers.shape[1]

    n_classes = pointers.shape[-1]
    loss = CCE(pointers.contiguous().view(-1, n_classes), target_pointers.contiguous().view(-1))
    return pointers, loss


def visualize(batch, pred_pointers, figsize=(10, 5)):
    image = batch['images'].data.cpu().numpy()[0].squeeze()
    word_bboxes = batch['sequence'].data.cpu().numpy()[0]
    target_pointers = batch['pointers'].data.cpu().numpy()[0]

    assert len(target_pointers) == pred_pointers.shape[0]
    print('Targets: {}, Preds: {}'.format(target_pointers.flatten(), pred_pointers.flatten()))

    print("Target")
    plot_word_bboxes_ponters(image, word_bboxes, target_pointers, figsize=figsize)
    plt.show()
    print("Predicted")
    plot_word_bboxes_ponters(image, word_bboxes, pred_pointers, figsize=figsize)
    plt.show()

def get_normalized_loss_func(pointers):
    """
    Calculates loss weights based on the numbers in "pointers" and returns a loss function initialized with those weights.
    """
    bc = np.bincount(pointers)
    bc = 1. / bc
    bc /= bc.sum()

    weight = Variable(torch.from_numpy(bc.astype(np.float32))).cuda(DEVICE)
    loss_func = torch.nn.CrossEntropyLoss(weight=weight)
    return loss_func

In [None]:
save_every = 10000
val_every = 100

In [None]:
epoch, i_batch = 0, 0

In [None]:
while epoch < params.n_epochs:
    train_data_iter = iter(dataloader_train)
    while i_batch < len(dataloader_train):
        i_batch += 1
        train_batch = next(train_data_iter)
        iter_cntr = epoch * len(dataloader_train) + i_batch  # The overall iteration number across epochs

        # This could happen because of random cropping - a better cropping strategy would help
        if train_batch is None or len(train_batch['pointers']) == 0:
            continue

        # Forward
        l = train_batch['pointers'].data.cpu().numpy().flatten().max()
        weight = Variable(torch.from_numpy(w[:l+1].astype(np.float32))).cuda(DEVICE)
        loss_func = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=-100).cuda(DEVICE)
#         loss_func = get_normalized_loss_func(train_batch['pointers'].data.cpu().numpy().flatten())
        pointers, train_loss = predict_and_eval(model, train_batch, loss_func)

        # Backprop
        model_optim.zero_grad()
        train_loss.backward()
        model_optim.step()

        writer.add_scalar('train.loss', train_loss.data.cpu().numpy(), iter_cntr)
        
        # Save
        if i_batch % params.save_every == 0:
            torch.save(model.state_dict(), os.path.join(weights_folder, '{}_{}.pt'.format(epoch, i_batch)))

        # Validation
        if i_batch % val_every == 0:

            total_val_loss = 0
            for jx, val_batch in enumerate(dataloader_val):
                if val_batch is None or len(val_batch['pointers']) == 0:
                    continue

                if jx == 10:
                    break
                l = val_batch['pointers'].data.cpu().numpy().flatten().max()
                weight = Variable(torch.from_numpy(w[:l+1].astype(np.float32))).cuda(DEVICE)
                loss_func = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=-100).cuda(DEVICE)
#                 loss_func = get_normalized_loss_func(val_batch['pointers'].data.cpu().numpy().flatten())
                pointers, val_loss = predict_and_eval(model, val_batch, loss_func)
                total_val_loss += val_loss.data.cpu().numpy()

                # plot few
                if jx < 4:
                    pred_pointers = pointers.argmax(dim=-1).data.cpu().numpy()[0]
                    visualize(val_batch, pred_pointers)

            writer.add_scalar('val.loss', total_val_loss / 10, iter_cntr)
    epoch += 1