In [2]:
from dataclasses import dataclass, field

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import Resize

In [3]:
@dataclass
class Settings():
    
    input_folder: str = "./data/"
    model_folder: str = "./models/"
    
    batch_size: int = 2
    
    sos_token: int = 59
    eos_token: int = 60
    pad_token: int = 61
    
    max_frames: int = 250
    nb_feature: int = 164
    max_phrases: int = 31 + 1
    nb_token: int = 58 + 3
    
    x_shape: tuple[int, int] = (128,nb_feature)
    
    encoder_feature = 256
    encoder_block = 2
    encoder_head = 8
    
    decoder_feature = 256
    decoder_block = 2
    decoder_head = 8
    
S = Settings()
S

Settings(input_folder='./data/', model_folder='./models/', batch_size=2, sos_token=59, eos_token=60, pad_token=61, max_frames=250, nb_feature=164, max_phrases=32, nb_token=61, x_shape=(128, 164))

In [4]:
class DS(Dataset):
    
    def __init__(self):
        super(DS, self).__init__()
        
        self.xs = torch.load(S.input_folder + "x.torch")
        self.ys = torch.load(S.input_folder + "y.torch")
        
    def __getitem__(self, index):
        return self.xs[index], self.ys[index]
    
    def __len__(self):
        return len(self.ys)

In [5]:
def find_3d_tensor_shape(T: torch.Tensor, dim: int):
    T = T.clone()
    if dim == 0: T = T[:, 0, 0]
    elif dim == 1: T = T[0, :, 0]
    elif dim == 2: T = T[0, 0, :]
    T[:] = 1.
    T = T.sum()
    return T

In [6]:
class DataProcessing(nn.Module):

    def __init__(self):
        super(DataProcessing, self).__init__()
        
        self.max_frames = torch.tensor(S.max_frames)
        self.x_shape = S.x_shape
        self.zero_tensor = torch.tensor(0.)
    
    def forward(self, x):
        if len(x.shape) <= 2: x = x[None]
        
        x = torch.where(x.isnan(), self.zero_tensor, x)
        
        nb_frame = find_3d_tensor_shape(x, 1).to(torch.int64)
        pad_value = torch.where(nb_frame < self.max_frames, self.max_frames - nb_frame, self.zero_tensor.to(torch.int64))
        x = F.pad(x, (0,0,0,pad_value), "constant", self.zero_tensor)
        x = F.interpolate(x[None], self.x_shape)[0]
        
        return x

In [7]:
def one_hot(T, nb_class, batch_first=True):
    assert len(T.unique()) <= nb_class, "nb_class should be higher then number of unique element in tensor T"
    T_dtype = T.dtype
    if not batch_first: T = T[None]
    out = []
    for batch in T:
        out.append(torch.stack([torch.where(batch == uniq, 1, 0) for uniq in range(nb_class)]).T)
    out = torch.stack(out)
    if not batch_first: out = out[0]
    return out.to(T_dtype)

In [8]:
class FrameEmbedding(nn.Module):
    
    def __init__(self):
        super(FrameEmbedding, self).__init__()
        
        self.l1 = nn.Linear(S.x_shape[1], S.encoder_feature, False)
        self.l2 = nn.Linear(S.encoder_feature, S.encoder_feature, False)
        
        self.pe1 = nn.Parameter(torch.zeros((S.x_shape[0], S.encoder_feature)))
        
    def forward(self, x):
        x = F.gelu(self.l1(x))
        x = self.l2(x)
        
        x = x + self.pe1
        
        return x

