In [None]:
# Пример использования xLSTM из оригинальной статьи
# https://arxiv.org/pdf/2405.04517
# Code base https://github.com/NX-AI/xlstm/tree/main

@article{xlstm,
  title={xLSTM: Extended Long Short-Term Memory},
  author={Beck, Maximilian and P{\"o}ppel, Korbinian and Spanring, Markus and Auer, Andreas and Prudnikova, Oleksandra and Kopp, Michael and Klambauer, G{\"u}nter and Brandstetter, Johannes and Hochreiter, Sepp},
  journal={arXiv preprint arXiv:2405.04517},
  year={2024}
}


In [None]:
import pandas as pd
import os
import numpy as np
import string
import random

import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import pickle

In [None]:
# Для воспроизводимости.

SEED = 42

torch.manual_seed(SEED)

random.seed(SEED)

torch.cuda.manual_seed(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# Считаем тренировочные данные.

df_train = pd.read_csv('/content/drive/MyDrive/Диплом_2024/data/therapy_train_true.csv')

df_train.head()

Unnamed: 0,text
0,Московский государственный медико-стоматологич...
1,Башкирский Государственный Медицинский Универс...
2,Министерство здравоохранения Республики Белару...
3,\nПаспортная часть\n\nФИО: \nВозраст: 29 лет\n...
4,\nИстория болезни.\nФамилия: \n Имя: \nОтчест...


In [None]:
# Считаем тестовые данные.

df_test = pd.read_csv('/content/drive/MyDrive/Диплом_2024/data/test_data.csv')

df_test.head()

Unnamed: 0,text
0,Министерство здравоохранения Российской Феде...
1,"Жалобы на слабость, отдышку, изнуряющий кашель..."
2,Диагноз: Пневмония в правой нижней доле. ДН II...
3,Основные жалобы - на периодический кашель с не...
4,Клинический диагноз: Внебольничная правосторон...


In [None]:
print(df_train.shape)

print(df_test.shape)

(65, 1)
(6, 1)


In [None]:
train_text = ' '.join(df_train['text'])

In [None]:
test_text = ' '.join(df_test['text'])

In [None]:
train_text[:200]

'Московский государственный медико-стоматологический университет\nкафедра пропедевтики внутренних болезней стоматологического факультета\n(заведующий кафедрой  - заслуженный деятель науки РФ, профессор Т'

In [None]:
# Уберем знаки препинания и лишние символы, приведем все к нижнему регистру.

train_text = train_text.replace('\n', ' ')

train_text = train_text.replace('\t', ' ')

test_text = test_text.replace('\n', ' ')

test_text = test_text.replace('\t', ' ')

train_text = train_text.lower()

test_text = test_text.lower()

train_text = train_text.translate(str.maketrans('', '', string.punctuation))

test_text = test_text.translate(str.maketrans('', '', string.punctuation))

# Уберем все цифры.

from string import digits

remove_digits = str.maketrans('', '', digits)

train_text = train_text.translate(remove_digits)

test_text = test_text.translate(remove_digits)


In [None]:
train_text[:500]

'московский государственный медикостоматологический университет кафедра пропедевтики внутренних болезней стоматологического факультета заведующий кафедрой   заслуженный деятель науки рф профессор токмачев юрий константинович            история болезни больного коновалова ад  лет  терапевтическое отделение палата           куратор студентка iii курса   группы дневного  стоматологического факультета коваленко александры валериевны    преподаватель пихлак аэ         москва     паспортные данные  фио'

In [None]:
test_text[:500]

'  министерство здравоохранения российской федерации  алтайский государственный медицинский университет кафедра пропедевтики внутренних болезней зав кафедрой проф            академическая история болезни          больной куратор студентка  группы iii курса лечебного факультета время курации  –  г преподаватель           паспортная часть   фио   возраст  лет   место работы центр занятости населения  место жительства  дата поступления в клинику  г  диагноз пневмония в правой нижней доле дн ii остры'

In [None]:
# Импортируем заранее подготовленные токенизаторы

with open('/content/drive/MyDrive/Диплом_2024/tokenizers/saved_word_to_int_therapy.pkl', 'rb') as f:
    word_to_int = pickle.load(f)

with open('/content/drive/MyDrive/Диплом_2024/tokenizers/saved_int_to_word_therapy.pkl', 'rb') as f:
    int_to_word = pickle.load(f)

In [None]:
# Размер словаря

len(word_to_int)

20413

In [None]:
# Длина последовательности
# Разделим текст на последовательности необходимой длины

SEQUENCE_LENGTH = 64
words = train_text.split()
samples = [words[i:i+SEQUENCE_LENGTH+1] for i in range(len(words)-SEQUENCE_LENGTH)]

In [None]:
# Создадим класс датасета

class TextDataset(Dataset):
    def __init__(self, samples, word_to_int):
        self.samples = samples
        self.word_to_int = word_to_int
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        input_seq = torch.LongTensor([self.word_to_int.get(word, 0) for word in sample[:-1]])
        target_seq = torch.LongTensor([self.word_to_int.get(word, 0) for word in sample[1:]])
        return input_seq, target_seq

In [None]:
# Создадим даталоадер

BATCH_SIZE = 32
train_dataset = TextDataset(samples, word_to_int)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
print(train_dataset[1])

(tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
        20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,  9, 10,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 42, 62]), tensor([ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
        21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,  9, 10, 36,
        37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
        55, 56, 57, 58, 59, 60, 61, 42, 62, 63]))


In [None]:
test_words = test_text.split()

test_samples = [test_words[i:i+SEQUENCE_LENGTH+1] for i in range(len(test_words)-SEQUENCE_LENGTH)]

In [None]:
test_dataset = TextDataset(test_samples, word_to_int)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)
print(test_dataset[1])

