# Visual: rotation
This notebook explores `rotation` part of the context. We use our [prospective pretrained visual encoders](#encoders) and explore how much of the rotation information we've got there:
let's check both [classification and regression](#model) scenarios.

In [None]:
import os
import re
import torch
import numpy as np
import pandas as pd

from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['xtick.labelsize'] = 7
rcParams['ytick.labelsize'] = 7

from pathlib import Path
from typing import Callable
from einops import rearrange, repeat

from torch import nn, Tensor
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.cuda.amp import GradScaler
from torchmetrics import F1Score, ConfusionMatrix, R2Score
from torchsummary import summary

In [None]:
# local lib
from scripts import simulate as sim
from scripts import parse, render
from scripts.backbone import *
from scripts.trainer import *
from scripts.dataset import NormalizeView

In [None]:
#torch._dynamo.config.verbose = True
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

In [None]:
# semantic segmentation masks
images = [str(x).split('/').pop() for x in Path('./data/masks').glob('*.png')
           if not str(x).startswith('data/masks/que-')]
len(images)

In [None]:
VIEW_SIZE = 128
LATENT_DIM = 512

In [None]:
samples = images #np.random.choice(images, 160, replace=False)

In [None]:
class RotationDataset(Dataset):
    def __init__(self, source: str, view_size: int, max_samples: int):
        self.num_steps = max_samples
        view = 255 - np.array(ImageOps.grayscale(Image.open(f'data/images/{source}')))
        self.nav = render.AgentView(view, view_size)
        self.transform = NormalizeView()
    
    def __len__(self):
        return self.num_steps
    
    def __getitem__(self, index):
        std = 0
        while std < 10: # make sure there's something to see
            rotation = np.random.randint(0, 360)
            center = (np.array(self.nav.space.center) * (0.25 + np.random.rand(2) * 0.5)).astype(int)
            zoom = -1. - np.random.rand() * 2
            observation = self.nav.render(center, rotation, zoom)
            std = np.std(observation)
        X = self.transform(observation)
        # classification target: integer angle
        Y1 = rotation
        # regression target: float [-1., 1.]
        Y2 = torch.Tensor([-(360. - rotation)/180. if rotation > 180 else float(rotation)/180.]).float()
        return X, (Y1, Y2)
    

In [None]:
sample = np.random.choice(images)
print(sample)
# test loader
batch_size = 4
loader = DataLoader(RotationDataset(sample, VIEW_SIZE, batch_size), batch_size)
# show first batch
for X, (Y1, Y2) in loader:
    for i in range(batch_size):
        fig, ax = plt.subplots(figsize=(3, 3))
        ax.imshow(X[i,:].squeeze(), 'gray')
        ax.axis('off')
        ax.set_title(f'rotation: [ {Y1[i].squeeze()} ] [ {Y2[i].squeeze():.2f} ]', fontsize=10)
        plt.show()


<a name="encoders"></a>

### Backbones to compare
For this experiment we use pretrained [CNN](Visual-Backbone-CNN.ipynb) and [ViT](Visual-Backbone-ViT.ipynb) backbones.

In [None]:
cnn_encoder = get_cnn_backbone(pretrained=True, frozen=False)

In [None]:
vit_encoder = get_vit_backbone(pretrained=True, frozen=False)

In [None]:
encoders = { 'CNN':cnn_encoder, 'ViT':vit_encoder }
tags = list(encoders.keys())

<a name="model"></a>

### Model
The model takes pretrained encoder and attaches two MPL-heads: 360 degrees classification and `[-1, 1]` regression. If classification works -- that's all we need. However, 360 classes is a lot. Regression may be even more challenging due to a fixed interval with singular edges, and it is less useful for us anyway. However, letting the encoders to learn further (do not freeze weights) both scenarios at once may improve the quality of the embeddings down the line.

In [None]:
def get_cnn_head(output_dim: int):
    return nn.Sequential(
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(start_dim=1),
        Head(LATENT_DIM, output_dim))

def get_vit_head(output_dim: int):
    return nn.Sequential(
        MeanReduce(),
        Head(LATENT_DIM, output_dim))
            
heads = { 'CNN':get_cnn_head, 'ViT':get_vit_head }

In [None]:
class RotationEstimator(nn.Module):
    def __init__(self, backbone: nn.Module, frozen: bool, get_head: Callable):
        super().__init__()
        self.backbone = backbone
        if frozen: # freeze weights
            for param in self.backbone.parameters():
                param.requires_grad = False
        self.classifier = get_head(360) # no softmax as we going to use crossentropy
        self.regressor = get_head(1)
        
    def forward(self, x):
        embedding = self.backbone(x)[-1]
        cls = self.classifier(embedding)
        reg = self.regressor(embedding)
        return cls, reg

# see how much we've got there
#model = RotationEstimator(cnn_encoder, True, get_cnn_head)

# see how much we can get there
model = RotationEstimator(cnn_encoder, False, get_cnn_head)

summary(model.to(DEVICE), (1, VIEW_SIZE, VIEW_SIZE))

In [None]:
for output in model(X.to(DEVICE)):
    print(output.shape)

<a name="training"></a>

### Comparative training and evaluation

In [None]:
train_samples = np.random.choice(samples, int(len(samples) * 0.95), replace=False)
test_samples = list(set(samples).difference(set(train_samples)))
len(train_samples), len(test_samples)

In [None]:
batch_size = 16

In [None]:
dataset = RotationDataset

In [None]:
models = [RotationEstimator(encoders[tag], False, heads[tag]).to(DEVICE) for tag in tags]

<a name="hydra"></a>
We can define our combined loss criterion as a weighted sum of tasks losses. However, tasks losses dynamics may not be well aligned along the training making static tasks weights a suboptimal solution.

    class CombinedLoss(nn.Module):
        def __init__(self, criteria: list, weights: list):
            assert len(criteria) == len(weights)
            super(CombinedLoss, self).__init__()
            self.criteria = criteria
            self.weights = weights        

        def forward(self, preds, targets):
            losses = []
            for i, criterion in enumerate(self.criteria):
                losses.append(criterion(preds[i], targets[i]) * self.weights[i])
            return torch.sum(torch.stack(losses))

    criterion = CombinedLoss(criteria, [1., 1., 10., 10.]).to(device)
    # optimized model parameters only
    params = model.parameters()

Instead, we make tasks weights trainable parameters and learn them along the model training. 

In [None]:
    class HydraLoss(nn.Module):
        """
        Construct combined loss with trainable weights:
        https://arxiv.org/abs/1705.07115
        """
        def __init__(self, criteria: list):
            super().__init__()
            self.criteria = criteria
            self.log_vars = nn.Parameter(torch.zeros((len(criteria))))

        def forward(self, preds, targets):
            losses = []
            for i, criterion in enumerate(self.criteria):
                loss = criterion(preds[i], targets[i])
                losses.append(torch.exp(-self.log_vars[i]) * loss + self.log_vars[i])
            return torch.sum(torch.stack(losses))


In [None]:
learning_rate = 5e-6

criteria = [HydraLoss([nn.CrossEntropyLoss(), nn.MSELoss()]).to(DEVICE) for _ in range(len(models))]
optimizers = [AdamW([p for p in model.parameters()] + [p for p in criterion.parameters()], lr=learning_rate)
              for model, criterion in zip(models, criteria)]

In [None]:
metrics = [{ 'classification': {'f1-score': F1Score(task='multiclass', num_classes=360).to(DEVICE),
                                'confmat': ConfusionMatrix(task='multiclass', num_classes=360).to(DEVICE) },
             'regression': { 'r2-score': R2Score().to(DEVICE) }}
           for _ in range(len(models))]

In [None]:
trainer = MultiTrainer(dataset, models, VIEW_SIZE, criteria, optimizers, metrics, tags=tags, multi_y=True)

Let's run a few epochs with a full visual in between to see how it goes.

In [None]:
num_epochs = 6
validation_steps = 2
k, offset = num_epochs//validation_steps, 0
for _ in range(k):
    # run training
    results = trainer.run(train_samples, test_samples, batch_size, k, 1, offset=offset)
    offset += k
    # show loss and validation history
    trainer.plot_compare()

    # get predictions by both models for the same data
    preds, targets = [[] for _ in range(len(models))], []
    for model in models:
        model.eval()
    for source in test_samples:
        loader = DataLoader(RotationDataset(source, VIEW_SIZE, batch_size//2), batch_size//2)
        for X, (Y1, Y2) in loader:
            targets.append(Y2.squeeze().numpy())
            with torch.no_grad():
                for i in range(len(models)):
                    torch.cuda.empty_cache()
                    preds[i].append(models[i](X.to(DEVICE))[1].squeeze().cpu().numpy())

    ticks = list(range(30, 361, 30))
    fig, ax = plt.subplots(2, 2, figsize=(8, 8))
    for i, tag in enumerate(tags):        
        # show classifier confusion
        matrix = trainer.metrics_history[i]['classification']['confmat']
        ax[0][i].imshow(matrix/np.max(matrix), cmap='coolwarm')
        total = np.sum(matrix)
        ax[0][i].set_xticks(ticks)
        ax[0][i].set_yticks(ticks)
        ax[0][i].set_title(f'{tags[i].upper()} confusion matrix', fontsize=10)
        ax[0][i].set_xlabel('Predicted')
        ax[0][i].set_ylabel('Actual')

        # show regressor residuals
        ax[1][i].scatter(targets, np.array(targets) - np.array(preds[i]), s=3, alpha=0.2)
        ax[1][i].axhline(y=0, color='C3')
        ax[1][i].set_title(f'{tag.upper()} residual plot')
        ax[1][i].set_xlabel('Target')
        ax[1][i].set_ylabel('Error')
    fig.tight_layout()
    plt.show()

In [None]:
trainer.plot_history()

In [None]:
trainer.plot_compare()

    for tag, model in zip(tags, models):
        torch.save(model.state_dict(), f'./models/visual-rotation-{tag}.pt')
        torch.save(encoders[tag].state_dict(), f'./models/visual-backbone-{tag}.pt')