In [9]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.device = "cpu"
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        depth = embed_dim // num_heads
        
        self.lq = nn.ModuleList([nn.Linear(embed_dim, depth) for _ in range(num_heads)])
        self.lk = nn.ModuleList([nn.Linear(embed_dim, depth) for _ in range(num_heads)])
        self.lv = nn.ModuleList([nn.Linear(embed_dim, depth) for _ in range(num_heads)])
        
        self.lo = nn.Linear(embed_dim, embed_dim)
        
    def scaled_dot_product_attention(self, q, k, v, attn_mask):
        qkt = torch.bmm(q, k.permute(0,2,1))
        scaling_factor = find_3d_tensor_shape(q, 2).float().to(q.device)
        scaled_qkt = qkt / scaling_factor
        
        if attn_mask != None: attn_mask = torch.where(attn_mask == 0, torch.FloatTensor([-torch.inf]), torch.FloatTensor([0]))
        else: attn_mask = torch.zeros(*scaled_qkt.shape).to(q.device)
        attn_qkt = scaled_qkt + attn_mask
        softmax_qkt = F.softmax(attn_qkt, 1)
        
        output = torch.bmm(softmax_qkt, v)
        
        return output
        
    def forward(self, q, k, v, attn_mask=None):
        
        multi_attn = [self.scaled_dot_product_attention(lq(q), lk(k), lv(v), attn_mask)
            for lq, lk, lv in zip(self.lq, self.lk, self.lv)]
        
        multi_head = torch.cat(multi_attn, 2)
        multi_head_attn = self.lo(multi_head)
        return multi_head_attn

In [10]:
class Encoder(nn.Module):
    
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.blocks = nn.ModuleList([self.encorder_block() for _ in range(S.encoder_block)])
        
    def encorder_block(self):
        return nn.ModuleList([
            MultiHeadAttention(S.encoder_feature, S.encoder_head),
            nn.LayerNorm(S.encoder_feature),
            nn.Linear(S.encoder_feature, S.encoder_feature),
            nn.Linear(S.encoder_feature, S.encoder_feature),
            nn.LayerNorm(S.encoder_feature),
        ])

    def forward(self, x):
        
        attn_mask = x.sum(2)
        attn_mask = torch.where(attn_mask == 0, torch.FloatTensor([0.]), torch.FloatTensor([1.]))
        attn_mask = attn_mask[:,:,None].repeat(1,1,S.x_shape[0])
        
        for mha, ln1, l1, l2, ln2 in self.blocks:
            
            _x = x
            x = mha(x, x, x, attn_mask)
            x = ln1(x + _x)
            
            _x = x
            x = F.gelu(l1(x))
            x = l2(x)
            x = ln2(x + _x)
            
        return x

In [11]:
class PhraseEmbedding(nn.Module):
    
    def __init__(self):
        super(PhraseEmbedding, self).__init__()
        
        self.emb1 = nn.Embedding(S.nb_token+1, S.decoder_feature)
        self.pe1 = nn.Parameter(torch.zeros((S.x_shape[0], S.decoder_feature)))
        
    def forward(self, y):
        y = F.pad(y, (1,0,0,0), "constant", S.sos_token)
        y = F.pad(y, (0, S.x_shape[0] - y.shape[1],0,0), "constant", S.pad_token)
        y = self.emb1(y.to(torch.int64))
        
        y = y + self.pe1
        
        return y

In [12]:
class Decoder(nn.Module):
    
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.causal_mha = MultiHeadAttention(S.decoder_feature, S.decoder_head)
        self.causal_ln = nn.LayerNorm(S.decoder_feature)
        
        self.blocks = nn.ModuleList([self.decoder_block() for _ in range(S.decoder_block)])
        
    def decoder_block(self):
        return nn.ModuleList([
            MultiHeadAttention(S.decoder_feature, S.decoder_head),
            nn.LayerNorm(S.decoder_feature),
            nn.Linear(S.decoder_feature, S.decoder_feature),
            nn.Linear(S.decoder_feature, S.decoder_feature),
            nn.LayerNorm(S.decoder_feature),
        ])
        
    def forward(self, encoder_output, x):
        
        causal_mask = torch.arange(S.x_shape[0])[:, None] >= torch.arange(S.x_shape[0])
        causal_mask = causal_mask.float().repeat(x.shape[0],1,1).to(x.device)
        
        _x = x
        self.causal_mha(x, x, x, causal_mask)
        x = self.causal_ln(x + _x)
        
        for mha, ln1, l1, l2, ln2 in self.blocks:
            
            _x = x
            x = mha(x, encoder_output, encoder_output, causal_mask)
            x = ln1(x + _x)
            
            _x = x
            x = F.gelu(l1(x))
            x = l2(x)
            x = ln2(x + _x)
        
        x = x[:,:S.max_phrases,:]
        
        return x

In [13]:
class Classifier(nn.Module):
    
    def __init__(self):
        super(Classifier, self).__init__()
        
        self.l1 = nn.Linear(S.decoder_feature, S.nb_token)
        
    def forward(self, x):
        x = self.l1(x)
        x = F.softmax(x, 1)
        
        return x
    
