In [1]:
# Подавление предупреждений
import warnings
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

# Импорт необходимых библиотек
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModel,AutoModelForMaskedLM, RobertaModel, RobertaTokenizer
import torch
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange
from typing import Tuple, Callable
from torch.autograd import Function
import gc
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score

In [4]:
class Embedding():
    def __init__(self, model_name='jina', pooling=None):
        self.model_name = model_name
        self.pooling = pooling
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if model_name == 'jina':
            self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True)
            self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True).to(self.device)
        elif model_name == 'xlm-roberta-base':
            self.tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
            self.model = AutoModel.from_pretrained('xlm-roberta-base').to(self.device)
        elif model_name == 'canine-c':
            self.tokenizer = AutoTokenizer.from_pretrained('google/canine-c')
            self.model = AutoModel.from_pretrained('google/canine-c').to(self.device)
        else:
            raise ValueError('Unknown name of Embedding')
    def _mean_pooling(self, X):
        def mean_pooling(model_output, attention_mask):
            token_embeddings = model_output[0]
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings.unsqueeze(1)
    
    def get_embeddings(self, X):
        if self.pooling is None:
            if self.model_name == 'canine-c_emb':
                max_len = 329
            else:
                max_len = 95
            encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
            with torch.no_grad():
                features = self.model(**encoded_input)[0].detach().cpu().float().numpy()
            res = np.pad(features[:, :max_len, :], ((0, 0), (0, max(0, max_len - features.shape[1])), (0, 0)), "constant")
            return torch.tensor(res)
        elif self.pooling == 'mean':
            return self._mean_pooling(X)
        else:
            raise ValueError('Unknown type of pooling')

In [34]:
class PScan(Function):
    @staticmethod
    def forward(ctx, A_inp, X_inp):
        A, X = A_inp.clone(), X_inp.clone()
        A, X = rearrange(A, "l b d s -> b d l s"), rearrange(X, "l b d s -> b d l s")
        PScan._forward(A, X)
        ctx.save_for_backward(A.clone(), X)
        return rearrange(X, "b d l s -> b l d s")

    @staticmethod
    def backward(ctx, grad_inp: Tensor) -> Tuple[Tensor, Tensor]:
        A, X = ctx.saved_tensors
        A = torch.cat((A[:, :, :1], A[:, :, 1:].flip(2)), dim = 2)
        grad_out = rearrange(grad_inp, "b l d s -> b d l s")
        grad_out = grad_out.flip(2)
        PScan._forward(A, grad_out)
        grad_out = grad_out.flip(2)
        Q = torch.zeros_like(X)
        Q[:, :, 1:].add_(X[:, :, :-1] * grad_out[:, :, 1:])
        return rearrange(Q, "b d l s -> b l d s"), rearrange(grad_out, "b d l s -> b l d s")

    @staticmethod
    def _forward(A: Tensor, X: Tensor) -> None:
        b, d, l, s = A.shape
        num_steps = int(math.log2(l))
        Av, Xv = A, X
        for _ in range(num_steps):
            T = Xv.size(2)
            Av, Xv = Av[:, :, :T].reshape(b, d, T // 2, 2, -1), Xv[:, :, :T].reshape(b, d, T // 2, 2, -1)
            Xv[:, :, :, 1].add_(Av[:, :, :, 1].mul(Xv[:, :, :, 0]))
            Av[:, :, :, 1].mul_(Av[:, :, :, 0])
            Av, Xv = Av[:, :, :, 1], Xv[:, :, :, 1]
        for k in range(num_steps - 1, -1, -1):
            Av, Xv = A[:, :, 2**k - 1 : l : 2**k], X[:, :, 2**k - 1 : l : 2**k]
            T = 2 * (Xv.size(2) // 2)
            if T < Xv.size(2):
                Xv[:, :, -1].add_(Av[:, :, -1].mul(Xv[:, :, -2]))
                Av[:, :, -1].mul_(Av[:, :, -2])
            Av, Xv = Av[:, :, :T].reshape(b, d, T // 2, 2, -1), Xv[:, :, :T].reshape(b, d, T // 2, 2, -1)
            Xv[:, :, 1:, 0].add_(Av[:, :, 1:, 0].mul(Xv[:, :, :-1, 1]))
            Av[:, :, 1:, 0].mul_(Av[:, :, :-1, 1])

pscan: Callable[[Tensor, Tensor], Tensor] = PScan.apply

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x: Tensor) -> Tensor:        
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight

class MambaBlock(nn.Module):
    def __init__(self, d_input, d_model):
        super(MambaBlock, self).__init__()
        self.in_proj = nn.Linear(d_input, d_model)
        self.s_B = nn.Linear(d_model, d_model)
        self.s_C = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_input)

    def forward(self, x):
        x = x.to(device)
        x = self.in_proj(x)
        B, C = self.s_B(x), self.s_C(x)
        res = self.out_proj(x + B + C)
        return res

class Mamba(nn.Module):
    def __init__(self, num_layers, d_input, d_model, num_classes, model_name='jina', pooling=None, transform_idx_to_labels=['anger', 'disgust', 'fear', 'sadness', 'neutral', 'happiness', 'enthusiasm']):
        super(Mamba, self).__init__()
        embed = Embedding(model_name, pooling)
        self.embedding = embed.get_embeddings
        self.layers = nn.ModuleList([MambaBlock(d_input, d_model) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_input, num_classes)
        self.model_name = model_name
        self.transform_idx_to_labels = transform_idx_to_labels

    def forward(self, seq):
        seq = torch.tensor(self.embedding(seq)).to(device)
        for mamba in self.layers:
            seq = mamba(seq)
        return self.fc_out(seq.mean(dim = 1))
    def predict(self, x):
        with torch.no_grad():
            output = self.forward(x)
            _, predictions = torch.max(output, dim=1)
            result = [self.transform_idx_to_labels[i] for i in (map(int, predictions))]
        return result

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [35]:
%%capture --no-stdout
model = Mamba(model_name='jina', pooling=None, num_layers = 2, d_input = 1024, d_model = 256, num_classes=7, transform_idx_to_labels=['anger', 'disgust', 'fear', 'sadness', 'neutral', 'joy', 'surprise']).to(device)
PATH_TO_MODEL = os.path.join(os.path.join("."), "models")
checkpoint = torch.load(os.path.join(PATH_TO_MODEL, "MELD_checkpoint.pth"))
model.load_state_dict(checkpoint['model_state_dict'])

In [42]:
model.predict(['hello', 'Wow! This for me??', 'I work as a teacher.', 'Go away, I hate you.', "I have nothing, I'm a loser."])

['neutral', 'surprise', 'neutral', 'anger', 'sadness']

In [43]:
%%capture --no-stdout
model = Mamba(model_name='jina', pooling=None, num_layers = 2, d_input = 1024, d_model = 64, num_classes=7).to(device)
PATH_TO_MODEL = os.path.join(os.path.join("."), "models")
checkpoint = torch.load(os.path.join(PATH_TO_MODEL, "RESD_checkpoint.pth"))
model.load_state_dict(checkpoint['model_state_dict'])

In [49]:
model.predict(['Уходи, я тебя ненавижу', 'Мне страшно!', 'Как тут грязно, все в пыли и сырости', 'Я работаю учителем', 'Ура! Спасибо!'])

['anger', 'fear', 'anger', 'neutral', 'happiness']