# Semantic alignment


In [None]:
import os
import re
import torch
import numpy as np
import pandas as pd

from time import time
from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import colormaps
from pathlib import Path
from einops import rearrange
from typing import Callable

from torch import nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torchmetrics import F1Score, ConfusionMatrix
from torchsummary import summary

from transformers import AutoTokenizer, AutoModel

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, Filter, FieldCondition, MatchValue

In [None]:
# load local notebook-utils
from scripts.backbone import *
from scripts.dataset import *
from scripts.trainer import *

In [None]:
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

In [None]:
# semantic segmentation masks
samples = [str(x).split('/').pop() for x in Path('./data/masks').glob('*.png')
           if not str(x).startswith('data/masks/que-')]
len(samples)

In [None]:
VIEW_SIZE = 128

## Define semantic space
Text presence in the view is the main indicator, secondary is presence of straight lines, and so on.

In [None]:
NONDOC = [
    'not a document',
    'nothing like document',
    'does not look like a document',
    'no document in the view',
]

SCOPE = { # page corners in the view
    (1, 1, 1, 1):'full-page view',
    (1, 1, 1, 0):'top-right partial view of a page',
    (1, 1, 0, 1):'top-left partial view of a page',
    (0, 1, 1, 1):'bootom-righ partial view of a page',
    (1, 0, 1, 1):'bottom-left partial view of a page',
    (1, 1, 0, 0):'top part of a page',
    (0, 0, 1, 1):'bottom part of a page',
    (0, 1, 1, 0):'right side of a page',
    (1, 0, 0, 1):'left side of a page',
    (1, 0, 0, 0):'top-left corner of a page',
    (0, 1, 0, 0):'top-right corner of a page',
    (0, 0, 1, 0):'bottom-right corner of a page',
    (0, 0, 0, 1):'bottom-left corner of a page',
    (0, 0, 0, 0):'page fragment',
}

ORIENTATION = {
    0:'straight',
    90:'turned on the left side',
    180:'turned upside-down',
    270:'turned on the right side',
}

KEYS = list(SCOPE.keys()) + list(ORIENTATION.keys())

## Encoders

In [None]:
visual_encoder = get_cnn_encoder(pretrained=True, frozen=True)
#visual_encoder = get_vit_encoder(pretrained=True, frozen=True)

In [None]:
class SemanticEncoder(nn.Module):
    def __init__(self, model_id: str, max_seq_length: int = 128):
        super(SemanticEncoder, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, max_seq_length=max_seq_length)
        self.model = AutoModel.from_pretrained(model_id)
        # freeze params
        for param in self.model.parameters():
            param.requires_grad = False
        # use cls token hidden representation as the sentence's embedding
        self.target_idx = 0
        
    def mean_pool(self, text):
        encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')        
        output = self.model(**encoded_input)
        token_embeddings = output[0] # all token embeddings
        attention_mask = encoded_input['attention_mask']
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        norm = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return nn.functional.normalize(torch.sum(token_embeddings * input_mask_expanded, 1) / norm)
    
    def encode(self, text):
        return self.mean_pool(text)
        # or use cls-token
        #encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')        
        #output = self.model(**encoded_input)
        #last_hidden_state = output.last_hidden_state
        #return last_hidden_state[:,self.target_idx,:]
    

In [None]:
#LLMID = 'sentence-transformers/distiluse-base-multilingual-cased-v1'
LLMID = 'sentence-transformers/paraphrase-distilroberta-base-v1'
semantic_encoder = SemanticEncoder(LLMID, max_seq_length=64) #.to(DEVICE)

captions = ['view description', 'another view description']
embeddings = semantic_encoder.encode(captions)
print(embeddings.shape)

## Dataset and Dataloader

In [None]:
labels = {
    'doc':['non-doc','doc','?'],
    'text':['no text','text','?'],
    'rotation':['straight','90','180','270','rotated','n/a'],
    'zoom':['unreadable','readable','?']
}

