# Импорт модулей

In [5]:
# Подавление предупреждений
import warnings
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

# Импорт необходимых библиотек
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import polars as pl
import pandas as pd
import yfinance as yf
import sklearn
import networkx as nx
import ipywidgets
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from torch import Tensor
from einops import rearrange
from typing import Tuple, Callable
from torch.autograd import Function

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from mpl_toolkits.mplot3d import Axes3D

from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

In [2]:
SEED = 42
BATCH_SIZE = 16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

# Подготовка данных

In [6]:
MAX_LENGTH=32

In [7]:
dataset = load_dataset('imdb')
print(dataset)

# Используем предобученный токенизатор
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=MAX_LENGTH
    )

tokenized_datasets = dataset.map(tokenize_function, batched=True)

tokenized_datasets = tokenized_datasets.remove_columns(['text'])
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
tokenized_datasets.set_format('torch')

README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

plain_text/unsupervised-00000-of-00001.p(…):   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [8]:
train_dataset = tokenized_datasets['train']
test_dataset = tokenized_datasets['test']

train_texts = train_dataset['input_ids']
test_texts = test_dataset['input_ids']
train_labels = train_dataset['labels']
test_labels = test_dataset['labels']

X_train_tensor = torch.stack(list(train_texts))
X_test_tensor = torch.stack(list(test_texts))
y_train_tensor = torch.stack(list(train_labels))
y_test_tensor = torch.stack(list(test_labels))

In [9]:
X_train_tensor = X_train_tensor.unsqueeze(1)
X_test_tensor = X_test_tensor.unsqueeze(1)

In [10]:
X_train_tensor = X_train_tensor.float().to(DEVICE)
X_test_tensor  = X_test_tensor.float().to(DEVICE)
y_train_tensor = y_train_tensor.float().to(DEVICE)
y_test_tensor  = y_test_tensor.float().to(DEVICE)

# Класс модели Mamba

In [3]:
class PScan(Function):
    @staticmethod
    def forward(ctx, A_inp, X_inp):
        A, X = A_inp.clone(), X_inp.clone()
        A, X = rearrange(A, "l b d s -> b d l s"), rearrange(X, "l b d s -> b d l s")
        PScan._forward(A, X)
        ctx.save_for_backward(A.clone(), X)
        return rearrange(X, "b d l s -> b l d s")

    @staticmethod
    def backward(ctx, grad_inp: Tensor) -> Tuple[Tensor, Tensor]:
        A, X = ctx.saved_tensors
        A = torch.cat((A[:, :, :1], A[:, :, 1:].flip(2)), dim = 2)
        grad_out = rearrange(grad_inp, "b l d s -> b d l s")
        grad_out = grad_out.flip(2)
        PScan._forward(A, grad_out)
        grad_out = grad_out.flip(2)
        Q = torch.zeros_like(X)
        Q[:, :, 1:].add_(X[:, :, :-1] * grad_out[:, :, 1:])
        return rearrange(Q, "b d l s -> b l d s"), rearrange(grad_out, "b d l s -> b l d s")

    @staticmethod
    def _forward(A: Tensor, X: Tensor) -> None:
        b, d, l, s = A.shape
        num_steps = int(math.log2(l))
        Av, Xv = A, X
        for _ in range(num_steps):
            T = Xv.size(2)
            Av, Xv = Av[:, :, :T].reshape(b, d, T // 2, 2, -1), Xv[:, :, :T].reshape(b, d, T // 2, 2, -1)
            Xv[:, :, :, 1].add_(Av[:, :, :, 1].mul(Xv[:, :, :, 0]))
            Av[:, :, :, 1].mul_(Av[:, :, :, 0])
            Av, Xv = Av[:, :, :, 1], Xv[:, :, :, 1]
        for k in range(num_steps - 1, -1, -1):
            Av, Xv = A[:, :, 2**k - 1 : l : 2**k], X[:, :, 2**k - 1 : l : 2**k]
            T = 2 * (Xv.size(2) // 2)
            if T < Xv.size(2):
                Xv[:, :, -1].add_(Av[:, :, -1].mul(Xv[:, :, -2]))
                Av[:, :, -1].mul_(Av[:, :, -2])
            Av, Xv = Av[:, :, :T].reshape(b, d, T // 2, 2, -1), Xv[:, :, :T].reshape(b, d, T // 2, 2, -1)
            Xv[:, :, 1:, 0].add_(Av[:, :, 1:, 0].mul(Xv[:, :, :-1, 1]))
            Av[:, :, 1:, 0].mul_(Av[:, :, :-1, 1])

pscan: Callable[[Tensor, Tensor], Tensor] = PScan.apply

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x: Tensor) -> Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight

class MambaBlock(nn.Module):
    def __init__(self, d_input, d_model):
        super(MambaBlock, self).__init__()
        self.in_proj = nn.Linear(d_input, d_model)
        self.s_B = nn.Linear(d_model, d_model)
        self.s_C = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_input)

    def forward(self, x):
        x = self.in_proj(x)
        B, C = self.s_B(x), self.s_C(x)
        return self.out_proj(x + B + C)

class Mamba(nn.Module):
    def __init__(self, num_layers, d_input, d_model):
        super(Mamba, self).__init__()
        self.layers = nn.ModuleList([MambaBlock(d_input, d_model) for _ in range(num_layers)])

    def forward(self, seq):
        for mamba in self.layers:
            seq = mamba(seq)
        return seq

# Обучение модели

In [20]:
model = Mamba(num_layers = 6, d_input = MAX_LENGTH, d_model = 256)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-3)

