# Visual: supervised multi-task pretraining
This notebook explores the `CNN`-based `visual-encoder` pretraining for downstream tasks of segmentation and classification. The training runs with the `bridge` (cross-connections) disabled to force most information captured at the embedding level (bottleneck).

* [Dataset and Dataloader](#data)
* [Backbone model](#model)
* [Training and Validation](#run)
    * [Define models](#1)
    * [Define optimization](#2)
    * [Define validation metrics](#3)
    * [Run training](#4)
    * [Evaluate results](#5)
        * [Embeddings](#embeddings)
    
#### Observations
The [embedding space](#embeddings) produced by this training shows better separation of the basic page-types in comparison with [initial experiment](Visual-Backbone-CNN.ipynb#embeddings).

In [None]:
import os
import re
import torch
import numpy as np
import pandas as pd

from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import colormaps
from pathlib import Path
from einops import rearrange

from torch import nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torchmetrics import F1Score, JaccardIndex, ConfusionMatrix
from torchsummary import summary

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

In [None]:
# load local notebook-utils
from scripts.backbone import *
from scripts.dataset import *
from scripts.trainer import *

In [None]:
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

In [None]:
# semantic segmentation masks
masks = [str(x).split('/').pop() for x in Path('./data/masks').glob('*.png')
         if not str(x).startswith('data/masks/que-')]
len(masks)

In [None]:
VIEW_SIZE = 128
LATENT_DIM = 512

<a name="data"></a>

### Create dataset and dataloader
We are going to generate data `online` which will slow down training, but would give us lots of flexibility.
The input is a noisy version of the random page view-port (center, rotation, zoom). The targets are:
* segmentation: text, input-space, table-outlines
* value: info vs. noise
* alignment: document rotation related to the view-port

In [None]:
# use images with masks
samples = masks #np.random.choice(masks, 640, replace=False)

In [None]:
ALIGNMET_WEIGHT = get_alignment_weight()

In [None]:
sample = np.random.choice(samples)
# test loader
n = 8
loader = DataLoader(MultitaskPretrainingDataset(sample, VIEW_SIZE, max_samples=n), batch_size=n)
alignment = ['N/A','No','Yes']
# show first batch
for X, (Y1, Y2, Y3, Y4) in loader:
    print(f'source: {sample}\nX: {X.shape}  Y1:{Y1.shape}  Y2:{Y2.shape}  Y3:{Y3.shape}  Y4:{Y4.shape}')
    for i in range(n):
        fig, ax = plt.subplots(1, 3, figsize=(8, 8))
        ax[0].imshow(X[i,:].squeeze(), 'gray')
        ax[0].axis('off')
        # restore channels to avoid visual confusion
        matrix = (np.eye(len(ORDER))[Y1[i,:]][:,:,1:] > 0) * 255
        # til -> ilt change RGB order for better lines visibility
        ax[1].imshow(matrix[:,:,[1,0,2]])
        ax[1].axis('off')
        ax[2].imshow(Y2[i,:], 'gray')
        ax[2].axis('off')  
        if i == 0:
            ax[0].set_title(f'Input view', fontsize=10)
            ax[1].set_title('Segmentation task', fontsize=10)
            ax[2].set_title('Value task', fontsize=10)
        else:
            angle = Y4[i] if Y4[i] < 360 else 'N/A'
            ax[0].set_title(f'aligned: {alignment[Y3[i]]}  rotation: {angle}', fontsize=8)
        plt.show()

<a name="model"></a>

## Model
Based on our [comparative experiment](Visual-Backbone-CNN.ipynb) the default visual-backbone `CNN` architecture we chose `64/4/residual`.

In [None]:
CHANNELS = 64
DEPTH = 4
#encoder = CNNEncoder(out_channels=CHANNELS, depth=DEPTH, residual=True).to(DEVICE)
#summary(encoder, (1, VIEW_SIZE, VIEW_SIZE))

PATCH_SIZE = 4
encoder = TransformerEncoder(VIEW_SIZE, PATCH_SIZE, LATENT_DIM, DEPTH)
#summary(encoder.to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))

<a name="model"></a>

In this experiment we let's the same visual encoder with several task-specific decoders: `segmentation`, `value`, `alignment`, `rotation` and train them all together.

In [None]:
#get_head, get_decoder = get_cnn_head, get_cnn_decoder
get_head, get_decoder = get_vit_head, get_vit_decoder

In [None]:
class MultitaskUNet(nn.Module):
    """
    train multiple models using the same visual encoder:
    two decoders: segmentation and value
    """
    def __init__(self, backbone: nn.Module, get_head: Callable, get_decoder: Callable, latent_dim: int = 512):
        super().__init__()
        self.backbone = backbone
        # teask-specific decoders w/o bridges
        self.segmentation = get_decoder(4)
        self.value = get_decoder(2)
        # teask-specific classifiers
        self.alignment = get_head(3)
        self.rotation = get_head(361)

    def forward(self, x):
        e = self.backbone(x)
        embedding = e[-1]
        segmentation = self.segmentation(e[:])
        value = self.value(e[:])
        alignment = self.alignment(embedding)
        rotation = self.rotation(embedding)
        return segmentation, value, alignment, rotation

#MultitaskUNet(encoder, get_head, get_decoder).to(DEVICE)(X.to(DEVICE))
#summary(MultitaskUNet(encoder, get_head, get_decoder).to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))

<a name="run"></a>

## Training and evaluation

In [None]:
dataset = MultitaskPretrainingDataset

In [None]:
train_samples = np.random.choice(samples, int(len(samples) * 0.95), replace=False)
test_samples = list(set(samples).difference(set(train_samples)))
len(train_samples), len(test_samples)

<a name="1"></a>

#### 1. Define model

In [None]:
model = MultitaskUNet(encoder, get_head, get_decoder).to(DEVICE)

<a name="2"></a>

#### 2. Define optimization

In [None]:
SEGMENTATION_WEIGHT

In [None]:
DENOISING_WEIGHT

In [None]:
ALIGNMET_WEIGHT

In [None]:
criteria = [ DiceLoss(4).to(DEVICE),
             DiceLoss(2).to(DEVICE),
             nn.CrossEntropyLoss(weight=torch.tensor(ALIGNMET_WEIGHT, dtype=torch.float32)).to(DEVICE),
             nn.CrossEntropyLoss().to(DEVICE) ]

We make tasks weights trainable parameters and learn them along the model training. 

In [None]:
    class HydraLoss(nn.Module):
        """
        Construct combined loss with trainable weights:
        https://arxiv.org/abs/1705.07115
        """
        def __init__(self, criteria: list):
            super().__init__()
            self.criteria = criteria
            self.log_vars = nn.Parameter(torch.zeros((len(criteria))))

        def forward(self, preds, targets):
            losses = []
            for i, criterion in enumerate(self.criteria):
                loss = criterion(preds[i], targets[i])
                losses.append(torch.exp(-self.log_vars[i]) * loss + self.log_vars[i])
            return torch.sum(torch.stack(losses))

    criterion = HydraLoss(criteria).to(DEVICE)
    # optimize both: model and loss parameters
    params = [p for p in model.parameters()] + [p for p in criterion.parameters()]

In [None]:
learning_rate = 1e-6
optimizer = AdamW(params, lr=learning_rate)

<a name="3"></a>

#### 3. Define evaluation metrics

In [None]:
metrics = {
    'segmentation': {
        'f1-score': F1Score(task='multiclass', num_classes=4).to(DEVICE) },
    'value': {
        'f1-score': F1Score(task='multiclass', num_classes=2).to(DEVICE) },
    'alignment': {
        'confmat': ConfusionMatrix(task='multiclass', num_classes=3).to(DEVICE),
        'f1-score': F1Score(task='multiclass', num_classes=3).to(DEVICE) },
    'rotation': {
        'confmat': ConfusionMatrix(task='multiclass', num_classes=361).to(DEVICE),
        'f1-score': F1Score(task='multiclass', num_classes=361).to(DEVICE) }}


<a name="4"></a>

#### 4. Run training

In [None]:
batch_size = 12
num_epochs = 8

In [None]:
trainer = Trainer(model, dataset, VIEW_SIZE, criterion, optimizer, metrics, multi_y=True)
results = trainer.run(train_samples, test_samples, batch_size, num_epochs=num_epochs, validation_steps=1)

<a name="5"></a>

#### 5. Evaluate results

In [None]:
plot_history(trainer.loss_history, trainer.metrics_history, multi_y=True)

In [None]:
for task in results:
    for metric in results[task]:
        if metric != 'confmat':
            print(f'{task:>20} {metric}: {results[task][metric]:.4f}')

In [None]:
plot_confmat(np.array(trainer.metrics_history['alignment']['confmat']),
             alignment, 'Alignment task confusion-matrix')

In [None]:
plot_confmat(np.array(trainer.metrics_history['rotation']['confmat']),
             None, list(range(30, 350, 30)), 'Rotation task confusion-matrix', size=8)

In [None]:
# let's see some examples with new variation from test-samples
loader = DataLoader(MultitaskPretrainingDataset(np.random.choice(samples), VIEW_SIZE, max_samples=8),
                    batch_size=8)
model.eval()
with torch.no_grad():
    for X, (Y1, Y2, Y3, Y4) in loader:
        P = [np.argmax(p.cpu().numpy(), axis=1) for p in model(X.to(DEVICE))]
        for i in range(X.shape[0]):
            fig, ax = plt.subplots(1, 4, figsize=(11, 11))
            # input view
            ax[0].imshow(X[i,:].squeeze().numpy(), 'gray')
            ax[0].axis('off')
            
            # segmentation target color-channels
            matrix = (np.eye(len(ORDER))[Y1[i,:]][:,:,1:] > 0) * 255
            ax[1].imshow(matrix[:,:,[1,0,2]])
            ax[1].axis('off')
            
            # task output
            matrix = (np.eye(len(ORDER))[P[0][i,:]][:,:,1:] > 0) * 255
            ax[2].imshow(matrix[:,:,[1,0,2]])
            ax[2].axis('off')
            
            # kinetic awareness task
            ax[3].imshow(P[1][i,:], 'gray')
            ax[3].axis('off')
            
            if i == 0:
                ax[0].set_title('Input view', fontsize=10)
                ax[1].set_title('Segmentation target', fontsize=10)
                ax[2].set_title('Segmentation output', fontsize=10)
                ax[3].set_title('Value output', fontsize=10)
            else:
                ax[0].set_title((f'Align: {alignment[Y3[i]]} [true]  {alignment[P[2][i]]} [detected]   '
                                 f'Rotation: {Y4[i]} [true]  {P[3][i]} [detected]'),
                                fontsize=10, ha='left', x=0)
            plt.show()


<a name="embeddings"></a>

Now, lets check the latent space produced with this encoder.

In [None]:
classes = ['mixed','plain-text','form-table','non-doc']
labeled = pd.read_csv('./data/labeled-sample.csv')
labeled.groupby('label').size()

In [None]:
dataset = TopViewDataset(VIEW_SIZE, labeled['source'], labeled['label'], contrast=0.3)
embeddings, labels = get_embeddings(dataset, model.backbone, reduce=MeanReduce())

In [None]:
pca = PCA(n_components=2)
norm = StandardScaler().fit(embeddings)
P = pca.fit_transform(norm.transform(embeddings))
# classes means
centers = np.array([np.median(P[np.where(np.array(labels) == k)], axis=0) for k in range(len(classes))])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(9, 4))
cmap = colormaps['gist_rainbow']
labels = np.array(labels)
n = len(classes) - 1
for c in range(len(classes) - 1):
    ax[0].scatter(P[labels==c,0], P[labels==c,1], s=3, color=cmap(c/n), alpha=0.3)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].set_title('Docs layout types')

ax[1].scatter(P[:,0], P[:,1], s=3, c=labels/n, cmap=cmap, alpha=0.3)
for c in range(len(classes)):
    ax[1].scatter(centers[c,0], centers[c,1], color=cmap(c/n),
                    s=75, marker='posv'[c], edgecolor='black', label=classes[c])
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].set_title('Docs vs. non-docs')
ax[1].legend(bbox_to_anchor=(1, 1), frameon=False)
plt.show()

In [None]:
#torch.save(model.state_dict(), f'./models/visual-multitask-CNN-R-{CHANNELS}-{DEPTH}.pt')
torch.save(model.state_dict(), f'./models/visual-multitask-ViT-{PATCH_SIZE}-{DEPTH}.pt')

# save encoder only
#torch.save(encoder.state_dict(), f'./models/visual-backbone-CNN-R-{CHANNELS}-{DEPTH}.pt')
torch.save(encoder.state_dict(), f'./models/visual-backbone-ViT-{PATCH_SIZE}-{DEPTH}.pt')

In [None]:
# generate static test-batch for comparison
loader = DataLoader(MultitaskPretrainingDataset(np.random.choice(samples), VIEW_SIZE, max_samples=8),
                    batch_size=8)
batch = []
for X, (Y1, Y2, Y3, Y4) in loader:
    batch.append((X, (Y1, Y2, Y3, Y4)))

for checkpoint in sorted(os.listdir('models/checkpoint')):
    if checkpoint.endswith('.pt'):
        print('\n\n\n')
        model = MultitaskUNet(encoder, get_head, get_decoder).to(DEVICE)
        model.load_state_dict(torch.load(f'models/checkpoint/{checkpoint}'))
        # let's see some examples with new variation from test-samples
        loader = DataLoader(MultitaskPretrainingDataset(np.random.choice(samples), VIEW_SIZE, max_samples=4),
                            batch_size=4)
        model.eval()
        with torch.no_grad():
            for X, (Y1, Y2, Y3, Y4) in batch:
                P = [np.argmax(p.cpu().numpy(), axis=1) for p in model(X.to(DEVICE))]
                for i in range(X.shape[0]):
                    fig, ax = plt.subplots(1, 4, figsize=(11, 11))
                    # input view
                    ax[0].imshow(X[i,:].squeeze().numpy(), 'gray')
                    ax[0].axis('off')

                    # segmentation target color-channels
                    matrix = (np.eye(len(ORDER))[Y1[i,:]][:,:,1:] > 0) * 255
                    ax[1].imshow(matrix[:,:,[1,0,2]])
                    ax[1].axis('off')

                    # task output
                    matrix = (np.eye(len(ORDER))[P[0][i,:]][:,:,1:] > 0) * 255
                    ax[2].imshow(matrix[:,:,[1,0,2]])
                    ax[2].axis('off')

                    # kinetic awareness task
                    ax[3].imshow(P[1][i,:], 'gray')
                    ax[3].axis('off')

                    if i == 0:
                        ax[0].set_title('Input view', fontsize=10)
                        ax[1].set_title('Segmentation target', fontsize=10)
                        ax[2].set_title('Segmentation output', fontsize=10)
                        ax[3].set_title('Value output', fontsize=10)
                    else:
                        ax[0].set_title((f'Align: {alignment[Y3[i]]} [true]  {alignment[P[2][i]]} [detected]   '
                                         f'Rotation: {Y4[i]} [true]  {P[3][i]} [detected]'),
                                        fontsize=10, ha='left', x=0)
                    plt.show()