In [None]:
class SemanticAlignDataset(torch.utils.data.Dataset):
    """
    use a single document for a batch of random view-ports
    X: view-image and view-description text embedding
    Y: semantic and visual projections covariance matrix
       and tasks labels for both visual and semantic classifiers
    """
    def __init__(self, source: str, view_size: int, max_samples: int,
                       encode: Callable = semantic_encoder.encode, 
                       debug: bool = False):
        self.encode = encode
        self.view_size = view_size
        self.max_samples = max_samples
        self.debug = debug
        # ensure less collisions and better contrast
        #keys = KEYS * int(np.ceil((max_samples - 1)/len(KEYS)))
        # better representation of target-resolution and rotation
        keys = KEYS + [(0, 0, 0, 0)] * (max_samples - len(KEYS) - 1)
        self.keys = keys[:max_samples - 1] + ['']
        self.order = np.random.choice(range(max_samples), max_samples, replace=False)
        self.angle = list(np.random.choice([a for a in range(1, 360) if a not in [90, 180, 270]],
                                               max_samples, replace=False))
        # load source image
        orig = np.array(ImageOps.grayscale(Image.open(f'{ROOT}/data/images/{source}')))
        view = make_noisy_sample(orig) if np.random.rand() > 0.5 else 255 - orig
        # load segmentation mask
        mask = np.array(Image.open(f'{ROOT}/data/masks/{source}'))
        # define renderers for all
        self.view = render.AgentView((view).astype(np.uint8), view_size, bias=np.random.randint(100))
        self.segmentation = render.AgentView((np.eye(len(ORDER))[mask][:,:,1:] > 0) * 255, view_size)
        # define image preprocesing
        self.transform = Normalize
        
    def __len__(self):
        return self.max_samples
    
    def __getitem__(self, idx):
        key = self.keys[self.order[idx]]
        if key == '':
            # random non-doc image for out-of-class example
            X1  = self.transform(make_negative_sample(self.view_size).astype(np.float32)/255.)
            caption = np.random.choice(NONDOC)
            X2 = self.encode(caption).squeeze()
            Y1 = caption if self.debug else idx
            return (X1, X2), tuple([Y1] + [0, 0, 5, 2] * 2)
        # generate random viewport
        center, rotation, zoom = self.random_viewport(key)
        std = 0
        while std < 10: # make sure there's something to see
            rotation = np.random.randint(0, 360)
            center = (np.array(self.view.space.center) * (0.25 + np.random.rand(2) * 0.5)).astype(int)
            zoom = -1. - np.random.rand() * 2
            observation = self.view.render(center, rotation, zoom)
            std = np.std(observation)
        # render views
        X1 = self.transform(observation.astype(np.float32)/255.)
        # render masks in the same view-port
        view = self.segmentation.render(center, rotation, zoom)
        # fix scattered after rotation value back to binary
        view = (view/255. > 0.25).astype(int)
        scores = np.sum(view, axis=(0, 1))/164. # percent by channel
        lines, inputs, text = scores
        # alignment task target
        caption, info = self.generate_description(key, center, rotation, zoom, scores)
        X2 = self.encode(caption).squeeze()
        Y1 = caption if self.debug else idx # f'{caption}\n{info}'
        Y2 = 2 if (text < 0.25 and lines < 0.25) or caption == 'hard to identify' else 1
        Y3 = 0 if text == 0 else 1 if text > 5 else 2
        Y4 = 5 if text < 1 else {0:0, 90:1, 180:2, 270:3}.get(rotation, 4)
        Y5 = self.parse_zoom(text, zoom)
        return (X1, X2), tuple([Y1] + [Y2, Y3, Y4, Y5] * 2)
    
    def random_viewport(self, key):
        """
        generate a random viewport which fits description `key`
        """
        center = np.array(self.view.space.center).astype(float)
        if key in SCOPE.keys():
            n = sum(key) # num corners in the view
            zoom = [0., -2.75, -3., -3.5, -4.][n] + np.random.rand() * 0.25
            if n == 0: # small fragment
                shift = np.random.rand(2) - 0.5
                rotation = self.angle.pop()
            elif n == 4: # full-page
                rotation = np.random.choice([0, 90, 180, 270]) + np.random.choice([-2, -1, 0, 1, 2])
                shift = (np.random.rand(2) - 0.5) * 0.1
            else: # corners in the view
                f = 0.75 / n
                d = np.array([[-1., -1.],[-1., 1.],[1., 1.],[1., -1.]])[np.array(key) > 0] * f
                d += (np.random.rand(*d.shape) * 0.05 - 0.05)
                shift = np.sum(d, axis=0)
                rotation = self.angle.pop()
            center *= (1. + shift)
        else:
            zoom = 0.5 - np.random.rand() * 2.5
            center *= (0.4 + np.random.rand() * 1.2)
            rotation = key if key in [0, 90, 180, 270] else self.angle.pop()
        return center.astype(int), rotation, zoom
    
    def generate_description(self, key, center, rotation, zoom, scores):
        """
        text caption for the view using heuristics based on the dataset stats
        """
        lines, inputs, text = scores
        info = f'lines: {lines:.2f}   inputs: {inputs:.2f}   text: {text:.2f}   zoom: {zoom:.2f}'
        if lines == 0 and text == 0:
            return 'no text and no lines in the view', info
        orientation = self.parse_align(text, rotation, zoom)
        scope = self.parse_scope(key, zoom)
        content = self.parce_content(text, lines, inputs, zoom)
        if content == '':
            return 'hard to identify', info
        return f'{scope} with {content} {orientation}', info
    
    def parse_align(self, text, rotation, zoom):
        """
        rotation assessement based on text; if no text
        """
        if text == 0 or (text < 1 and zoom > -1):
            return 'straight' if rotation in [0, 90, 180, 270] else 'rotated'
        if rotation in [0, 90, 180, 270]:
            return ORIENTATION[rotation]
        return f'rotated {rotation:.0f} degrees counterclockwise' if rotation < 180 else \
               f'rotated {360 - rotation:.0f} degrees clockwise'
    
    def parse_scope(self, key, zoom):
        if key in SCOPE:
            return SCOPE[key]
        if zoom > 0:
            return 'small fragment of a page'
        return 'page fragment'
    
    def parce_content(self, text, lines, inputs, zoom):
        content = []
        if text > 0.1:
            content.append('text')
        if lines > 0.1:
            content.append('lines')
        return ' and '.join(content)
    
    def parse_zoom(self, text, zoom):
        if text < 1: return 2     # unknown
        if zoom > 0.5: return 1   # word
        if zoom > -0.5: return 1  # text
        if zoom > -1.5: return 1  # block (readable)
        return 0                  # page (not-readable)
    