# input_tensor = torch.randn(32, 128, 512)
# target_tensor = torch.randn(32, 128, 512)

num_epochs = 100
for epoch in range(num_epochs):
    model.train()

    optimizer.zero_grad()

    # Прямой проход
    output = model(X_train_tensor)[:, 0, -1]

    # Вычисление потерь
    loss = criterion(output, y_train_tensor)

    # Обратный проход
    loss.backward()

    # Обновление параметров модели
    optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

Epoch 1/100, Loss: 231.74066162109375
Epoch 2/100, Loss: 185125.375
Epoch 3/100, Loss: 14978.0732421875
Epoch 4/100, Loss: 1011221.9375
Epoch 5/100, Loss: 2880290.75
Epoch 6/100, Loss: 70815.7421875
Epoch 7/100, Loss: 124809.453125
Epoch 8/100, Loss: 73536.625
Epoch 9/100, Loss: 307230.625
Epoch 10/100, Loss: 13970.140625
Epoch 11/100, Loss: 186294.890625
Epoch 12/100, Loss: 100210.421875
Epoch 13/100, Loss: 280247.21875
Epoch 14/100, Loss: 232172.671875
Epoch 15/100, Loss: 51546.171875
Epoch 16/100, Loss: 216705.390625
Epoch 17/100, Loss: 36097.13671875
Epoch 18/100, Loss: 39806.1015625
Epoch 19/100, Loss: 52052.82421875
Epoch 20/100, Loss: 4039.16650390625
Epoch 21/100, Loss: 104622.09375
Epoch 22/100, Loss: 21743.2109375
Epoch 23/100, Loss: 14499.1630859375
Epoch 24/100, Loss: 46758.52734375
Epoch 25/100, Loss: 20121.048828125
Epoch 26/100, Loss: 1587.1278076171875
Epoch 27/100, Loss: 5543.7158203125
Epoch 28/100, Loss: 12325.57421875
Epoch 29/100, Loss: 10673.380859375
Epoch 30/100

# Выводы

В данной работе представлен пример обучения модели Mamba на IMDb данных. В ходе выполнения задания я получил бесценный опыт использования инновационной архитектуры