(tensor([2126, 3271, 3272, 3273,    2, 2133,    4,    5,    6,    7,    8, 3274,
          12, 3275, 3276,   21,   22,  205,   30,   31,   34,   32,   33, 2142,
          10,  276, 2803, 1117,   61,   39, 1068, 1069,   45, 1074,   26,   59,
         314, 3277, 3278, 3279,   59,   60,   64,   65,   66,   67,   61,  174,
        2812,   66,  835,  760, 3072, 3280,  183,  175,  176, 3281, 3282, 3283,
        3284, 3285, 3286, 3287]), tensor([3271, 3272, 3273,    2, 2133,    4,    5,    6,    7,    8, 3274,   12,
        3275, 3276,   21,   22,  205,   30,   31,   34,   32,   33, 2142,   10,
         276, 2803, 1117,   61,   39, 1068, 1069,   45, 1074,   26,   59,  314,
        3277, 3278, 3279,   59,   60,   64,   65,   66,   67,   61,  174, 2812,
          66,  835,  760, 3072, 3280,  183,  175,  176, 3281, 3282, 3283, 3284,
        3285, 3286, 3287, 3288]))


In [None]:
# Установим неоходимые библиотеки


!pip install Ninja

!pip install omegaconf

!pip install dacite

!pip install xlstm

Collecting xlstm
  Downloading xlstm-1.0.3-py3-none-any.whl (95 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.1/95.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xlstm
Successfully installed xlstm-1.0.3


In [None]:
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMLMModel, xLSTMLMModelConfig

In [None]:
# Зададим конфигурацию

xlstm_cfg = """
vocab_size: 20413
mlstm_block:
  mlstm:
    conv1d_kernel_size: 4
    qkv_proj_blocksize: 4
    num_heads: 4
slstm_block:
  slstm:
    backend: cuda
    num_heads: 4
    conv1d_kernel_size: 4
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 64
num_blocks: 7
embedding_dim: 128
slstm_at: [1]
"""

In [None]:
cfg = OmegaConf.create(xlstm_cfg)

In [None]:
cfg

{'vocab_size': 20413, 'mlstm_block': {'mlstm': {'conv1d_kernel_size': 4, 'qkv_proj_blocksize': 4, 'num_heads': 4}}, 'slstm_block': {'slstm': {'backend': 'cuda', 'num_heads': 4, 'conv1d_kernel_size': 4, 'bias_init': 'powerlaw_blockdependent'}, 'feedforward': {'proj_factor': 1.3, 'act_fn': 'gelu'}}, 'context_length': 64, 'num_blocks': 7, 'embedding_dim': 128, 'slstm_at': [1]}

In [None]:
cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))

In [None]:
cfg

xLSTMLMModelConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0, round_proj_up_dim_up=True, round_proj_up_to_multiple_of=64, _proj_up_dim=256, conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4, embedding_dim=128, bias=False, dropout=0.0, context_length=64, _num_blocks=7, _inner_embedding_dim=256)), slstm_block=sLSTMBlockConfig(slstm=sLSTMLayerConfig(hidden_size=128, num_heads=4, num_states=4, backend='cuda', function='slstm', bias_init='powerlaw_blockdependent', recurrent_weight_init='zeros', _block_idx=0, _num_blocks=7, num_gates=4, gradient_recurrent_cut=False, gradient_recurrent_clipval=None, forward_clipval=None, batch_size=8, input_shape='BSGNH', internal_input_shape='SBNGH', output_shape='BNSH', constants={}, dtype='bfloat16', dtype_b='float32', dtype_r='bfloat16', dtype_w='bfloat16', dtype_g='bfloat16', dtype_s='bfloat16', dtype_a='float32', enable_automatic_mixed_precision=True, initial_val=0.0, embedding_dim=128, conv1d_kernel_size=4, group_norm_we

In [None]:
# Объект модели

model = xLSTMLMModel(cfg)

{'verbose': True, 'with_cuda': True, 'extra_ldflags': ['-L/usr/local/cuda/lib', '-lcublas'], 'extra_cflags': ['-DSLSTM_HIDDEN_SIZE=128', '-DSLSTM_BATCH_SIZE=8', '-DSLSTM_NUM_HEADS=4', '-DSLSTM_NUM_STATES=4', '-DSLSTM_DTYPE_B=float', '-DSLSTM_DTYPE_R=__nv_bfloat16', '-DSLSTM_DTYPE_W=__nv_bfloat16', '-DSLSTM_DTYPE_G=__nv_bfloat16', '-DSLSTM_DTYPE_S=__nv_bfloat16', '-DSLSTM_DTYPE_A=float', '-DSLSTM_NUM_GATES=4', '-DSLSTM_SIMPLE_AGG=true', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL_VALID=false', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL=0.0', '-DSLSTM_FORWARD_CLIPVAL_VALID=false', '-DSLSTM_FORWARD_CLIPVAL=0.0', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_OPERATORS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '-U__CUDA_NO_BFLOAT162_OPERATORS__', '-U__CUDA_NO_BFLOAT162_CONVERSIONS__'], 'extra_cuda_cflags': ['-Xptxas="-v"', '-gencode', 'arch=compute_80,code=compute_80', '-res-usage', '--use_fast_math', '-O3', '-Xptxas -O3', '--extra-device-vectorization', '-DSLSTM_

Using /root/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py310_cu121/slstm_HS128BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py310_cu121/slstm_HS128BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module slstm_HS128BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module slstm_HS128BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)

xLSTMLMModel(
  (xlstm_block_stack): xLSTMBlockStack(
    (blocks): ModuleList(
      (0): mLSTMBlock(
        (xlstm_norm): LayerNorm()
        (xlstm): mLSTMLayer(
          (proj_up): Linear(in_features=128, out_features=512, bias=False)
          (q_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (k_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (v_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (conv1d): CausalConv1d(
            (conv): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
          )
          (conv_act_fn): SiLU()
          (mlstm_cell): mLSTMCell(
            (igate): Linear(in_features=768, out_features=4, bias=True)
            (fgate): Linear(in

In [None]:
# Определим некоторые гиперпараметры
# Посчитаем, сколько всего обучаемых параметров


epochs = 15
learning_rate = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
print(model)

# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")

total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.\n")

xLSTMLMModel(
  (xlstm_block_stack): xLSTMBlockStack(
    (blocks): ModuleList(
      (0): mLSTMBlock(
        (xlstm_norm): LayerNorm()
        (xlstm): mLSTMLayer(
          (proj_up): Linear(in_features=128, out_features=512, bias=False)
          (q_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (k_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (v_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (conv1d): CausalConv1d(
            (conv): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
          )
          (conv_act_fn): SiLU()
          (mlstm_cell): mLSTMCell(
            (igate): Linear(in_features=768, out_features=4, bias=True)
            (fgate): Linear(in

In [None]:
# Сохраним длину словаря в отдельную переменную

vocab_size = len(word_to_int)

print(vocab_size)

20413


In [None]:
# Training.

def train(model, epochs, dataloader, criterion, vocab_size):
    model.train()
    for epoch in range(epochs):
        running_loss = 0
        for input_seq, target_seq in dataloader:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            outputs = model(input_seq)
            target_seq = target_seq.contiguous().view(-1)
            outputs = outputs.view(-1, vocab_size)

            loss = criterion(outputs, target_seq.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.detach().cpu().numpy()
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch} loss: {epoch_loss:.3f}")

In [None]:
%%time

train(model, epochs, train_dataloader, criterion, vocab_size)

Epoch 0 loss: 1.050
Epoch 1 loss: 0.193
Epoch 2 loss: 0.161
Epoch 3 loss: 0.144
Epoch 4 loss: 0.130
Epoch 5 loss: 0.120
Epoch 6 loss: 0.114
Epoch 7 loss: 0.109
Epoch 8 loss: 0.105
Epoch 9 loss: 0.102
Epoch 10 loss: 0.100
Epoch 11 loss: 0.098
Epoch 12 loss: 0.096
Epoch 13 loss: 0.095
Epoch 14 loss: 0.093
CPU times: user 1h 10min 47s, sys: 16.5 s, total: 1h 11min 4s
Wall time: 1h 11min 54s


In [None]:
# Сохраним модель

checkpoint = {'state_dict': model.state_dict(),
              'optimizer' : optimizer.state_dict()}

torch.save(checkpoint, '/content/drive/My Drive/Диплом_2024/models/xLSTM_therapy.pth')

In [None]:
def load_checkpoint_for_eval(filepath, device, model):
    checkpoint = torch.load(filepath)

    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False

    model.eval()
    model = model.to(device)

    return model

In [None]:
model = load_checkpoint_for_eval('/content/drive/My Drive/Диплом_2024/models/xLSTM_therapy.pth', 'cuda', model)

In [None]:
def return_int_vector(text):
    words = text.split()
    input_seq = torch.LongTensor([word_to_int.get(word, 0) for word in words[-SEQUENCE_LENGTH:]]).unsqueeze(0)
    return input_seq

def sample_next(predictions):
    """
    Greedy sampling.
    """
    # Greedy approach.
    probabilities = F.softmax(predictions[:, -1, :], dim=-1).cpu()
    next_token = torch.argmax(probabilities)
    return int(next_token.cpu())

def text_generator(sentence, generate_length):
    model.eval()
    sample = sentence
    for i in range(generate_length):
        int_vector = return_int_vector(sample)
        if len(int_vector) >= SEQUENCE_LENGTH - 1:
            break
        input_tensor = int_vector.to(device)
        with torch.no_grad():
            predictions = model(input_tensor)
        next_token = sample_next(predictions)
        sample += ' ' + int_to_word[next_token]
    print(sample)
    print('\n')

In [None]:
sentences = [
    "хрипы"
]
generate_length = 5
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: хрипы
хрипы в нижних отделах с х




In [None]:
sentences = [
    "верхушечный"
]
generate_length = 7
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: верхушечный
верхушечный толчок визуально не определяется определяется пульсация в




In [None]:
sentences = [
    "аускультация"
]
generate_length = 6
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: аускультация
аускультация сердца и сосудов соотношение тонов сердца




In [None]:
sentences = [
    "жалобы"
]
generate_length = 6
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: жалобы
жалобы на кашель насморк кашель продуктивный частый




In [None]:
sentences = [
    "кашель"
]
generate_length = 6
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: кашель
кашель с незначительным количеством вязкой слизистой мокроты




In [None]:
# Тестирование.

model.eval()

preds = []
targets = []

for input_seq, target_seq in test_dataloader:
        input_seq, target_seq = input_seq.to(device), target_seq.to(device)

        with torch.no_grad():
            predictions = model(input_seq)

        target_seq = target_seq.contiguous().view(-1)
        target_seq_n = target_seq.cpu().numpy()
        target_seq_n = list(target_seq_n)



        predictions_np = predictions.cpu().numpy()
        batch_preds = np.argmax(predictions_np, axis=2)


        batch_preds_l = batch_preds.ravel()
        batch_preds_l = list(batch_preds_l)

        targets.extend(target_seq_n)
        preds.extend(batch_preds_l)



In [None]:
# Метрики качества на тестовой выборке.

print(classification_report(targets, preds))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00      5120
           2       1.00      1.00      1.00         5
           4       1.00      0.86      0.92         7
           5       1.00      1.00      1.00         8
           6       1.00      0.56      0.71         9
           7       0.96      0.99      0.97        74
           8       0.97      1.00      0.99        75
          10       1.00      1.00      1.00        25
          12       1.00      1.00      1.00        13
          21       0.47      0.20      0.28        80
          22       0.13      0.18      0.15       145
          23       0.87      0.54      0.67      1344
          25       0.69      0.83      0.75       448
          26       0.72      0.80      0.76       675
          29       0.00      0.00      0.00         0
          30       1.00      0.79      0.88        19
          31       1.00      0.95      0.97        20
          32       0.91    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