In [None]:
sample = np.random.choice(samples)
# test loader
batch_size = 8
loader = DataLoader(SemanticAlignDataset(sample, VIEW_SIZE, batch_size, debug=True), batch_size=batch_size)
for X, Y in loader:
    for i in range(batch_size):
        fig, ax = plt.subplots(figsize=(4, 4))
        ax.imshow(X[0][i,:].squeeze(), 'gray')
        ax.axis('off')
        text = '  '.join([labels[task][Y[k + 1][i]] for k, task in enumerate(labels)])
        ax.set_title(f'{Y[0][i]}\nlabels:  {text}', ha='left', x=0, fontsize=10)
        plt.show()

#### Dataset class-weights estimation

In [None]:
batch_size = 16 # intended batch-size

stats = {x:[] for x in labels}
timer = 0 # check how much time data-synth takes
for source in np.random.choice(samples, 100, replace=False):
    start = time()
    loader = DataLoader(SemanticAlignDataset(source, VIEW_SIZE, batch_size, debug=False), batch_size=batch_size)
    for X, Y in loader:
        for i, label in enumerate(labels.keys(), 1):
            stats[label] += Y[i].tolist()
    timer += (time() - start)/100
    
print(f'Average time to make a batch: {timer:.0f} sec')
stats = pd.DataFrame.from_dict(stats, orient='columns')

In [None]:
# check batch structure
loader = DataLoader(SemanticAlignDataset(sample, VIEW_SIZE, batch_size), batch_size=batch_size)
for X, Y in loader:
    print(X[0].shape, X[1].shape)
    print(Y[0].shape, Y[1].shape, Y[2].shape, Y[3].shape, Y[4].shape)
    break

## Model

In [None]:
LATENT_DIM = 128

