# Visual: semantic segmentation
This notebook focused on layout understanding tasks. We use our pretrained visual-backbone with frozen weights and train the decoders with `bridge` (cross-connections) enabled. In this experiment the view is `aligned` (random orientation, small skew considered), and zoom level is sufficient to make decision (top-view, or, a quarter-page at least).

* [Dataset and Dataloader](#data)
* [Model architecture](#model)
* [Training and Validation](#run)
    * [Define optimization](#2)
    * [Define validation metrics](#3)
    * [Run training](#4)
    * [Evaluate results](#5)


In [None]:
import os
import re
import torch
import numpy as np
import pandas as pd

from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import colormaps
from pathlib import Path
from einops import rearrange

from torch import nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torchmetrics import F1Score, JaccardIndex, ConfusionMatrix
from torchsummary import summary

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

In [None]:
# load local notebook-utils
from scripts.backbone import *
from scripts.dataset import *
from scripts.trainer import *

In [None]:
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

In [None]:
# semantic segmentation masks
masks = [str(x).split('/').pop() for x in Path('./data/masks').glob('*.png')
         if not str(x).startswith('data/masks/que-')]
len(masks)

In [None]:
VIEW_SIZE = 128

<a name="data"></a>

### Create dataset and dataloader
The input is a noisy version of a page-alined view-port (0, 90, 180, 270).

In [None]:
# use images with masks
samples = masks #np.random.choice(masks, 640, replace=False)

In [None]:
class SegmentationDataset(Dataset):
    """
    batch of page-aligned view-ports from the single document for the set of tasks:
    1. semantic segmentation
    2. value extraction and denoising
    3. orientation detection
    """
    def __init__(self, source: str, view_size: int, max_samples: int, max_skew: float: = 3.):
        self.view_size = view_size
        self.max_samples = max_samples
        # load source image
        orig = np.array(ImageOps.grayscale(Image.open(f'{ROOT}/data/images/{source}')))
        view = make_noisy_sample(orig)
        # load segmentation mask
        mask = np.array(Image.open(f'{ROOT}/data/masks/{source}'))
        # define renderers for all
        self.view = render.AgentView((view).astype(np.uint8), view_size, bias=np.random.randint(100))
        self.segmentation = render.AgentView((np.eye(len(ORDER))[mask][:,:,1:] > 0) * 255, view_size)
        self.value = render.AgentView(255 - orig, view_size)
        # define image preprocesing
        self.transform = Normalize
    
    def random_viewport(self):
        """
        aligned view-port: only a small skew considered
        """
        center = (np.array(self.view.space.center) * (0.1 + np.random.rand() * 1.8)).astype(int)
        rotation = np.random.choice([0, 90, 180, 270])
        zoom = -1 - np.random.rand() * 2.5
        return center, rotation, zoom

    def __len__(self):
        return self.max_samples
    
    def __getitem__(self, idx):
        if np.random.rand() < 0.2: # random non-doc image for out-of-class example balanced repr.
            X = self.transform(make_negative_sample(self.view_size).astype(np.float32)/255.)
            Y = torch.Tensor(np.zeros((self.view_size, self.view_size))).long()
            return X, (Y, Y, 0)
        # generate random viewport
        std = 0
        while std < 10: # make sure there's something to see
            center, rotation, zoom = self.random_viewport()
            view = self.view.set_state(center, rotation, zoom)
            std = np.std(view)        
        # orientation task
        Y3 = rotation//90 + 1
        # add random skew
        rotation += int(np.random.rand() * max_skew - max_skew)
        # render views
        X = self.transform(self.view.set_state(center, rotation, zoom).astype(np.float32)/255.)
        # initialize segmentation masks channels
        Y1 = np.zeros((self.view_size, self.view_size, len(ORDER)))
        # render masks in the same view-port
        view = self.segmentation.set_state(center, rotation, zoom)
        # fix scattered after rotation value back to binary
        view = (view/255. > 0.25).astype(int)
        # set target as a class-indices matrix
        Y1[:,:,1:] = view
        # segmentation task target
        Y1 = torch.Tensor(np.argmax(Y1, axis=(2))).long()
        # value task target
        view = self.value.set_state(center, rotation, zoom)
        Y2 = torch.Tensor(view/255. >= 0.25).squeeze().long()
        return X, (Y1, Y2, Y3)
    

In [None]:
sample = np.random.choice(samples)
# test loader
n = 8
loader = DataLoader(MultitaskDataset(sample, VIEW_SIZE, max_samples=n), batch_size=n, shuffle=False)
orientation = ['N/A','0','90','180','270']
# show first batch
for X, (Y1, Y2, Y3) in loader:
    print(f'source: {sample}\nX: {X.shape}  Y1:{Y1.shape}  Y2:{Y2.shape}  Y3:{Y3.shape}')
    for i in range(n):
        fig, ax = plt.subplots(1, 3, figsize=(8, 8))
        ax[0].imshow(X[i,:].squeeze(), 'gray')
        ax[0].axis('off')
        # restore channels to avoid visual confusion
        matrix = (np.eye(len(ORDER))[Y1[i,:]][:,:,1:] > 0) * 255
        # til -> ilt change RGB order for better lines visibility
        ax[1].imshow(matrix[:,:,[1,0,2]])
        ax[1].axis('off')
        ax[2].imshow(Y2[i,:], 'gray')
        ax[2].axis('off')  
        if i == 0:
            ax[0].set_title(f'Input view', fontsize=10)
            ax[1].set_title('Segmentation task', fontsize=10)
            ax[2].set_title('Value task', fontsize=10)
        else:
            ax[0].set_title(f'orientation: {orientation[Y3[i]]}', fontsize=10)
        plt.show()

<a name="model"></a>

## Model
Based on our [comparative experiment](Visual-Backbone-CNN.ipynb) the default CNN-based architecture we chose `64/4/residual`.

In [None]:
CHANNELS = 64
DEPTH = 4
#backbone = CNNEncoder(out_channels=CHANNELS, depth=DEPTH, residual=True).to(DEVICE)
#backbone.load_state_dict(torch.load(f'./models/visual-backbone-CNN-R-64-4.pt'))
#summary(backbone, (1, VIEW_SIZE, VIEW_SIZE))

PATCH_SIZE = 4
#backbone = TransformerEncoder(VIEW_SIZE, PATCH_SIZE, LATENT_DIM, DEPTH)
#backbone.load_state_dict(torch.load(f'./models/visual-backbone-ViT-4-4.pt'))
#summary(encoder.to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))

<a name="model"></a>



In [None]:
class ClassificationHead(nn.Sequential):
    def __init__(self, latent_dim: int, num_classes: int):
        super(ClassificationHead, self).__init__(
            nn.LayerNorm(latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, num_classes))

def get_cnn_encoder(num_classes: int):
    return nn.Sequential(
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(start_dim=1),
        ClassificationHead(LATENT_DIM, num_classes),
        nn.Softmax(dim=1))

def get_vit_encoder(num_classes: int):
    return nn.Sequential(
        MeanReduce(),
        ClassificationHead(LATENT_DIM, num_classes),
        nn.Softmax(dim=1))


In [None]:
get_encoder, get_decoder = get_cnn_encoder, get_cnn_decoder
#get_encoder, get_decoder = get_vit_encoder, get_vit_decoder

In this new experiment we use the same encoder with multiple task-specific decoders: `segmentation`, `value`, and `orientation` detector head to train all as a single model.

In [None]:
def get_cnn_decoder(num_classes: int, bridge: bool = False):
    return nn.Sequential(
        CNNDecoder(LATENT_DIM, DEPTH - 1, True, bridge, True),
        nn.Conv2d(CHANNELS, num_classes, 1, 1),
        nn.Softmax(dim=1))

def get_vit_decoder(num_classes: int, bridge: bool = False):
    return nn.Sequential(
        TransformerDecoder(VIEW_SIZE, PATCH_SIZE, LATENT_DIM, DEPTH - 1, channels=num_classes, bridge=bridge),
        nn.Softmax(dim=1))


In [None]:
class MultitaskUCNN(nn.Module):
    """
    train multiple models using the same visual encoder:
    two decoders: segmentation and value
    and one classification head
    """
    def __init__(self, encoder: CNNEncoder, frozen: bool = False):
        super().__init__()
        self.encoder = encoder
        if frozen: # freeze weights
            for param in self.encoder.parameters():
                param.requires_grad = False        
        channels, depth = encoder.out_channels, encoder.depth
        embedding_size = channels * (2 ** (depth - 1))
        # teask-specific decoders
        self.segmentation = CNNDecoder(embedding_size, depth - 1, True, True, True)
        self.value = CNNDecoder(embedding_size, depth - 1, True, True, True)
        self.alignment = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(start_dim=1),
            nn.LayerNorm(embedding_size),
            nn.ReLU(),
            nn.Linear(embedding_size, 5), # classes: n/a, 0, 90, 180, 270
            nn.Softmax(dim=1)
        )
        # tasks heads
        self.segmentation_logits = nn.Sequential(nn.Conv2d(channels, 4, 1, 1), nn.Softmax(dim=1))
        self.value_logits = nn.Sequential(nn.Conv2d(channels, 2, 1, 1), nn.Softmax(dim=1))

    def forward(self, x):
        e = self.encoder(x)
        segmentation = self.segmentation_logits(self.segmentation(e[:]))
        value = self.value_logits(self.value(e[:]))
        # avg pool (reduce) from the bottleneck
        alignment = self.alignment(e[-1])
        return segmentation, value, alignment

#MultitaskUCNN(encoder).to(DEVICE)(X.to(DEVICE))
#summary(MultitaskUCNN(encoder).to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))

<a name="run"></a>

## Training and evaluation

In [None]:
dataset = MultitaskDataset

In [None]:
train_samples = np.random.choice(samples, int(len(samples) * 0.95), replace=False)
test_samples = list(set(samples).difference(set(train_samples)))
len(train_samples), len(test_samples)

<a name="1"></a>

#### 1. Define model

In [None]:
encoder.load_state_dict(torch.load(f'./models/visual-encoder-CNN-R-{CHANNELS}-{DEPTH}.pt'))
model = MultitaskUCNN(encoder, frozen=False).to(DEVICE)

<a name="2"></a>

#### 2. Define optimization

In [None]:
SEGMENTATION_WEIGHT

In [None]:
DENOISING_WEIGHT

In [None]:
criteria = [ DiceLoss(4).to(DEVICE), DiceLoss(2).to(DEVICE), nn.CrossEntropyLoss().to(DEVICE) ]
criterion = HydraLoss(criteria).to(DEVICE)
# optimize both: model and loss parameters
params = [p for p in model.parameters()] + [p for p in criterion.parameters()]

In [None]:
learning_rate = 1e-6
optimizer = AdamW(params, lr=learning_rate)

<a name="3"></a>

#### 3. Define evaluation metrics

In [None]:
metrics = {
    'segmentation': {
        'f1-score': F1Score(task='multiclass', num_classes=4).to(DEVICE) },
    'value': {
        'f1-score': F1Score(task='multiclass', num_classes=2).to(DEVICE) },
    'orientation': {
        'confmat': ConfusionMatrix(task='multiclass', num_classes=5).to(DEVICE),
        'f1-score': F1Score(task='multiclass', num_classes=5).to(DEVICE) }}


<a name="4"></a>

#### 4. Run training

In [None]:
batch_size = 16
num_epochs = 2

In [None]:
!rm -rf ./runs/visual-multi-cnn
trainer = Trainer(model, dataset, VIEW_SIZE, criterion, optimizer, metrics, multi_y=True,
                  tensorboard_dir='runs/visual-multi-cnn') # log progress to tensorboard
results = trainer.run(train_samples, test_samples, batch_size, num_epochs=num_epochs, validation_steps=4)

<a name="5"></a>

#### 5. Evaluate results

In [None]:
plot_history(trainer.loss_history, trainer.metrics_history, multi_y=True)

In [None]:
plot_confmat(np.sum(np.array(results['orientation']['confmat']), axis=0),
             orientation, 'Orientation task confusion-matrix')

In [None]:
results = trainer.run(train_samples, test_samples, batch_size, num_epochs=1, validation_steps=4)

In [None]:
plot_history(trainer.loss_history, trainer.metrics_history, multi_y=True)

In [None]:
plot_confmat(np.sum(np.array(results['orientation']['confmat']), axis=0),
             orientation, 'Orientation task confusion-matrix')

In [None]:
results = trainer.run(train_samples, test_samples, batch_size, num_epochs=1, validation_steps=2)

In [None]:
plot_history(trainer.loss_history, trainer.metrics_history, multi_y=True)
plot_confmat(np.sum(np.array(results['orientation']['confmat']), axis=0),
             orientation, 'Orientation task confusion-matrix')

In [None]:
# let's see some examples with new variation from test-samples
loader = DataLoader(MultitaskDataset(np.random.choice(samples), VIEW_SIZE, max_samples=8),
                    batch_size=8, shuffle=False)
model.eval()
with torch.no_grad():
    for X, (Y1, Y2, Y3) in loader:
        preds = model(X.to(DEVICE))        
        P1 = np.argmax(preds[0].cpu().numpy(), axis=1)
        P2 = np.argmax(preds[1].cpu().numpy(), axis=1)
        P3 = np.argmax(preds[2].cpu().numpy(), axis=1)
        for i in range(X.shape[0]):
            fig, ax = plt.subplots(1, 4, figsize=(11, 11))
            # input view
            ax[0].imshow(X[i,:].squeeze().numpy(), 'gray')
            ax[0].axis('off')
            
            # segmentation target color-channels
            matrix = (np.eye(len(ORDER))[Y1[i,:]][:,:,1:] > 0) * 255
            ax[1].imshow(matrix[:,:,[1,0,2]])
            ax[1].axis('off')
            
            # task output
            matrix = (np.eye(len(ORDER))[P1[i,:]][:,:,1:] > 0) * 255
            ax[2].imshow(matrix[:,:,[1,0,2]])
            ax[2].axis('off')
            
            # kinetic awareness task
            ax[3].imshow(P2[i,:], 'gray')
            ax[3].axis('off')
            
            if i == 0:
                ax[0].set_title(f'Input view {orientation[Y3[i]]}', fontsize=10)
                ax[1].set_title('Segmentation target', fontsize=10)
                ax[2].set_title('Segmentation output', fontsize=10)
                ax[3].set_title(f'Value output {orientation[P3[i]]}', fontsize=10)
            else:
                ax[0].set_title(orientation[Y3[i]], fontsize=10)
                ax[3].set_title(orientation[P3[i]], fontsize=10)
            plt.show()


#### 6. Save progress

    torch.save({'epoch': num_epochs,
                'batch_size': batch_size,
                'learning_rate': learning_rate,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()},
               f'./models/visual-multi-CNN- R-{CHANNELS}-{DEPTH}.pt')


    # save encoder model
    torch.save(encoder.state_dict(), f'./models/visual-encoder-CNN-R-{CHANNELS}-{DEPTH}-S.pt')