# Visual: supervised multi-task training
This notebook explores the `CNN`-based `visual-encoder` training for downstream tasks of segmentation and classification. The training runs with the `bridge` (cross-connections) disabled to force most information captured at the embedding level (bottleneck).

* [Dataset and Dataloader](#data)
* [Backbone model](#model)
* [Training and Validation](#run)
    * [Define models](#1)
    * [Define optimization](#2)
    * [Define validation metrics](#3)
    * [Run training](#4)
    * [Evaluate results](#5)
        * [Embeddings](#embeddings)
    
#### Observations
The [embedding space](#embeddings) produced by this training shows better separation of the basic page-types in comparison with [initial experiment](Visual-Backbone-CNN.ipynb#embeddings).

In [None]:
import os
import re
import torch
import numpy as np
import pandas as pd

from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import colormaps
from pathlib import Path
from einops import rearrange

from torch import nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torchmetrics import F1Score, JaccardIndex, ConfusionMatrix
from torchsummary import summary

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

In [None]:
# load local notebook-utils
from scripts.backbone import *
from scripts.dataset import *
from scripts.trainer import *

In [None]:
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

In [None]:
# semantic segmentation masks
masks = [str(x).split('/').pop() for x in Path('./data/masks').glob('*.png')
         if not str(x).startswith('data/masks/que-')]
len(masks)

In [None]:
VIEW_SIZE = 128
LATENT_DIM = 512

<a name="data"></a>

### Create dataset and dataloader
We are going to generate data `online` which will slow down training, but would give us lots of flexibility.
The input is a noisy version of the random page view-port (center, rotation, zoom). The targets are:
* segmentation: text, input-space, table-outlines
* value: info vs. noise
* alignment: document rotation related to the view-port

In [None]:
# use images with masks
samples = masks #np.random.choice(masks, 640, replace=False)

In [None]:
class MultitaskPretrainingDataset(Dataset):
    """
    batch of random view-ports from a single page for a set of tasks:
    1. semantic segmentation
    2. value extraction and denoising
    3. alignment
    4. rotation
    """
    def __init__(self, source: str, view_size: int, max_samples: int,
                       alignment_threshold: float = 0,
                       unknown_fraction: float = 0.1, aligned_fraction: float = 0.2):
        self.view_size = view_size
        self.max_samples = max_samples
        self.alignment_threshold = alignment_threshold
        self.unknown_fraction = unknown_fraction
        self.aligned_fraction = aligned_fraction
        # load source image
        orig = np.array(ImageOps.grayscale(Image.open(f'{ROOT}/data/images/{source}')))
        view = make_noisy_sample(orig)
        # load segmentation mask
        mask = np.array(Image.open(f'{ROOT}/data/masks/{source}'))
        # define renderers for all
        self.view = render.AgentView((view).astype(np.uint8), view_size, bias=np.random.randint(100))
        self.segmentation = render.AgentView((np.eye(len(ORDER))[mask][:,:,1:] > 0) * 255, view_size)
        self.value = render.AgentView(255 - orig, view_size)
        # define image preprocesing
        self.transform = NormalizeView()
    
    def random_viewport(self):
        """
        the challenge here -- we need both -- coverage and consistency for a good representation
        """
        center = (np.array(self.view.space.center) * (0.25 + np.random.rand() * 1.5)).astype(int)
        if np.random.rand() < self.aligned_fraction:
            rotation = np.random.choice([0, 90, 180, 270])
        else:
            rotation = np.random.randint(0, 360)
        zoom = np.random.rand() * 4.0 - 3.5
        return center, rotation, zoom

    def __len__(self):
        return self.max_samples
    
    def __getitem__(self, idx):
        if np.random.rand() < self.unknown_fraction: # random non-doc image for out-of-class example
            X = self.transform(make_negative_sample(self.view_size))
            Y = torch.Tensor(np.zeros((self.view_size, self.view_size))).long()
            return X, (Y, Y, 0, 360)
        # generate random viewport
        std = 0
        while std < 10: # make sure there's something to see
            center, rotation, zoom = self.random_viewport()
            view = self.view.render(center, rotation, zoom)
            std = np.std(view)
        # render views
        X = self.transform(view)
        # initialize segmentation masks channels
        Y1 = np.zeros((self.view_size, self.view_size, len(ORDER)))
        # render masks in the same view-port
        view = self.segmentation.render(center, rotation, zoom)
        # fix scattered after rotation value back to binary
        view = (view/255. > 0.25).astype(int)
        # set target as a class-indices matrix
        Y1[:,:,1:] = view
        # segmentation task target
        Y1 = torch.Tensor(np.argmax(Y1, axis=(2))).long()
        # value task target
        view = self.value.render(center, rotation, zoom)
        Y2 = torch.Tensor(view/255. >= 0.25).squeeze().long()
        # alignment task target
        d = rotation % 90
        Y3 = int(min(d, 90 - d) <= self.alignment_threshold) + 1
        Y4 = int(rotation)
        return X, (Y1, Y2, Y3, Y4)
    

In [None]:
def get_alignment_weight(alignment_threshold=0, unknown_fraction=0.1, aligned_fraction=0.2):
    """
    estimate alignment task class-weight given chosen configuration
    """
    x = []
    for _ in range(10000): # run 10000 tries and get stats
        if np.random.rand() < unknown_fraction:
            x.append(0)
        else:
            r = np.random.choice([0, 90, 180, 270]) if np.random.rand() < aligned_fraction else \
                np.random.randint(0, 360)
            d = r % 90
            x.append(int(min(d, 90 - d) <= alignment_threshold) + 1)

    w = pd.Series(x)
    w = w.sum() - w.groupby(w).size()
    return list(np.round(list(w/w.sum()), 2))


ALIGNMET_WEIGHT = get_alignment_weight()

In [None]:
sample = np.random.choice(samples)
# test loader
n = 8
loader = DataLoader(MultitaskPretrainingDataset(sample, VIEW_SIZE, max_samples=n), batch_size=n)
alignment = ['N/A','No','Yes']
# show first batch
for X, (Y1, Y2, Y3, Y4) in loader:
    print(f'source: {sample}\nX: {X.shape}  Y1:{Y1.shape}  Y2:{Y2.shape}  Y3:{Y3.shape}  Y4:{Y4.shape}')
    for i in range(n):
        fig, ax = plt.subplots(1, 3, figsize=(8, 8))
        ax[0].imshow(X[i,:].squeeze(), 'gray')
        ax[0].axis('off')
        # restore channels to avoid visual confusion
        matrix = (np.eye(len(ORDER))[Y1[i,:]][:,:,1:] > 0) * 255
        # til -> ilt change RGB order for better lines visibility
        ax[1].imshow(matrix[:,:,[1,0,2]])
        ax[1].axis('off')
        ax[2].imshow(Y2[i,:], 'gray')
        ax[2].axis('off')  
        if i == 0:
            ax[0].set_title(f'Input view', fontsize=10)
            ax[1].set_title('Segmentation task', fontsize=10)
            ax[2].set_title('Value task', fontsize=10)
        else:
            angle = Y4[i] if Y4[i] < 360 else 'N/A'
            ax[0].set_title(f'aligned: {alignment[Y3[i]]}  rotation: {angle}', fontsize=8)
        plt.show()

<a name="model"></a>

## Model
Based on our [comparative experiment](Visual-Backbone-CNN.ipynb) the default visual-backbone `CNN` architecture we chose `64/4/residual`.

In [None]:
def get_cnn_head(output_dim: int):
    return nn.Sequential(
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(start_dim=1),
        Head(LATENT_DIM, output_dim))

def get_vit_head(output_dim: int):
    return nn.Sequential(
        MeanReduce(),
        Head(LATENT_DIM, output_dim))


In [None]:
def get_cnn_decoder(num_classes: int, bridge: bool = False):
    return nn.Sequential(
        CNNDecoder(LATENT_DIM, DEPTH - 1, True, bridge, True),
        nn.Conv2d(CHANNELS, num_classes, 1, 1),
        nn.Softmax(dim=1))

def get_vit_decoder(num_classes: int, bridge: bool = False):
    return nn.Sequential(
        TransformerDecoder(VIEW_SIZE, PATCH_SIZE, LATENT_DIM, DEPTH - 1, channels=num_classes, bridge=bridge),
        nn.Softmax(dim=1))


In [None]:
encoder = get_cnn_backbone(pretrained=True, frozen=False)
#encoder = get_vit_backbone(pretrained=True, frozen=False)

<a name="model"></a>

In this experiment we let's the same visual encoder with several task-specific decoders: `segmentation`, `value`, `alignment`, `rotation` and train them all together.

In [None]:
get_head, get_decoder = get_cnn_head, get_cnn_decoder
#get_head, get_decoder = get_vit_head, get_vit_decoder

In [None]:
class MultitaskUNet(nn.Module):
    """
    train multiple models using the same visual encoder:
    two decoders: segmentation and value
    """
    def __init__(self, backbone: nn.Module, get_head: Callable, get_decoder: Callable, latent_dim: int = 512):
        super().__init__()
        self.backbone = backbone
        # teask-specific decoders w/o bridges
        self.segmentation = get_decoder(4)
        self.value = get_decoder(2)
        # teask-specific classifiers
        self.alignment = get_head(3)
        self.rotation = get_head(361)

    def forward(self, x):
        e = self.backbone(x)
        embedding = e[-1]
        segmentation = self.segmentation(e[:])
        value = self.value(e[:])
        alignment = self.alignment(embedding)
        rotation = self.rotation(embedding)
        return segmentation, value, alignment, rotation

#MultitaskUNet(encoder, get_head, get_decoder).to(DEVICE)(X.to(DEVICE))
#summary(MultitaskUNet(encoder, get_head, get_decoder).to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))

<a name="run"></a>

## Training and evaluation

In [None]:
dataset = MultitaskPretrainingDataset

In [None]:
train_samples = np.random.choice(samples, int(len(samples) * 0.95), replace=False)
test_samples = list(set(samples).difference(set(train_samples)))
len(train_samples), len(test_samples)

<a name="1"></a>

#### 1. Define model

In [None]:
model = MultitaskUNet(encoder, get_head, get_decoder).to(DEVICE)

<a name="2"></a>

#### 2. Define optimization

In [None]:
SEGMENTATION_WEIGHT

In [None]:
DENOISING_WEIGHT

In [None]:
ALIGNMET_WEIGHT

In [None]:
criteria = [ DiceLoss(4).to(DEVICE),
             DiceLoss(2).to(DEVICE),
             nn.CrossEntropyLoss(weight=torch.tensor(ALIGNMET_WEIGHT, dtype=torch.float32)).to(DEVICE),
             nn.CrossEntropyLoss().to(DEVICE) ]

We can define our combined loss criterion as a weighted sum of tasks losses. However, tasks losses dynamic most probably will not be well aligned along the training making static tasks weights a suboptimal solution.

    class CombinedLoss(nn.Module):
        def __init__(self, criteria: list, weights: list):
            assert len(criteria) == len(weights)
            super(CombinedLoss, self).__init__()
            self.criteria = criteria
            self.weights = weights        

        def forward(self, preds, targets):
            losses = []
            for i, criterion in enumerate(self.criteria):
                losses.append(criterion(preds[i], targets[i]) * self.weights[i])
            return torch.sum(torch.stack(losses))

    criterion = CombinedLoss(criteria, [1., 1., 10., 10.]).to(device)
    # optimized model parameters only
    params = model.parameters()

Instead we can make tasks weights trainable parameters and learn them along the model training. 

In [None]:
    class HydraLoss(nn.Module):
        """
        Construct combined loss with trainable weights:
        https://arxiv.org/abs/1705.07115
        """
        def __init__(self, criteria: list):
            super().__init__()
            self.criteria = criteria
            self.log_vars = nn.Parameter(torch.zeros((len(criteria))))

        def forward(self, preds, targets):
            losses = []
            for i, criterion in enumerate(self.criteria):
                loss = criterion(preds[i], targets[i])
                losses.append(torch.exp(-self.log_vars[i]) * loss + self.log_vars[i])
            return torch.sum(torch.stack(losses))

    criterion = HydraLoss(criteria).to(DEVICE)
    # optimize both: model and loss parameters
    params = [p for p in model.parameters()] + [p for p in criterion.parameters()]

In [None]:
learning_rate = 1e-6
optimizer = AdamW(params, lr=learning_rate)

<a name="3"></a>

#### 3. Define evaluation metrics

In [None]:
metrics = {
    'segmentation': {
        'f1-score': F1Score(task='multiclass', num_classes=4).to(DEVICE) },
    'value': {
        'f1-score': F1Score(task='multiclass', num_classes=2).to(DEVICE) },
    'alignment': {
        'confmat': ConfusionMatrix(task='multiclass', num_classes=3).to(DEVICE),
        'f1-score': F1Score(task='multiclass', num_classes=3).to(DEVICE) },
    'rotation': {
        'confmat': ConfusionMatrix(task='multiclass', num_classes=361).to(DEVICE),
        'f1-score': F1Score(task='multiclass', num_classes=361).to(DEVICE) }}


<a name="4"></a>

#### 4. Run training

In [None]:
batch_size = 16
num_epochs = 6

In [None]:
#!rm -rf ./runs/visual-pretraining-cnn
trainer = Trainer(model, dataset, VIEW_SIZE, criterion, optimizer, metrics, multi_y=True)
#                 tensorboard_dir='runs/visual-pretraining-cnn') # log progress to tensorboard
results = trainer.run(train_samples, test_samples, batch_size, num_epochs=num_epochs, validation_steps=1)

<a name="5"></a>

#### 5. Evaluate results

In [None]:
plot_history(trainer.loss_history, trainer.metrics_history, multi_y=True)

In [None]:
for task in results:
    for metric in results[task]:
        if metric != 'confmat':
            print(f'{task:>20} {metric}: {results[task][metric]:.4f}')

In [None]:
plot_confmat(np.array(trainer.metrics_history['alignment']['confmat']),
             alignment, 'Alignment task confusion-matrix')

In [None]:
plot_confmat(np.array(trainer.metrics_history['rotation']['confmat']),
             None, list(range(30, 350, 30)), 'Rotation task confusion-matrix', size=8)

In [None]:
# let's see some examples
loader = DataLoader(MultitaskPretrainingDataset(np.random.choice(samples), VIEW_SIZE, max_samples=8),
                    batch_size=8)
model.eval()
with torch.no_grad():
    for X, (Y1, Y2, Y3, Y4) in loader:
        P = [np.argmax(p.cpu().numpy(), axis=1) for p in model(X.to(DEVICE))]
        for i in range(X.shape[0]):
            fig, ax = plt.subplots(1, 4, figsize=(11, 11))
            # input view
            ax[0].imshow(X[i,:].squeeze().numpy(), 'gray')
            ax[0].axis('off')
            
            # segmentation target color-channels
            matrix = (np.eye(len(ORDER))[Y1[i,:]][:,:,1:] > 0) * 255
            ax[1].imshow(matrix[:,:,[1,0,2]])
            ax[1].axis('off')
            
            # task output
            matrix = (np.eye(len(ORDER))[P[0][i,:]][:,:,1:] > 0) * 255
            ax[2].imshow(matrix[:,:,[1,0,2]])
            ax[2].axis('off')
            
            # kinetic awareness task
            ax[3].imshow(P[1][i,:], 'gray')
            ax[3].axis('off')
            
            if i == 0:
                ax[0].set_title('Input view', fontsize=10)
                ax[1].set_title('Segmentation target', fontsize=10)
                ax[2].set_title('Segmentation output', fontsize=10)
                ax[3].set_title('Value output', fontsize=10)
            else:
                ax[0].set_title((f'Align: {alignment[Y3[i]]} [true]  {alignment[P[2][i]]} [detected]   '
                                 f'Rotation: {Y4[i]} [true]  {P[3][i]} [detected]'),
                                fontsize=10, ha='left', x=0)
            plt.show()


<a name="embeddings"></a>

Let's check the latent space produced with trained encoder.

In [None]:
classes = ['mixed','plain-text','form-table','non-doc']
labeled = pd.read_csv('./data/labeled-sample.csv')
labeled.groupby('label').size()

In [None]:
dataset = TopViewDataset(VIEW_SIZE, labeled['source'], labeled['label'], contrast=0.3)
embeddings, labels = get_embeddings(dataset, model.backbone, reduce=nn.AdaptiveAvgPool2d((1, 1)))
# add to tensorboar-projector
#trainer.writer.add_embedding(embeddings, metadata=labels)

In [None]:
pca = PCA(n_components=2)
norm = StandardScaler().fit(embeddings)
P = pca.fit_transform(norm.transform(embeddings))
# classes means
centers = np.array([np.median(P[np.where(np.array(labels) == k)], axis=0) for k in range(len(classes))])

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
cmap = colormaps['gist_rainbow']
labels = np.array(labels)
n = len(classes) - 1
ax.scatter(P[:,0], P[:,1], s=3, c=labels/n, cmap=cmap, alpha=0.3)
for c in range(len(classes)):
    ax.scatter(centers[c,0], centers[c,1], color=cmap(c/n),
                   s=75, marker='posv'[c], edgecolor='black', label=classes[c])
ax.set_xticks([])
ax.set_yticks([])
ax.set_title('Docs vs. non-docs')
ax.legend(bbox_to_anchor=(1, 1), frameon=False)
plt.show()

In [None]:
torch.save(model.state_dict(), f'./models/visual-multitask-CNN-R-{CHANNELS}-{DEPTH}.pt')
#torch.save(model.state_dict(), f'./models/visual-multitask-ViT-{PATCH_SIZE}-{DEPTH}.pt')
#trainer.writer.close()

# save encoder only
torch.save(encoder.state_dict(), f'./models/visual-backbone-CNN-R-{CHANNELS}-{DEPTH}.pt')
#torch.save(encoder.state_dict(), f'./models/visual-backbone-ViT-{PATCH_SIZE}-{DEPTH}.pt')

Let's check changes between the epochs.

In [None]:
# generate static test-batch for comparison
loader = DataLoader(MultitaskPretrainingDataset(np.random.choice(samples), VIEW_SIZE, max_samples=8),
                    batch_size=8)
batch = []
for X, (Y1, Y2, Y3, Y4) in loader:
    batch.append((X, (Y1, Y2, Y3, Y4)))

for checkpoint in sorted(os.listdir('models/checkpoint')):
    if checkpoint.endswith('.pt'):
        print('\n\n\n')
        model = MultitaskUNet(encoder, get_head, get_decoder).to(DEVICE)
        model.load_state_dict(torch.load(f'models/checkpoint/{checkpoint}'))
        # let's see some examples with new variation from test-samples
        loader = DataLoader(MultitaskPretrainingDataset(np.random.choice(samples), VIEW_SIZE, max_samples=4),
                            batch_size=4)
        model.eval()
        with torch.no_grad():
            for X, (Y1, Y2, Y3, Y4) in batch:
                P = [np.argmax(p.cpu().numpy(), axis=1) for p in model(X.to(DEVICE))]
                for i in range(X.shape[0]):
                    fig, ax = plt.subplots(1, 4, figsize=(11, 11))
                    # input view
                    ax[0].imshow(X[i,:].squeeze().numpy(), 'gray')
                    ax[0].axis('off')

                    # segmentation target
                    matrix = (np.eye(len(ORDER))[Y1[i,:]][:,:,1:] > 0) * 255
                    ax[1].imshow(matrix[:,:,[1,0,2]])
                    ax[1].axis('off')

                    # segmentation output
                    matrix = (np.eye(len(ORDER))[P[0][i,:]][:,:,1:] > 0) * 255
                    ax[2].imshow(matrix[:,:,[1,0,2]])
                    ax[2].axis('off')

                    # value output
                    ax[3].imshow(P[1][i,:], 'gray')
                    ax[3].axis('off')

                    if i == 0:
                        ax[0].set_title('Input view', fontsize=10)
                        ax[1].set_title('Segmentation target', fontsize=10)
                        ax[2].set_title('Segmentation output', fontsize=10)
                        ax[3].set_title('Value output', fontsize=10)
                    else:
                        ax[0].set_title((f'Align: {alignment[Y3[i]]} [true]  {alignment[P[2][i]]} [detected]   '
                                         f'Rotation: {Y4[i]} [true]  {P[3][i]} [detected]'),
                                        fontsize=10, ha='left', x=0)
                    plt.show()