In [None]:
class Projection(nn.Module):
    def __init__(self, embedding_dim: int, latent_dim: int):
        super(Projection, self).__init__()
        self.projection = nn.Linear(embedding_dim, latent_dim)
        self.mlp = nn.Sequential(
            nn.GELU(),
            nn.Linear(latent_dim, latent_dim))
        self.norm = nn.LayerNorm(latent_dim)
    
    def forward(self, x):
        h = self.projection(x)
        x = self.mlp(h)
        x = x + h
        return self.norm(x)


In [None]:
semantic_projection = Projection(768, LATENT_DIM).to(DEVICE)

In [None]:
class VisualProjection(nn.Sequential):
    def __init__(self, encoder: nn.Module, embedding_dim: int, latent_dim: int, dropout: float = 0.):
        super(VisualProjection, self).__init__(
            encoder,
            Projection(embedding_dim, latent_dim))
    

In [None]:
visual_projection = VisualProjection(visual_encoder, 512, LATENT_DIM).to(DEVICE)

In [None]:
class MultitaskClassifier(nn.Module):
    def __init__(self, latent_dim: int, tasks: list):
        super().__init__()
        self.tasks = nn.ModuleList([Head(latent_dim, num_clases) for num_clases in tasks])
                                         
    def forward(self, x):
        return [task(x) for task in self.tasks]
        

In [None]:
task_load = [len(labels[x]) for x in labels]
classifier = MultitaskClassifier(LATENT_DIM, task_load).to(DEVICE)
for out in classifier(visual_projection(X[0].to(DEVICE))): print(out.shape)

In [None]:
class AlignMultitaskModel(nn.Module):
    def __init__(self, visual_projection: nn.Module, semantic_projection: nn.Module,
                       visual_classifier: nn.Module, semantic_classifier: nn.Module):
        super(AlignMultitaskModel, self).__init__()
        self.visual_projection = visual_projection
        self.semantic_projection = semantic_projection
        self.visual_classifier = visual_classifier
        self.semantic_classifier = semantic_classifier

    def forward(self, vx, sx):
        # calculate vectors
        vx = self.visual_projection(vx)
        sx = self.semantic_projection(sx)
        # calculate projections similarity
        d = torch.softmax(sx @ vx.T, dim=-1)
        # run detectors
        v = self.visual_classifier(vx)
        s = self.semantic_classifier(sx)
        return [d] + list(v) + list(s)


In [None]:
visual_classifier = MultitaskClassifier(LATENT_DIM, task_load).to(DEVICE)
semantic_classifier = MultitaskClassifier(LATENT_DIM, task_load).to(DEVICE)

model = AlignMultitaskModel(visual_projection, semantic_projection,
                            visual_classifier, semantic_classifier).to(DEVICE)

for out in model(X[0].to(DEVICE), X[1].to(DEVICE)): print(out.shape)

## Training

In [None]:
train_samples = np.random.choice(samples, int(len(samples) * 0.95), replace=False)
test_samples = list(set(samples).difference(set(train_samples)))
len(train_samples), len(test_samples)

In [None]:
dataset = SemanticAlignDataset

In [None]:
weights = [1. - stats.groupby(task).size()/len(stats) for task in labels]
weights = [list(w / sum(w)) for w in weights]
weights

In [None]:
# tasks loss
criteria = [nn.CrossEntropyLoss().to(DEVICE)] +\
           [nn.CrossEntropyLoss(weight=torch.tensor(w, dtype=torch.float32)).to(DEVICE) for w in weights] * 2
# composite parameterized loss
criterion = HydraLoss(criteria).to(DEVICE)
# optimize both: model and loss
params = [p for p in model.parameters()] + [p for p in criterion.parameters()]

learning_rate = 1e-6
optimizer = AdamW(params, lr=learning_rate)

In [None]:
metrics = {'align': {'confmat': ConfusionMatrix(task='multiclass', num_classes=batch_size).to(DEVICE)}}
for prefix in ['visual','semantic']:
    for i, task in enumerate(labels):    
        metrics[f'{prefix}-{task}'] = {
            'f1-score': F1Score(task='multiclass', num_classes=len(weights[i])).to(DEVICE),
            'confmat': ConfusionMatrix(task='multiclass', num_classes=len(weights[i])).to(DEVICE)}

In [None]:
num_epochs = 5
trainer = Trainer(model, dataset, VIEW_SIZE, criterion, optimizer, metrics, multi_x=True, multi_y=True,
                  autocast=False)