cls = Classifier()

In [14]:
class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        
        self.DP = DataProcessing()
        self.FE = FrameEmbedding()
        self.ENC = Encoder()
        self.PE = PhraseEmbedding()
        self.DEC = Decoder()
        self.CLS = Classifier()
        
    def forward(self, x, y):
        x = self.DP(x)
        x = self.FE(x)
        x = self.ENC(x)
        
        y = self.PE(y)
        x = self.DEC(x, y)
        x = self.CLS(x)
        
        return x

In [15]:
model = Model()

In [16]:
x = torch.randn(3,250,164)
y = torch.randint(0,59,(3,30))

In [17]:
output = model(x, y)
output.shape

torch.Size([3, 32, 61])

In [17]:
class InferenceModel(nn.Module):
    
    def __init__(self, train_model):
        super(InferenceModel, self).__init__()
        
        self.DP = train_model.DP
        self.FE = train_model.FE
        self.ENC = train_model.ENC
        self.PE = train_model.PE
        self.DEC = train_model.DEC
        self.CLS = train_model.CLS
        
    def forward(self, x):
        
        x = self.DP(x)
        x = self.FE(x)
        enc_out = self.ENC(x)
        
        phrase = torch.full((1,S.max_phrases), S.pad_token).to(x.device)
        
        for i in range(S.max_phrases):
            y = self.PE(phrase)
            x = self.DEC(enc_out, y)
            x = self.CLS(x)
            
            phrase[:, :i+1] = x.argmax(2)[:, :i+1]
            
        return phrase

In [18]:
inference_model = InferenceModel(model).eval()
x_inf = torch.randn(12,164)
output_inf = inference_model(x_inf)
output_inf.shape

torch.Size([1, 32])

In [18]:
torch.onnx.export(
    model,
    {"x": x, "y": y},
    S.model_folder + "v2.onnx",
    input_names=["x", "y"],
    output_names=["output"],
    opset_version=12,
)

  x = F.pad(x, (0,0,0,pad_value), "constant", self.zero_tensor)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(


verbose: False, log level: Level.ERROR



  _C._jit_pass_onnx_graph_shape_type_inference(


In [20]:
import onnx
from onnx_tf.backend import prepare

2023-07-02 17:42:13.203354: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [21]:
onnx_model = onnx.load(S.model_folder + "v2.onnx")
onnx.checker.check_model(onnx_model)

In [22]:
prepare(onnx_model).export_graph(S.model_folder + "v2_tf")

2023-07-02 17:42:15.331815: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-07-02 17:42:15.332144: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2023-07-02 17:42:26.325275: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'x' 

INFO:tensorflow:Assets written to: ./models/v2_tf/assets


INFO:tensorflow:Assets written to: ./models/v2_tf/assets


In [None]:
import tensorflow as tf
import numpy as np

In [None]:
tf_model = tf.saved_model.load(S.model_folder + "v2_tf/")
tf_model.trainable = False

In [None]:
tf_x = tf.cast(tf.convert_to_tensor(np.random.random((3,250,164))), tf.float32)
tf_y = tf.cast(tf.convert_to_tensor(np.random.randint(0, 59, (3,30))), tf.int64)

In [None]:
output = tf_model(x=tf_x, y=tf_y)
output["output"].shape

2023-06-30 07:46:42.641003: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor '8967' with dtype float and shape [62,256]
	 [[{{node 8967}}]]
2023-06-30 07:46:42.789127: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: StatefulPartitionedCall/assert_equal_3/Assert/AssertGuard/branch_executed/_257


TensorShape([3, 32, 61])

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(S.model_folder + "v2_tf/")
converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]

In [None]:
tflite_model = converter.convert()

2023-06-30 07:46:43.687791: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'serving_default_y' with dtype int64 and shape [3,30]
	 [[{{node serving_default_y}}]]
2023-06-30 07:46:43.719047: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2023-06-30 07:46:43.719070: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2023-06-30 07:46:43.719359: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: ./models/v2_tf/
2023-06-30 07:46:43.724031: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-06-30 07:46:43.724053: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: ./models/v2_tf/
2023-06-30 07:46:43.735005: I tensorf

In [None]:
with open(S.model_folder + "v2.tflite", "wb") as f: f.write(tflite_model)