# Visual Backbone
This notebook explores `ViT` transformer-based `visual-backbone` architecture (it makes sense taking in account the documents sequential nature) to compare with [CNN-based backbone](Visual-Backbone-CNN.ipynb). 

For the sake of exploration we build all from scratch and run some empirical study on key elements to determine default implementation details.

* [Dataset and Dataloader](#data)
* [ViT transformer model](#blocks)
    * [Blocks](#blocks)
    * [Encoder](#encoder)
    * [Decoder](#decoder)
    * [UNet](#model)    
* [Comparative training and evaluation](#run)
    * [Define models](#1)
    * [Define optimization](#2)
    * [Define validation metrics](#3)
    * [Run parallel training with different configuration](#4)
    * [Evaluate results](#5)
        * [Evaluate embeddings](#embeddings)


In [None]:
import os
import re
import torch
import numpy as np
import pandas as pd

from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import colormaps
from IPython.display import SVG
from pathlib import Path
from einops import rearrange, reduce, repeat
from sklearn.metrics import silhouette_score

from torch import nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.cuda.amp import GradScaler
from torchmetrics import F1Score, JaccardIndex
from torchsummary import summary

In [None]:
# load local notebook-utils
from scripts import render
from scripts.dataset import *
from scripts.trainer import *
from scripts.backbone import *

In [None]:
#torch._dynamo.config.verbose = True
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

In [None]:
# images with semantic segmentation masks available
images = [str(x).split('/').pop() for x in Path(f'{ROOT}/data/masks').glob('*.png')]
len(images)

In [None]:
VIEW_SIZE = 128
LATENT_DIM = 512

<a name="data"></a>

## Dataset
The same [dataset we used for CNN-based model](Visual-Backbone-CNN.ipynb#data) -- our decoder handles `value` extraction and denoising rather than reconstruction: we generate one-dimensional binary masks for the targets; for the inputs we generate a set of random view-ports (center, rotation, zoom) from a noisy version of the page.

In [None]:
# use images with masks
samples = images #np.random.choice(images, 320, replace=False)

In [None]:
# get static batch
batch = prep_batch(samples, RandomViewDataset, 8, VIEW_SIZE)
show_inputs(batch)
show_targets(batch)

# sample input/target
X, Y = batch[0]

<a name="blocks"></a>

## ViT transformer blocks
Transformers deal with sequences of tokens. `ViT` rearranges an image into a sequence of flattened patches and adds a learnable position embedding to a patch embedding before feeding it into a transformer-encoder.

In [None]:
SVG('assets/uvit-blocks.svg')

In [None]:
class ViewToSequence(nn.Module):
    def __init__(self,
                 view_size: int,
                 patch_size: int,
                 embed_size: int,
                 semantic_dim: int = 0,
                 channels: int = 1):
        
        super(ViewToSequence, self).__init__()
        self.patch_size = patch_size
        self.projection = nn.Conv2d(channels, embed_size, kernel_size=patch_size, stride=patch_size)
        # conditional and other tokens
        self.tokens = nn.Parameter(torch.randn(1, semantic_dim, embed_size)) if semantic_dim > 0 else None
        self.positions = nn.Parameter(torch.randn((view_size // patch_size) ** 2 + semantic_dim, embed_size))
                
    def forward(self, x):
        b = x.shape[0]
        # patch-sequence: either linear or conv
        x = self.projection(x)
        x = rearrange(x, 'b e (h) (w) -> b (h w) e')
        if not self.tokens is None:
            tokens = repeat(self.tokens, '() n e -> b n e', b=b)
            # prepend the tokens to the input
            x = torch.cat([tokens, x], dim=1)
        # add positional embedding
        x += self.positions
        return x

seq = ViewToSequence(VIEW_SIZE, 4, LATENT_DIM)(X)
seq.shape

In [None]:
class SequenceToView(nn.Module):
    def __init__(self,
                 view_size: int,
                 patch_size: int,
                 embed_size: int,
                 semantic_dim: int = 0,
                 channels: int = 1):
        
        super(SequenceToView, self).__init__()
        self.patch_size = patch_size
        self.semantic_dim = semantic_dim
        self.projection = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, patch_size ** 2 * channels, bias=True))
        # prevent artifacts
        #self.conv = nn.Conv2d(channels, channels, 3, padding=1)
        
    def forward(self, x):
        x = self.projection(x)
        x = x[:, self.semantic_dim:, :] # skip tokens
        d, p = int(x.shape[1] ** 0.5), self.patch_size
        return rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=d, w=d, p1=p, p2=p)
        #return self.conv(x)
    
SequenceToView(VIEW_SIZE, 4, LATENT_DIM, channels=1)(seq).shape

In [None]:
class Attention(nn.Module):
    def __init__(self,
                 embed_size: int,
                 num_heads: int = 4):
        
        super(Attention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        # queries, keys, values in one matrix
        self.qkv = nn.Linear(embed_size, embed_size * 3, bias=False)
        self.projection = nn.Sequential(
            nn.Linear(embed_size, embed_size),
            nn.ReLU()) # added to eliminate distractions
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        q, k, v = rearrange(self.qkv(x), 'b n (h d qkv) -> (qkv) b h n d', h=self.num_heads, qkv=3)
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', q, k) # batch, num_heads, query_len, key_len
        if mask is not None:
            energy.mask_fill(~mask, torch.finfo(torch.float32).min)
            
        scaling = self.embed_size ** 0.5
        att = torch.softmax(energy, dim=-1) / scaling
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.projection(out)
        return out
    
#seq = ViewToSequence(VIEW_SIZE, 4, LATENT_DIM)(X)
#Attention(LATENT_DIM)(seq).shape

In [None]:
class MLP(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4):
        super(MLP, self).__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Linear(expansion * emb_size, emb_size))


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self,
                 embed_size: int,
                 bridge: bool,
                 expansion: int = 4):
        
        super(TransformerBlock, self).__init__()
        self.attn = nn.Sequential(
            nn.LayerNorm(embed_size),
            Attention(embed_size))
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_size),
            MLP(embed_size, expansion=expansion))
        
        self.merge = nn.Linear(2 * embed_size, embed_size) if bridge else None
            
    def forward(self, x, pass_through=None):
        if self.merge is not None:
            x = self.merge(torch.cat((pass_through, x), dim=2))
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x
        
#TransformerBlock(LATENT_DIM, True)(seq, seq).shape

<a name="encoder"></a>

### Encoder
`Encoder` converts the image into a vector (embedding). 

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self,
                 view_size: int,
                 patch_size: int,
                 embed_size: int,
                 depth: int,
                 expansion: int = 4):
        
        super(TransformerEncoder, self).__init__()
        # patch embed
        self.sequence = ViewToSequence(view_size, patch_size, embed_size)
        # down-blocks
        self.blocks = nn.ModuleList([TransformerBlock(embed_size, False, expansion) for _ in range(depth)])
        self.depth = depth
        self.view_size = view_size
        self.patch_size = patch_size
        self.embed_size = embed_size
                
    def forward(self, x):
        x = self.sequence(x)
        outputs = []
        for block in self.blocks:
            x = block(x)
            outputs.append(x)
        return outputs
    
#for x in TransformerEncoder(VIEW_SIZE, 4, LATENT_DIM, 4)(X): print(x.shape)

To produce embeddings we apply `mean` reduction at the `bottleneck` output.

In [None]:
class MeanReduce(nn.Module):
    def forward(self, x):
        return torch.mean(x, axis=1)
    
class VisualEncoder(nn.Module):
    def __init__(self, backbone: nn.Module, reduce: nn.Module = nn.Identity()):
        super().__init__()
        self.encoder = backbone
        # freeze weights
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.reduce = reduce
        
    def forward(self, x):
        # our unet-encoder returns list of outputs from all the levels --
        # here we only need the bottleneck
        x = self.encoder(x).pop()
        return self.reduce(x).squeeze()
    
# frozen encoder (same as we used with CNN, but different reduce)
#VisualEncoder(TransformerEncoder(VIEW_SIZE, 4, LATENT_DIM, 4), MeanReduce()).to(DEVICE)(X.to(DEVICE)).shape

<a name="decoder"></a>

### Decoder
`Decoder` takes an embedding vector and reconstruct an image.

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self,
                 view_size: int,
                 patch_size: int,
                 embed_size: int,
                 depth: int,
                 channels: int = 1,
                 bridge: bool = True,
                 expansion: int = 4):
        
        super(TransformerDecoder, self).__init__()
        # up-blocks
        self.blocks = nn.ModuleList([TransformerBlock(embed_size, bridge, expansion) for _ in range(depth)])
        self.unpatch = SequenceToView(view_size, patch_size, embed_size, channels=channels)
        
    def forward(self, outputs):
        assert len(outputs) == len(self.blocks) + 1
        x = outputs.pop()
        for block in self.blocks:
            x = block(x, outputs.pop())
        return self.unpatch(x)

#encoded = TransformerEncoder(VIEW_SIZE, 4, LATENT_DIM, 4)(X)
#TransformerDecoder(VIEW_SIZE, 4, LATENT_DIM, 2, channels=2)(encoded).shape

<a name="model"></a>

### Encoder + Decoder
The model takes in an encoder (maybe pretrained) and attaches a matching decoder.

In [None]:
class UNet(nn.Sequential):
    def __init__(self, encoder: TransformerEncoder, output_dim: int = 1, bridge: bool = True):
        # construct matching decoder
        decoder = TransformerDecoder(encoder.view_size, encoder.patch_size, encoder.embed_size,
                                     encoder.depth - 1, channels=output_dim, bridge=bridge)
        super(UNet, self).__init__(encoder, decoder, nn.Softmax(dim=1))
    
#encoder = TransformerEncoder(VIEW_SIZE, 4, LATENT_DIM, 4)
#summary(UNet(encoder, output_dim=2, bridge=True).to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))

In [None]:
# let's compare no-bride vs. bridge
arc = { '':False, 'B':True }
tags = list(arc.keys())

<a name="run"></a>

## Comparative training and evaluation

In [None]:
dataset = RandomViewDataset

In [None]:
train_samples = np.random.choice(samples, int(len(samples) * 0.95), replace=False)
test_samples = list(set(samples).difference(set(train_samples)))
len(train_samples), len(test_samples)

With our chosen scenario the actual train/test datasets sizes will depend on the `batch_size` -- we generate a batch from each sample page: `size`=`num_samples`x`batch_size`.

<a name="1"></a>

#### 1. Define models

In [None]:
patch_size = 4
depth = 4

num_classes = 2
encoders, models = [], []

for tag in tags:
    encoders.append(TransformerEncoder(VIEW_SIZE, patch_size, LATENT_DIM, depth))
    models.append(UNet(encoders[-1], output_dim=num_classes, bridge=arc[tag]).to(DEVICE))
    
    # continue training with saved models
    #models[-1].load_state_dict(torch.load(f'./models/visual-unet-ViT-{patch_size}-{depth}-{tag}.pt'))

<a name="2"></a>

#### 2. Define optimization
We use the same `DiceLoss` which handles class imbalance internally.

In [None]:
learning_rate = 1e-6
criterions = [DiceLoss(num_classes).to(DEVICE) for _ in range(len(models))]
optimizers = [AdamW(model.parameters(), lr=learning_rate) for model in models]

<a name="3"></a>

#### 3. Define evaluation metrics

In [None]:
metrics = [{'f1-score': F1Score(task='multiclass', num_classes=num_classes).to(DEVICE),
            'jaccard': JaccardIndex(task='multiclass', num_classes=num_classes).to(DEVICE)}
           for _ in range(len(models))]

<a name="4"></a>

#### 4. Run training with validation
We train all the models side-by-side on the same data batches for comparison.

In [None]:
batch_size = 16
num_epochs = 1
validation_steps = 3

In [None]:
trainer = MultiTrainer(dataset, models, VIEW_SIZE, criterions, optimizers, metrics, tags=tags)
results = trainer.run(train_samples, test_samples, batch_size, num_epochs, validation_steps)

<a name="5"></a>

#### 5. Evaluate results

In [None]:
trainer.plot_compare()

In [None]:
trainer.plot_history()

In [None]:
# let's see some examples side-by-side
sample = np.random.choice(samples)
loader = DataLoader(RandomViewDataset(sample, VIEW_SIZE, max_samples=8), batch_size=8)
for model in models:
    model.eval()
with torch.no_grad():
    for X, Y in loader:
        P = [torch.argmax(model(X.to(DEVICE)), axis=1).cpu() for model in models]
        for i in range(X.shape[0]): # batch
            fig, ax = plt.subplots(1, len(models) + 1, figsize=(8, 8))
            ax[0].imshow(X[i,:].squeeze().numpy(), 'gray')
            for n in range(1, len(models) + 1): # model
                ax[n].imshow(P[n - 1][i,:].squeeze().numpy(), 'gray')
                if i == 0:
                    ax[n].set_title(f'{tags[n - 1]} model', fontsize=10)
            for n in range(len(models) + 1):
                ax[n].axis('off')
                if i == 0 and n == 0:
                    ax[0].set_title('Input', fontsize=10)
        plt.show()            

<a name="embeddings"></a>

To evaluate the embeddings produced by trained encoders we can use the basic types of pages identified in the [baselines exploration notebook](Doc-Classification-Baselines.ipynb#labels) -- our models should be able to tell them apart. Let's look how well these groups are separated in the embedding space.

In [None]:
classes = ['mixed','plain-text','form-table','non-doc']

labeled = pd.read_csv('./data/labeled-sample.csv')
labeled.groupby('label').size()

In [None]:
# 25% non-docs for contrast
classes = ['mixed','plain-text','form-table','non-doc']
dataset = TopViewDataset(VIEW_SIZE, labeled['source'], labeled['label'], contrast=0.25)
profiles, scores = [], []
results = trainer.results
for name, encoder in zip(tags, encoders):
    name = 'Base' if name == '' else name
    # use model encoder to get embeddings
    embeddings, labels = get_embeddings(dataset, encoder, reduce=MeanReduce())
    P, pca_ratios, L, lda_ratios = get_profile(embeddings, labels)
    scores.append(silhouette_score(P[:,:3], labels, metric='euclidean'))
    score = silhouette_score(L, labels, metric='euclidean')
    results[len(scores) - 1]['contrast-score'] = score
    profiles.append(pca_ratios)
    # classes aggregated
    centers = np.array([np.median(L[np.where(np.array(labels) == k)], axis=0) for k in range(len(classes))])
    cmap = colormaps['gist_rainbow']
    fig, ax = plt.subplots(1, 2, figsize=(7, 3.2))
    for j in range(len(classes)):
        s = np.where(np.array(labels) == j)
        ax[0].scatter(P[s,0], P[s,1], s=3, color=cmap(j/3), alpha=0.3)
        ax[1].scatter(L[s,0], L[s,1], s=3, color=cmap(j/3), alpha=0.3)
    for j in range(len(classes)):
        ax[1].scatter(centers[j,0], centers[j,1], color=cmap(j/3),
                      s=75, marker='pos^'[j], edgecolor='black', label=classes[j])
    for j, (t, s) in enumerate([('PCA', scores[-1]),('LDA', score)]):
        ax[j].set_xticks([])
        ax[j].set_yticks([])
        ax[j].set_title(f'{t}  silhouette-score: {s:.4f}', fontsize=10)
    ax[1].legend(title=f'{name} model', fontsize=8, bbox_to_anchor=(1, 1), frameon=False)
    plt.show()
    
# compare all
plot_profiles(tags, profiles, scores)

In [None]:
# documents only
classes = ['mixed','plain-text','form-table']
dataset = TopViewDataset(VIEW_SIZE, labeled['source'], labeled['label'], contrast=0)
profiles, scores = [], []
for name, encoder in zip(tags, encoders):
    name = 'Base' if name == '' else name
    # use model encoder to get embeddings
    embeddings, labels = get_embeddings(dataset, encoder, reduce=MeanReduce())
    P, pca_ratios, L, lda_ratios = get_profile(embeddings, labels)
    scores.append(silhouette_score(P[:,:3], labels, metric='euclidean'))
    score = silhouette_score(L, labels, metric='euclidean')
    results[len(scores) - 1]['cluster-score'] = score
    profiles.append(pca_ratios)
    # classes
    centers = np.array([np.median(L[np.where(np.array(labels) == k)], axis=0) for k in range(len(classes))])    
    fig, ax = plt.subplots(1, 2, figsize=(7, 3.2))
    for j in range(len(classes)):
        s = np.where(np.array(labels) == j)
        ax[0].scatter(P[s,0], P[s,1], s=3, color=cmap(j/3), alpha=0.3)
        ax[1].scatter(L[s,0], L[s,1], s=3, color=cmap(j/3), alpha=0.3)
    for j in range(len(classes)):
        ax[1].scatter(centers[j,0], centers[j,1], color=cmap(j/3),
                      s=75, marker='pos^'[j], edgecolor='black', label=classes[j])
    for j, (t, s) in enumerate([('PCA', scores[-1]),('LDA', score)]):
        ax[j].set_xticks([])
        ax[j].set_yticks([])
        ax[j].set_title(f'{t}  silhouette-score: {s:.4f}', fontsize=10)
    ax[1].legend(title=f'{name} model', fontsize=8, bbox_to_anchor=(1, 1), frameon=False)
    plt.show()

# compare all
plot_profiles(tags, profiles, scores)

In [None]:
results = pd.DataFrame.from_dict(results)
results['model'] = tags
results.set_index('model').style.format('{:.4f}').background_gradient('Greens')

    for tag, model, encoder in zip(tags, models, encoders):
        torch.save(model.state_dict(), f'./models/visual-unet-ViT-{patch_size}-{depth}-{tag}.pt')
        
    results.to_csv(f'./models/visual-unet-ViT-{patch_size}-{depth}.csv')
    trainer.save(f'./models/visual-unet-ViT-{patch_size}-{depth}')
    
    # save base-model trained encoder
    torch.save(encoders[0].state_dict(), './models/visual-backbone-ViT.pt')