results = trainer.run(train_samples, test_samples, batch_size, num_epochs=num_epochs, validation_steps=2)

In [None]:
plot_history(trainer.loss_history, trainer.metrics_history, multi_x=True, multi_y=True)

## Evaluation

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
matrix = np.sum(np.array(results['align']['confmat']), axis=0)
ax.imshow(matrix/np.max(matrix), cmap='coolwarm')
plt.xticks([])
plt.yticks([])
plt.title('align confusion matrix')
plt.show()

for task in labels:
    fig, ax = plt.subplots(1, 2, figsize=(5, 5))
    for i, prefix in enumerate(['visual','semantic']):
        matrix = np.sum(np.array(results[f'{prefix}-{task}']['confmat']), axis=0)
        ax[i].imshow(matrix/np.max(matrix), cmap='coolwarm')
        ax[i].set_xticks(range(len(labels[task])))
        ax[i].set_title(f'{prefix}-{task}', fontsize=10)
    ax[0].set_yticks([])
    ax[1].set_xticks(range(len(labels[task])))
    ax[1].yaxis.tick_right()
    ax[1].set_yticks(range(len(labels[task])))
    ax[1].set_yticklabels([f'{k}: {v}' for k, v in enumerate(labels[task])], fontsize=10)
    plt.show()    

In [None]:
V, S, L = [], [], {x:[] for x in labels} # pair vectors with labels
for source in np.random.choice(samples, 100, replace=False):
    loader = DataLoader(SemanticAlignDataset(source, VIEW_SIZE, batch_size, debug=False), batch_size=batch_size)
    for X, Y in loader:
        with torch.no_grad():
            V += list(visual_projection(X[0].to(DEVICE)).cpu().numpy())
            S += list(semantic_projection(X[1].to(DEVICE)).cpu().numpy())
        for i, label in enumerate(labels.keys(), 1):
            L[label] += Y[i].tolist()

In [None]:
def plot_labels(pe, te, color, labels, title):
    fig, ax = plt.subplots(1, 2, figsize=(9, 4))
    cmap = colormaps['gist_rainbow']
    for c in range(len(labels)):
        ax[0].scatter(te[color==c,0], te[color==c,1], s=5, color=cmap(c/len(labels)), alpha=0.5)
        ax[1].scatter(pe[color==c,0], pe[color==c,1], s=5, color=cmap(c/len(labels)), alpha=0.5, label=labels[c])
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].set_title(f'{title}: tSNE')
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    ax[1].set_title(f'{title}: PCA')
    ax[1].legend(bbox_to_anchor=(1, 1), frameon=False)
    plt.show()
    

In [None]:
pca = PCA(n_components=10)
norm = StandardScaler().fit(V)
E = norm.transform(V)
P = pca.fit_transform(E)
T = TSNE(n_components=2, perplexity=90).fit_transform(E)

for task in labels:
    plot_labels(P, T, np.array(L[task]), labels[task], f'Visual {task}')

In [None]:
pca = PCA(n_components=10)
norm = StandardScaler().fit(S)
E = norm.transform(S)
P = pca.fit_transform(E)
T = TSNE(n_components=2, perplexity=90).fit_transform(E)

for task in labels:
    plot_labels(P, T, np.array(L[task]), labels[task], f'Semantic {task}')

## Simple search

In [None]:
INDEX = 'visual-align'

In [None]:
qclient = QdrantClient(':memory:')
qclient.delete_collection(collection_name=INDEX)
qclient.create_collection(
    collection_name=INDEX, 
    vectors_config=VectorParams(size=LATENT_DIM, distance=Distance.COSINE),
)

In [None]:
payload = []

for observation in NONDOC:
    payload.append({ 'observation':observation, 'action':None })

for observation in SCOPE.values():
    payload.append({ 'observation':observation, 'action':'fix rotation' })
    
for angle in range(1, 360):
    angle, d1, d2 = (360 - angle, '', 'counter') if angle > 180 else (angle, 'counter', '')
    payload.append({'observation':f'rotated {angle} degrees {d1}clockwise',
                    'action':f'rotate {angle} degrees {d2}clockwise'})

