# Visual Backbone
This notebook explores `CNN`-based `visual-backbone` architectures. The [other consideration](Visual-Backbone-ViT.ipynb) is to use `ViT` transformer: documents are sequential in nature.

With downstream segmentation and classification tasks in mind we going to use `UNet` which won't produce most useful `latent space` due to `skip-connections` (bridges) and lack of compression (in case of a shallow network): first we explore different architectures. The ways to to make the latent space more useful we explore in the [VAE training notebook](Visual-Backbone-VAE.ipynb).

For the sake of exploration we build all from scratch and run some empirical study on key elements to determine default implementation details.

* [Dataset and Dataloader](#data)
* [UNet model](#unet)
    * [Attention](#attn)
    * [Encoder](#encoder)
    * [Decoder](#decoder)
    * [Encoder-Decoder](#model)    
* [Comparative training and evaluation](#run)
    * [Define models](#1)
    * [Define optimization](#2)
    * [Define validation metrics](#3)
    * [Run parallel training](#4)
    * [Evaluate results](#5)
    * [Evaluate embeddings](#embeddings)
    
#### Observations summary
* all models can support both non-doc detection as `anomaly detection` and as a `separate class`
* all models have sufficient info in the latent space to support our [baseline classification task](Visual-Classification-Baseline.ipynb)
* `residual` and `skip` connections speed up training significantly
* `self-attention` and `bridge-attention` help with denoising
* `self-attention` helps with outliers while residual connections make it worth
* skip-connections did not cause latent space to loose important information (bypass the bottleneck): in our case [we've got better clusters](#results)

In [None]:
import os
import re
import torch
import numpy as np
import pandas as pd

from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import colormaps
from IPython.display import SVG
from pathlib import Path
from einops import rearrange

from torch import nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.optim import SGD, AdamW
from torch.cuda.amp import GradScaler
from torchmetrics import F1Score, JaccardIndex
from torchsummary import summary

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler

In [None]:
# load local notebook-utils
from scripts import render
from scripts.dataset import *
from scripts.trainer import *

In [None]:
#torch._dynamo.config.verbose = True
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

In [None]:
# images with semantic segmentation masks available
images = [str(x).split('/').pop() for x in Path(f'{ROOT}/data/masks').glob('*.png')]
len(images)

In [None]:
VIEW_SIZE = 128

<a name="data"></a>

## Dataset
Document views are somewhat discrete -- for this simple we formulate decoder's task as classification (each pixel either signal or void) -- the decoder handles value extraction / denoising rather than reconstruction: for the inputs we generate a set of random viewports (center, rotation, zoom) from a single noisy version of a page; for the targets we generate one-dimensional binary masks.

We need some out-of-the-class samples for the contrast:
* if `contrast` being used: non-docs are treated as a `class`
* with no `contrast` training: non-docs are treated as `anomaly`

In [None]:
# docs samples
samples = images #np.random.choice(images, 160, replace=False)

In [None]:
# fraction of non-docs shown
#CONTRAST = 0.1
CONTRAST = 0

In [None]:
# random images (non-docs; and no text anywhere) for out-of-class examples
negatives = [str(x) for x in Path(f'{ROOT}/data/unsplash').glob('*.jpg')]

def make_negative_sample(view_size):
    """
    generate random non-document view
    """
    sample = 255 - np.array(ImageOps.grayscale(Image.open(np.random.choice(negatives))))
    nav = render.AgentView(sample, view_size, bias=np.random.randint(100))
    center = (np.array(sample.shape) * (1 - np.random.rand(2))).astype(int)
    rotation = np.random.randint(0, 360)
    zoom = np.random.rand() * 2 - 2
    return nav.set_state(center, rotation, zoom)

plt.imshow(make_negative_sample(VIEW_SIZE), 'gray')
plt.axis('off')
plt.show()

The challenge here: presence of straight lines and grids in some non-doc images may confuse our model.

To synthesize data on the go our data loader has to load the image from the disc, create a noisy version of it, and than render some random viewport -- that is slow -- some trade-off options:
* generate a whole dataset prior to the training (dataloader will be handling only loading)
* generate a set of noisy images (slowest part) prior to the training (dataloader will be handling viewports)
* dataloader will be handling all but for a single source: batch of random view-ports with one noisy version

For initial R&D we go with the first option which is slowest but gives us most flexibility. For the final stage we go with the second option (faster and easy to benchmark).

With our chosen scenarios the actual train/test datasets sizes will depend on the `batch_size` -- training/validation process runs through all the samples and generate a batch from each sample: `size`=`num_samples`✕`batch_size`.

In [None]:
# common preprocessing
class NormalizeView:
    """
    map to [0,1] and put channels first
    """
    def __call__(self, X):
        low, high = np.min(X), np.max(X)
        X = (X - low).astype(float)
        if high > low:
            X /= (high - low)
        if len(X.shape) == 3:
            h, w, c = X.shape
            return torch.Tensor(X).view(c, h, w)
        return torch.Tensor(X).unsqueeze(0)
    

NormalizeView()(np.random.randint(20, 220, 5))

In [None]:
class RandomViewDataset(Dataset):
    """
    use a single document noisy variation to create a batch of random view-ports
    make new data loader for each doc rather than reload resources for each view
    this scenario makes training very sensitive to the bad samples with bigger batches
    """
    def __init__(self, source: str, view_size: int, max_samples: int = 64, threshold: float = 0.25):
        self.view_size = view_size
        self.max_samples = max_samples
        self.threshold = threshold
        # load source image
        orig = np.array(ImageOps.grayscale(Image.open(f'{ROOT}/data/images/{source}')))
        view = make_noisy_sample(orig)
        # define renderers for all
        self.view = render.AgentView((view).astype(np.uint8), view_size, bias=np.random.randint(100))
        self.target = render.AgentView(255. - orig, view_size)
        # define image preprocesing
        self.transform = NormalizeView()

    def __len__(self):
        return self.max_samples
    
    def random_viewport(self):
        # pan: anywhere within the page-view bounding box
        center = (np.array(self.view.space.center) * (0.25 + np.random.rand() * 1.5)).astype(int)
        rotation = np.random.randint(0, 360)
        zoom = np.random.rand() * 4 - 3.5
        return center, rotation, zoom
    
    def __getitem__(self, idx):
        # once a while we need a negative sample
        if np.random.rand() < CONTRAST:
            X = self.transform(make_negative_sample(self.view_size))
            Y = torch.Tensor(np.zeros((self.view_size, self.view_size))).long()
            return X, Y
        # generate random viewport
        center, rotation, zoom = self.random_viewport()
        std = 0
        while std < 10: # make sure there's something to see
            center, rotation, zoom = self.random_viewport()
            view = self.view.render(center, rotation, zoom)
            std = np.std(view)
        # render corresponding views
        X = self.transform(view)
        # initialize masks channels
        target = self.transform(self.target.render(center, rotation, zoom))
        # sqrt here to make subtle lines pass the threshold
        Y = (target >= self.threshold).squeeze().long()
        return X, Y
    

In [None]:
sample = np.random.choice(samples)
# test loader
n = 8
loader = DataLoader(RandomViewDataset(sample, VIEW_SIZE, max_samples=n), batch_size=n, shuffle=False)
# show batch
for X, Y in loader:
    print(f'source: {sample}\nbatch:  X:{X.shape}  Y:{Y.shape}')
    for i in range(n):
        fig, ax = plt.subplots(1, 2, figsize=(5, 5))
        ax[0].imshow(X[i,:].squeeze(), 'gray')
        ax[0].axis('off')
        ax[1].imshow(Y[i,:].squeeze(), 'gray')
        ax[1].axis('off')
        if i == 0:
            ax[0].set_title('Input view')
            ax[1].set_title('Decoder task')
        plt.show()

<a name="unet"></a>


## UNet encoder-decoder model
We used a standard `ResNet`-encoder for [the baselines](Doc-Classification-Baselines.ipynb). For this experiment we build our own from scratch.

We use `GroupNorm` due to a small batch-size. Using `GELU` vs. `ReLU` noticeably stabilized the training in our case. Using `1x1 Conv` as a residual adapter should be explored. In this experiment we focus on `residual` and `skip`-connections (bridges), and two different types of `attention`.

In [None]:
SVG('assets/unet-blocks.svg')

<a name="attn"></a>

### Attention
`Self-Attention` highlights the important parts of the feature-map.
The output of the self-attention is converted to the shape of the input, normalized, and mapped between zero and one. The values are then multiplied by the output of the convolutional block in order to apply a weight to it before it progress to the next level. We use this module interchangeably with `GELU` activation at the end of the block.

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_heads: int = 1, head_channels: int = None):
        super(SelfAttention, self).__init__()
        self.num_heads = num_heads
        hidden_dim = in_channels * num_heads if head_channels is None else head_channels * num_heads
        # query-key-value
        self.qkv = nn.Conv2d(in_channels, hidden_dim * 3, 1)
        self.output = nn.Sequential(nn.Conv2d(hidden_dim, out_channels, 1), nn.GroupNorm(1, out_channels))
        self.activation = nn.ReLU()

    def forward(self, x):
        h, w = x.shape[-2:]
        qkv = self.qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.num_heads, qkv=3)
        k = k.softmax(dim=-1)
        context = torch.einsum('bhdn, bhen -> bhde', k, v)
        context = torch.einsum('bhde, bhdn -> bhen', context, q)
        context = rearrange(context, 'b heads c (h w) -> b (heads c) h w', heads=self.num_heads, w=w, h=h)
        return torch.tanh(self.activation(self.output(context)))


As mentioned above, the document views are somewhat discrete and almost "flat" -- just 3 or 4 fixed resolution levels -- it seems reasonable to apply attention to the bridges (bridge-attention module) which helps highlight the important regions of the feature map by resolution level: then we have an attention-gate in the decoder with enabled pass-through connections.

In [None]:
class BridgeAttention(nn.Module):
    def __init__(self, channels: int):
        super(BridgeAttention, self).__init__()
        #self.wx = ConvNorm(channels, channels, 1, 0)
        #self.wg = ConvNorm(channels, channels, 1, 0)
        self.wx = nn.Conv2d(channels, channels, 1, padding=0)
        self.wg = nn.Conv2d(channels, channels, 1, padding=0)
        self.activation = nn.ReLU()
        #self.attn = ConvNorm(channels, channels, 1, padding=0)
        self.attn = nn.Conv2d(channels, channels, 1, padding=0)
        
    def forward(self, pass_through, gaiting_signal):
        x = self.wx(pass_through)
        g = self.wg(gaiting_signal)
        x = torch.tanh(self.attn(self.activation(x + g)))
        return pass_through * (x + 1)


<a name="encoder"></a>

### Encoder
Our `unet-encoder` returns a list of outputs (feature-maps) from all resolution (depth) levels.

In [None]:
class ConvNorm(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, padding: int = 1):
        super(ConvNorm, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding),
            nn.GroupNorm(1, out_channels))

        
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, residual: bool = True, attn: bool = True):
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential(
            ConvNorm(in_channels, out_channels),
            nn.GELU(),
            ConvNorm(out_channels, out_channels))
        self.residual = nn.Conv2d(in_channels, out_channels, 1, padding=0) if residual else None
        self.attn = SelfAttention(in_channels, out_channels) if attn is not None else None
        self.activation = nn.ReLU()

    def forward(self, x):
        attn = self.attn(x) if self.attn is not None else None
        output = self.block(x)
        if self.residual:
            output = output + self.residual(x)
        # either attention or activation
        return self.activation(output) if attn is None else output * attn


In [None]:
class DownsampleBlock(nn.Module):        
    def __init__(self, in_channels: int, out_channels: int, residual: bool = True, attn: bool = True):
        super(DownsampleBlock, self).__init__()
        self.block = ConvBlock(in_channels, out_channels, residual, attn)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        pass_through = self.block(x)
        output = self.pool(pass_through)
        return output, pass_through
    

In [None]:
class CNNEncoder(nn.Module):
    def __init__(self, in_channels: int = 1, channels: int = 64, depth: int = 4,
                       residual: bool = True, attn: bool = True):
        super(CNNEncoder, self).__init__()
        self.channels = channels
        self.depth = depth
        self.blocks = nn.ModuleList()
        for _ in range(depth):
            self.blocks.append(DownsampleBlock(in_channels, channels, residual, attn))
            in_channels, channels = channels, channels * 2
        self.residual = residual
        self.attn = attn
        
    def forward(self, x):
        outputs = []
        for block in self.blocks:
            x, pass_through = block(x)
            outputs.append(pass_through)
        return outputs


encoder = CNNEncoder(channels=64, depth=4, residual=True, attn=True)
summary(encoder.to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))

`VisualEncoder` converts the image into a vector (embedding). With shallow network we've got no compression and a high-dimensional feature map at the bottle neck. To reduce dimensions we apply `AdaptiveAvgPool2d`.

In [None]:
class VisualEncoder(nn.Module):
    def __init__(self, backbone: nn.Module, reduce: nn.Module = None, frozen: bool = True):
        super(VisualEncoder, self).__init__()
        self.encoder = backbone
        if frozen: # freeze weights
            for param in self.encoder.parameters():
                param.requires_grad = False
        self.reduce = reduce
        
    def forward(self, x):
        # our unet-encoder returns list of outputs from all the levels --
        # here we only need the bottleneck
        x = self.encoder(x).pop()
        if self.reduce is None:
            return torch.flatten(x, start_dim=1)
        return torch.flatten(self.reduce(x), start_dim=1)

#summary(VisualEncoder(encoder, nn.AdaptiveAvgPool2d((1, 1))).to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))
VisualEncoder(encoder, nn.AdaptiveAvgPool2d((1, 1))).to(DEVICE).to(DEVICE)(X.to(DEVICE)).shape

In [None]:
VisualEncoder(encoder, None).to(DEVICE).to(DEVICE)(X.to(DEVICE)).shape

In [None]:
# we can make trainable adaptor
reduce = nn.Sequential(nn.Linear(131072, 512), nn.LayerNorm(512))
VisualEncoder(encoder, nn.AdaptiveAvgPool2d((1, 1))).to(DEVICE).to(DEVICE)(X.to(DEVICE)).shape

<a name="decoder"></a>

### Decoder
Our `unet-decoder` takes the list of feature-maps (`unet-encoder` output) and outputs the set of segmentation masks for the input image. With `bridge=False` (no skip-connections) the decoder will only use the bottleneck level feature map.

In [None]:
class UpsampleBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, residual: bool, attn: bool = True,
                 bridge: bool = True, bridge_attn: bool = True):
        super(UpsampleBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.bridge_attn = BridgeAttention(out_channels) if bridge_attn and attn else None
        self.block = ConvBlock(in_channels if bridge else in_channels//2, out_channels, residual, attn)
        self.bridge = bridge

    def forward(self, x, pass_through=None):
        x = self.deconv(x)
        if self.bridge: # use skip-connection
            if self.bridge_attn: # apply attention-gate to skip-connection
                pass_through = self.bridge_attn(pass_through, x)
            x = torch.cat((pass_through, x), dim=1)
        return self.block(x)
    

In [None]:
class CNNDecoder(nn.Module):
    def __init__(self, in_channels: int, output_dim: int, depth: int, residual: bool, attn: bool,
                       bridge: bool, bridge_attn: bool):
        super(CNNDecoder, self).__init__()
        self.blocks = nn.ModuleList()
        out_channels = in_channels//2
        for _ in range(depth):
            self.blocks.append(UpsampleBlock(in_channels, out_channels, residual, attn, bridge, bridge_attn))
            in_channels, out_channels = out_channels, out_channels//2
        self.head = nn.Conv2d(in_channels, output_dim, 1, padding=0) # 1x1 convolution
        
    def forward(self, outputs):
        assert len(outputs) == len(self.blocks) + 1
        outputs = list(outputs)
        x = outputs.pop()
        for block in self.blocks:
            x = block(x, outputs.pop())
        return self.head(x)
    

<a name="model"></a>

### Encoder + Decoder
The model takes in an encoder (maybe pretrained) and attaches a matching decoder.

In [None]:
class UNet(nn.Sequential):
    def __init__(self, encoder: CNNEncoder, output_dim: int = 1, bridge: bool = True, bridge_attn: bool = True):
        # construct matching decoder
        in_channels = encoder.channels * (2 ** (encoder.depth - 1))
        depth = encoder.depth - 1
        decoder = CNNDecoder(in_channels, output_dim, depth,
                             encoder.residual, encoder.attn, bridge, bridge_attn)
        super(UNet, self).__init__(
            encoder,
            decoder,
            nn.Softmax(dim=1)) # norm channels

UNet(encoder, output_dim=3, bridge=False, bridge_attn=True).to(DEVICE)(X.to(DEVICE)).shape

Let's compare different configurations.

In [None]:
arc = { '':     {'residual':False, 'attn':False, 'bridge':False, 'bridge_attn':False },
       
        'R':    {'residual':True,  'attn':False, 'bridge':False, 'bridge_attn':False },
        'A':    {'residual':False, 'attn':True,  'bridge':False, 'bridge_attn':False }, 
        'B':    {'residual':False, 'attn':False, 'bridge':True,  'bridge_attn':False },
       
        'RA':   {'residual':True,  'attn':True,  'bridge':False, 'bridge_attn':False }, 
        'RB':   {'residual':True,  'attn':False, 'bridge':True,  'bridge_attn':False },
        'AB':   {'residual':False, 'attn':True,  'bridge':True,  'bridge_attn':False }, 
        'BA':   {'residual':False, 'attn':False, 'bridge':True,  'bridge_attn':True  },
       
        'RAB':  {'residual':True,  'attn':True,  'bridge':True,  'bridge_attn':False },
        'RBA':  {'residual':True,  'attn':False, 'bridge':True,  'bridge_attn':True  },
        'ABA':  {'residual':False, 'attn':False, 'bridge':True,  'bridge_attn':True  },
       
        'RABA': {'residual':True,  'attn':True,  'bridge':True,  'bridge_attn':True  }}

tags = list(arc.keys())

<a name="run"></a>
## Comparative training and evaluation

In [None]:
dataset = RandomViewDataset

In [None]:
train_samples = np.random.choice(samples, int(len(samples) * 0.95), replace=False)
test_samples = list(set(samples).difference(set(train_samples)))
len(train_samples), len(test_samples)

<a name="1"></a>

#### 1. Define models
We build several models with different features to compare.

In [None]:
channels = 64
depth = 4

num_classes = 2
encoders, models = [], []

for tag in tags:
    encoders.append(CNNEncoder(channels=channels, depth=depth,
                               residual=arc[tag]['residual'], attn=arc[tag]['attn']))
    models.append(UNet(encoders[-1], output_dim=num_classes,
                       bridge=arc[tag]['bridge'], bridge_attn=arc[tag]['bridge_attn']).to(DEVICE))
    
    # continue training with saved models
    #models[-1].load_state_dict(torch.load(f'./models/visual-unet-CNN-{channels}-{depth}-{tag}.pt'))

<a name="2"></a>

#### 2. Define optimization
`CrossEntropy` with weighted classes should work, however, `DiceLoss` would be more robust against the imbalanced targets.

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, num_classes: int):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes

    def forward(self, X, Y):
        # apply softmax here instead to make this interchangeable with crossentropy
        #X = nn.Softmax(dim=1)(X)
        # unsqueeze classes
        Y = nn.functional.one_hot(Y, self.num_classes)
        # align axes
        X = rearrange(X, 'b c h w -> b h w c')
        # compute class weight
        W = torch.zeros((self.num_classes,))
        W = 1. / (torch.sum(Y, (0, 1, 2)) ** 2 + 1e-9)        
        # compute weighted cross and union sums over b h w
        cross = X * Y
        cross = W * torch.sum(cross, (0, 1, 2))
        cross = torch.sum(cross)
        union = Y + X
        union = W * torch.sum(union, (0, 1, 2))
        union = torch.sum(union)
        return 1. - 2. * (cross + 1e-9)/(union + 1e-9)
    

In [None]:
learning_rate = 1e-6
criteria = [DiceLoss(num_classes).to(DEVICE) for _ in range(len(models))]
optimizers = [AdamW(model.parameters(), lr=learning_rate) for model in models]

<a name="3"></a>

#### 3. Define evaluation metrics

In [None]:
metrics = [{ 'f1-score': F1Score(task='multiclass', num_classes=num_classes).to(DEVICE),
             'jaccard': JaccardIndex(task='multiclass', num_classes=num_classes).to(DEVICE) }
           for _ in range(len(models))]

<a name="4"></a>

#### 4. Run training with validation
We train all the models side-by-side on the same data batches for comparison.

Samples separated into a `#validation_steps` partitions, each model runs training on each partition followed by validation on all `test_samples`. Batches generated online, so, test is never exactly the same which minimizes selection bias.

In [None]:
batch_size = 16
num_epochs = 1
validation_steps = 3

In [None]:
trainer = MultiTrainer(dataset, models, VIEW_SIZE, criteria, optimizers, metrics, tags=tags)
results = trainer.run(train_samples, test_samples, batch_size, num_epochs, validation_steps)

<a name="5"></a>

#### 5. Evaluate results

In [None]:
trainer.plot_compare()

In [None]:
trainer.plot_history()

In [None]:
# some examples side-by-side
sample = np.random.choice(samples)
loader = DataLoader(RandomViewDataset(sample, VIEW_SIZE, max_samples=8), batch_size=8)
for model in models:
    model.eval()
with torch.no_grad():
    for X, Y in loader:
        P = [torch.argmax(model(X.to(DEVICE)), axis=1).cpu() for model in models]
        for i in range(X.shape[0]): # batch
            fig, ax = plt.subplots(1, len(models) + 1, figsize=(12, 12))
            ax[0].imshow(X[i,:].squeeze().numpy(), 'gray')
            for n in range(1, len(models) + 1): # model
                ax[n].imshow(P[n - 1][i,:].squeeze().numpy(), 'gray')
                if i == 0:
                    ax[n].set_title(f'{"model" if tags[n - 1] == "" else tags[n - 1]}', fontsize=10)
            for n in range(len(models) + 1):
                ax[n].axis('off')
                if i == 0 and n == 0:
                    ax[0].set_title('Input', fontsize=10)
        plt.show()

In [None]:
order = [] # sort in the order of performance
for r, t, i in sorted(zip(trainer.results, tags, range(len(tags))),
                      key=lambda v:v[0]['f1-score'] * v[0]['jaccard']):
    print(f'{t:>4} f1-score: {r["f1-score"]:<6} jaccard-index: {r["jaccard"]} #{i}')
    order.append(i)

In [None]:
# look closer with the same pregenerated batch within performance groups
loader = DataLoader(RandomViewDataset(np.random.choice(samples), VIEW_SIZE, max_samples=4), batch_size=4)
batch = []
for X, Y in loader:
    batch.append((X, Y))

with torch.no_grad():
    for group in [order[:4], order[4:-4], order[-4:]]:
        for X, Y in batch:
            P = [torch.argmax(models[k](X.to(DEVICE)), axis=1).cpu() for k in group]
            for i in range(X.shape[0]): # batch
                fig, ax = plt.subplots(1, 5, figsize=(10, 10))
                ax[0].imshow(X[i,:].squeeze().numpy(), 'gray')
                for n in range(1, 5): # model
                    ax[n].imshow(P[n - 1][i,:].squeeze().numpy(), 'gray')
                    if i == 0:
                        ax[n].set_title(f'{"model" if tags[group[n - 1]] == "" else tags[group[n - 1]]}',
                                        fontsize=10)
                for n in range(5):
                    ax[n].axis('off')
                    if i == 0 and n == 0:
                        ax[0].set_title('Input', fontsize=10)
            plt.show()

<a name="embeddings"></a>

#### Latent space evaluation
To evaluate the embeddings produced by trained encoders we can use the basic types of pages identified in the [baselines exploration notebook](Doc-Classification-Baselines.ipynb#labels) -- our models should be able to tell them apart. Let's look how well these groups are separated in the embedding space.

In [None]:
classes = ['mixed','plain-text','form-table','non-doc']

labeled = pd.read_csv('./data/labeled-sample.csv')
labeled.groupby('label').size()

In [None]:
class TopViewDataset(Dataset):
    """
    render full-page view
    """
    def __init__(self, view_size: int, samples: list, labels: list, contrast: float = 0.):
        self.view_size = view_size
        # add non-docs for contrast if needed
        n, c = int(len(samples) * contrast), max(labels)
        self.samples = list(samples) + ['?'] * n
        self.labels = labels
        self.non_doc_class = c + 1
        self.transform = NormalizeView()
        self.labels = labels

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        source = self.samples[idx]

        if source == '?': # non-doc sample for contrast
            X = self.transform(make_negative_sample(self.view_size))
            Y = self.non_doc_class
            return X, Y
        
        # load source image
        view = np.array(ImageOps.grayscale(Image.open(f'{ROOT}/data/images/{source}.png')))
        # renderer full-page view
        view = render.AgentView((255. - view).astype(np.uint8), self.view_size).top()
        X = self.transform(view)
        Y = self.labels[idx]
        return X, Y


In [None]:
text = labeled.loc[labeled['label']==1]
loader = DataLoader(TopViewDataset(VIEW_SIZE, text['source'].to_list(), [1] * len(text)),
                    batch_size=4, shuffle=False)
for X, Y in loader:
    for i in range(4):
        fig, ax = plt.subplots(figsize=(3, 3))
        ax.imshow(X[i,:].squeeze(), 'gray')
        ax.axis('off')
        plt.show()
    break

In [None]:
def get_embeddings(dataset: Dataset, backbone: nn.Module, reduce: nn.Module = None,
                   target_index: int = None, batch_size: int = 16):
    calc = VisualEncoder(backbone, reduce).to(DEVICE)
    calc.eval()
    embeddings, labels = None, []
    with torch.no_grad():
        for inputs, targets in DataLoader(dataset, batch_size=32):
            vectors = calc(inputs.to(DEVICE)).cpu().numpy().squeeze()
            embeddings = vectors if embeddings is None else np.concatenate([embeddings, vectors], axis=0)
            labels += list(targets.cpu().numpy()) if target_index is None \
                                else list(targets[target_index].cpu().numpy())
    return embeddings, labels


In [None]:
def get_profile(embeddings, labels):
    """
    Principal Components and Linear Discriminant
    """
    pca, lda = PCA(), LDA()
    scaler = StandardScaler().fit(embeddings)
    P, L = pca.fit_transform(scaler.transform(embeddings)), lda.fit_transform(embeddings, labels)
    # compute 2d tSNE
    #T = TSNE(n_components=2, perplexity=90).fit_transform(embeddings)
    return P, pca.explained_variance_ratio_, L, lda.explained_variance_ratio_ #, T
    

Let's check two main components interaction with never seen non-docs added to the test.

In [None]:
def plot_profiles(tags, profile, score):
    # sort by silhouette-score
    order = sorted(zip(tags, score, range(len(tags))), key=lambda x:x[1], reverse=True)
    # show explained variance ratio profiles
    fig, ax = plt.subplots(figsize=(6, 6))
    for t, s, i in order:
        plt.plot(profile[i][:7], color=f'C{i}', marker='oosDD'[len(t)],
                 label=f'score: {s:.4f}  model: {"base" if t=="" else t}' )
    plt.title('PCA explained variance ratio profiles', fontsize=10)
    plt.legend(title='Clusters silhouette-score', fontsize=8, bbox_to_anchor=(1, 1), frameon=False)
    plt.show()


In [None]:
# 25% non-docs for contrast
classes = ['mixed','plain-text','form-table','non-doc']
dataset = TopViewDataset(VIEW_SIZE, labeled['source'], labeled['label'], contrast=0.25)
profiles, scores = [], []
results = trainer.results
for name, encoder in zip(tags, encoders):
    name = 'Base' if name == '' else name
    # use model encoder to get embeddings
    embeddings, labels = get_embeddings(dataset, encoder, reduce=nn.AdaptiveAvgPool2d((1, 1)))
    P, pca_ratios, L, lda_ratios = get_profile(embeddings, labels)
    scores.append(silhouette_score(P[:,:3], labels, metric='euclidean'))
    score = silhouette_score(L, labels, metric='euclidean')
    results[len(scores) - 1]['contrast-score'] = score
    profiles.append(pca_ratios)
    # classes aggregated
    centers = np.array([np.median(L[np.where(np.array(labels) == k)], axis=0) for k in range(len(classes))])
    cmap = colormaps['gist_rainbow']
    fig, ax = plt.subplots(1, 2, figsize=(7, 3.2))
    for j in range(len(classes)):
        s = np.where(np.array(labels) == j)
        ax[0].scatter(P[s,0], P[s,1], s=3, color=cmap(j/3), alpha=0.3)
        ax[1].scatter(L[s,0], L[s,1], s=3, color=cmap(j/3), alpha=0.3)
    for j in range(len(classes)):
        ax[1].scatter(centers[j,0], centers[j,1], color=cmap(j/3),
                      s=75, marker='pos^'[j], edgecolor='black', label=classes[j])
    for j, (t, s) in enumerate([('PCA', scores[-1]),('LDA', score)]):
        ax[j].set_xticks([])
        ax[j].set_yticks([])
        ax[j].set_title(f'{t}  silhouette-score: {s:.4f}', fontsize=10)
    ax[1].legend(title=f'{name} model', fontsize=8, bbox_to_anchor=(1, 1), frameon=False)
    plt.show()
    
# compare all
plot_profiles(tags, profiles, scores)

<a name="results"></a>

In [None]:
# documents only
classes = ['mixed','plain-text','form-table']
dataset = TopViewDataset(VIEW_SIZE, labeled['source'], labeled['label'], contrast=0)
profiles, scores = [], []
for name, encoder in zip(tags, encoders):
    name = 'Base' if name == '' else name
    # use model encoder to get embeddings
    embeddings, labels = get_embeddings(dataset, encoder, reduce=nn.AdaptiveAvgPool2d((1, 1)))
    P, pca_ratios, L, lda_ratios = get_profile(embeddings, labels)
    scores.append(silhouette_score(P[:,:3], labels, metric='euclidean'))
    score = silhouette_score(L, labels, metric='euclidean')
    results[len(scores) - 1]['cluster-score'] = score
    profiles.append(pca_ratios)
    # classes
    centers = np.array([np.median(L[np.where(np.array(labels) == k)], axis=0) for k in range(len(classes))])    
    fig, ax = plt.subplots(1, 2, figsize=(7, 3.2))
    for j in range(len(classes)):
        s = np.where(np.array(labels) == j)
        ax[0].scatter(P[s,0], P[s,1], s=3, color=cmap(j/3), alpha=0.3)
        ax[1].scatter(L[s,0], L[s,1], s=3, color=cmap(j/3), alpha=0.3)
    for j in range(len(classes)):
        ax[1].scatter(centers[j,0], centers[j,1], color=cmap(j/3),
                      s=75, marker='pos^'[j], edgecolor='black', label=classes[j])
    for j, (t, s) in enumerate([('PCA', scores[-1]),('LDA', score)]):
        ax[j].set_xticks([])
        ax[j].set_yticks([])
        ax[j].set_title(f'{t}  silhouette-score: {s:.4f}', fontsize=10)
    ax[1].legend(title=f'{name} model', fontsize=8, bbox_to_anchor=(1, 1), frameon=False)
    plt.show()

# compare all
plot_profiles(tags, profiles, scores)

In [None]:
results = pd.DataFrame.from_dict(results)
results['model'] = tags
results.set_index('model').style.format('{:.4f}').background_gradient('Greens')

    for tag, model, encoder in zip(tags, models, encoders):
        torch.save(model.state_dict(), f'./models/visual-unet-CNN-{channels}-{depth}-{tag}.pt')
        if tag == 'RA': # save best trained encoder
            torch.save(encoder.state_dict(), f'./models/visual-backbone-CNN.pt')
        
    results.to_csv(f'./models/visual-unet-CNN-{channels}-{depth}.csv')
    trainer.save(f'./models/visual-unet-CNN-{channels}-{depth}')