In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import Levenshtein
import time

from morse.models import CNNResidualBlock, TransformerResidualBlock, PoolingTransition, CNNTransformer, CTCHead
from morse.my_datasets import ListDataset, load_tensors, filenames_to_torch
from morse.samplers import LongCTCSampler
from morse.augmentations import rotation_transform
from morse.text_helpers import Vectorizer

In [2]:
d_input = 65
d_model = 64
d_inner = 128
d_output = 63

batch_size = 16

seq_len = 501

In [3]:
output = CNNResidualBlock(d_model=d_model, d_inner=d_inner, 
    apply_post_norm=False)(torch.randn((batch_size, d_model, seq_len)))
print(output.shape)
print('mean', torch.mean(output))
print('var', torch.var(output))

torch.Size([16, 64, 501])
mean tensor(-0.0019, grad_fn=<MeanBackward0>)
var tensor(1.3651, grad_fn=<VarBackward0>)


In [4]:
# output = TransformerResidualBlock(d_model, d_inner, apply_post_norm=False)(torch.randn((seq_len, batch_size, d_model)))
# print(output.shape)
# print('mean', torch.mean(output))
# print('var', torch.var(output))

In [5]:
# TransformerResidualBlock(d_model, d_inner, apply_post_norm=False).calculate_attention_entropy(torch.randn((seq_len, batch_size, d_model)))

In [6]:
PoolingTransition(overlap=True)(torch.randn((batch_size, d_model, seq_len + 1))).shape

torch.Size([16, 64, 251])

In [7]:
out = CNNTransformer(d_input, d_model, n_pools=3, n_blocks_before_pool=2, n_transformer_blocks=2,
               head_block=CTCHead(d_model, d_output),
               make_cnn_block=lambda: CNNResidualBlock(d_model, d_inner),
               make_transformer_block=lambda: TransformerResidualBlock(d_model, d_ffn=d_inner))(torch.randn((batch_size, d_input, seq_len)))

print(out.shape)
print('mean', torch.mean(out))
print('var', torch.var(out))

torch.Size([16, 63, 62])
mean tensor(0.0356, grad_fn=<MeanBackward0>)
var tensor(0.7015, grad_fn=<VarBackward0>)