payload += [
    { 'observation':'straight', 'action':'fix zoom' },
    { 'observation':'turned on the left side', 'action':'rotate 90 degrees clockwise' },
    { 'observation':'turned upside-down', 'action':'rotate 180 degrees' },
    { 'observation':'turned on the right side', 'action':'rotate 90 degrees counterclockwise' },
    
    { 'observation':'no text and no lines', 'action':'zoom out' },
    { 'observation':'text', 'action':'fix rotation' },
    { 'observation':'lines', 'action':'zoom out' },
    { 'observation':'text and lines', 'action':'fix rotation' },
]

for scope in SCOPE.values():
    for content in ['text and lines','text','lines']:
        for rotation in list(ORIENTATION.values()) + ['rotated counterclockwise','rotated clockwise']:
            payload.append({'observation':f'{scope} with {content} {rotation}',
                            'action':''})

In [None]:
embeddings = semantic_encoder.encode([x['observation'] for x in payload])
with torch.no_grad():
    embeddings = semantic_projection(torch.Tensor(embeddings).to(DEVICE)).cpu().numpy()
    
embeddings.shape    

In [None]:
embeddings[0,:].shape

In [None]:
qclient.upload_collection(
    collection_name=INDEX,
    vectors=embeddings,
    payload=payload,
)

In [None]:
def find_similar(observation_vector: list, limit: int = 5):
    # search for closest vectors with type `observation`
    results = qclient.search(
        collection_name=INDEX,
        query_vector=observation_vector,
        #query_filter=Filter(must=[FieldCondition(key='type', match=MatchValue(value='observation'))]),
        limit=limit,
    )
    # return top matches with scores
    return [{**x.payload, **{'score':x.score}} for x in results if x.score >= 0.25]


In [None]:
find_similar(embeddings[0,:], 3)

In [None]:
plt.rcParams['font.family'] = 'monospace'
sample = np.random.choice(samples)
loader = DataLoader(SemanticAlignDataset(sample, VIEW_SIZE, batch_size, debug=True), batch_size=batch_size)
for X, Y in loader:
    with torch.no_grad():
        # extract visual features
        P = visual_projection(X[0].to(DEVICE)).cpu().numpy()
        # get predictions
        preds = model(X[0].to(DEVICE), X[1].to(DEVICE))
        matrix = preds[0].cpu().numpy()
        tasks = [torch.argmax(p, dim=1).cpu().numpy() for p in preds[1:]]
        fig, ax = plt.subplots(figsize=(4, 4))
        ax.imshow(matrix, 'coolwarm')
        ax.yaxis.tick_right()
        ax.set_yticks(range(len(Y[0])))
        ax.set_yticklabels(['{}. {}'.format(i, y) for i, y in enumerate(Y[0])], fontsize=8)
        ax.set_xticks(range(len(Y[0])))
        plt.title('Confusion Matrix')
        plt.show()
        
    for i in range(batch_size):
        similar = '\n'.join([f"{x['observation']} [score {x['score']:.2f}]" for x in find_similar(P[i,:], 3)])
        fig, ax = plt.subplots(figsize=(4, 4))
        ax.imshow(X[0][i,:].squeeze(), 'gray')
        ax.axis('off')
        ax.set_title(f'True:\n{Y[0][i]}\n\nSearch:\n{similar}', ha='left', x=0, fontsize=10)

        text = [('    Task        Visual        Semantic      True\n'
                 '   ------------------------------------------------------')] +\
               [(f'    {task:<12}{labels[task][tasks[k][i]]:<12}  {labels[task][tasks[k + 4][i]]:<12}'
                 f'  {labels[task][Y[k + 1][i]]:<12}') for k, task in enumerate(labels)]

        ax.annotate('\n'.join(text), xy=(1, 0.5), xytext=(0, 10), xycoords=('axes fraction','figure fraction'),
                    textcoords='offset points', size=10, ha='left', va='center')
        plt.show()

In [None]:
torch.save(visual_projection.state_dict(), f'./models/visual-projection-CNN.pt')
torch.save(visual_classifier.state_dict(), f'./models/visual-classifier-CNN.pt')
torch.save(semantic_projection.state_dict(), f'./models/semantic-projection-CNN.pt')
torch.save(semantic_classifier.state_dict(), f'./models/semantic-classifier-CNN.pt')