In [1]:
import pandas as pd
import os
import numpy as np
import string
import random

import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import pickle

In [2]:
# Для воспроизводимости.

SEED = 42

torch.manual_seed(SEED)

random.seed(SEED)

torch.cuda.manual_seed(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
# Считаем тренировочные данные.

df_train = pd.read_csv('/content/drive/MyDrive/Диплом_2024/data/therapy_train_true.csv')

df_train.head()

Unnamed: 0,text
0,Московский государственный медико-стоматологич...
1,Башкирский Государственный Медицинский Универс...
2,Министерство здравоохранения Республики Белару...
3,\nПаспортная часть\n\nФИО: \nВозраст: 29 лет\n...
4,\nИстория болезни.\nФамилия: \n Имя: \nОтчест...


In [5]:
# Считаем тестовые данные.

df_test = pd.read_csv('/content/drive/MyDrive/Диплом_2024/data/therapy_test_true.csv')

df_test.head()

Unnamed: 0,text
0,\n\nМинистерство здравоохранения Российской Фе...


In [6]:
print(df_train.shape)

print(df_test.shape)

(65, 1)
(1, 1)


In [7]:
train_text = ' '.join(df_train['text'])

In [8]:
test_text = ' '.join(df_test['text'])

In [9]:
train_text[:200]

'Московский государственный медико-стоматологический университет\nкафедра пропедевтики внутренних болезней стоматологического факультета\n(заведующий кафедрой  - заслуженный деятель науки РФ, профессор Т'

In [10]:
# Уберем знаки препинания и лишние символы, приведем все к нижнему регистру.

train_text = train_text.replace('\n', ' ')

train_text = train_text.replace('\t', ' ')

test_text = test_text.replace('\n', ' ')

test_text = test_text.replace('\t', ' ')

train_text = train_text.lower()

test_text = test_text.lower()

train_text = train_text.translate(str.maketrans('', '', string.punctuation))

test_text = test_text.translate(str.maketrans('', '', string.punctuation))

# Уберем все цифры.

from string import digits

remove_digits = str.maketrans('', '', digits)

train_text = train_text.translate(remove_digits)

test_text = test_text.translate(remove_digits)


In [11]:
train_text[:500]

'московский государственный медикостоматологический университет кафедра пропедевтики внутренних болезней стоматологического факультета заведующий кафедрой   заслуженный деятель науки рф профессор токмачев юрий константинович            история болезни больного коновалова ад  лет  терапевтическое отделение палата           куратор студентка iii курса   группы дневного  стоматологического факультета коваленко александры валериевны    преподаватель пихлак аэ         москва     паспортные данные  фио'

In [12]:
test_text[:500]

'  министерство здравоохранения российской федерации  алтайский государственный медицинский университет кафедра пропедевтики внутренних болезней зав кафедрой проф            академическая история болезни          больной куратор студентка  группы iii курса лечебного факультета время курации  –  г преподаватель           паспортная часть   фио   возраст  лет   место работы центр занятости населения  место жительства  дата поступления в клинику  г  диагноз пневмония в правой нижней доле дн ii остры'

In [13]:
with open('/content/drive/MyDrive/Диплом_2024/tokenizers/saved_word_to_int_therapy.pkl', 'rb') as f:
    word_to_int = pickle.load(f)

with open('/content/drive/MyDrive/Диплом_2024/tokenizers/saved_int_to_word_therapy.pkl', 'rb') as f:
    int_to_word = pickle.load(f)

In [15]:
SEQUENCE_LENGTH = 64
words = train_text.split()
samples = [words[i:i+SEQUENCE_LENGTH+1] for i in range(len(words)-SEQUENCE_LENGTH)]

In [16]:
# Next, we need to create the custom dataset class.

class TextDataset(Dataset):
    def __init__(self, samples, word_to_int):
        self.samples = samples
        self.word_to_int = word_to_int
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        input_seq = torch.LongTensor([self.word_to_int.get(word, 0) for word in sample[:-1]])
        target_seq = torch.LongTensor([self.word_to_int.get(word, 0) for word in sample[1:]])
        return input_seq, target_seq

In [17]:
BATCH_SIZE = 32
train_dataset = TextDataset(samples, word_to_int)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
print(train_dataset[1])

(tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
        20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,  9, 10,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 42, 62]), tensor([ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
        21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,  9, 10, 36,
        37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
        55, 56, 57, 58, 59, 60, 61, 42, 62, 63]))


In [18]:
test_words = test_text.split()

test_samples = [test_words[i:i+SEQUENCE_LENGTH+1] for i in range(len(test_words)-SEQUENCE_LENGTH)]

In [19]:
test_dataset = TextDataset(test_samples, word_to_int)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)
print(test_dataset[1])

(tensor([2126, 3271, 3272, 3273,    2, 2133,    4,    5,    6,    7,    8, 3274,
          12, 3275, 3276,   21,   22,  205,   30,   31,   34,   32,   33, 2142,
          10,  276, 2803, 1117,   61,   39, 1068, 1069,   45, 1074,   26,   59,
         314, 3277, 3278, 3279,   59,   60,   64,   65,   66,   67,   61,  174,
        2812,   66,  835,  760, 3072, 3280,  183,  175,  176, 3281, 3282, 3283,
        3284, 3285, 3286, 3287]), tensor([3271, 3272, 3273,    2, 2133,    4,    5,    6,    7,    8, 3274,   12,
        3275, 3276,   21,   22,  205,   30,   31,   34,   32,   33, 2142,   10,
         276, 2803, 1117,   61,   39, 1068, 1069,   45, 1074,   26,   59,  314,
        3277, 3278, 3279,   59,   60,   64,   65,   66,   67,   61,  174, 2812,
          66,  835,  760, 3072, 3280,  183,  175,  176, 3281, 3282, 3283, 3284,
        3285, 3286, 3287, 3288]))


In [20]:
# The Decoder Only Text Generation Transformer Model.
# This function is a utility function that creates a mask that is used in the
# attention mechanism of the Transformer model. It ensures that while predicting
# the next word, the model sees the context only till the previous words.
# This is a key concept in decoder-only autoregressive text generation models.

def generate_square_subsequent_mask(sz):
    """
    Generate a square mask for the sequence. The masked positions are filled with float('-inf').
    Unmasked positions are filled with float(0.0).
    """
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [21]:
# This class defines the Positional Encoding for the tokens.

class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model, dropout=0.1):
        """
        :param max_len: Input length sequence.
        :param d_model: Embedding dimension.
        :param dropout: Dropout value (default=0.1)
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        """
        Inputs of forward function
        :param x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [22]:
# The TextGen class creates the final text generation Transformer model by
# combining the above components and adding the missing ones as well.

class TextGen(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads):
        super(TextGen, self).__init__()
        self.pos_encoder = PositionalEncoding(max_len=SEQUENCE_LENGTH, d_model=embed_dim)
        self.emb = nn.Embedding(vocab_size, embed_dim)
        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer=self.decoder_layer,
            num_layers=num_layers,
        )
        self.linear = nn.Linear(embed_dim, vocab_size)
        self.dropout = nn.Dropout(0.2)

    # Positional encoding is required. Else the model does not learn.
    def forward(self, x):
        emb = self.emb(x)

        # Generate input sequence mask with shape (SEQUENCE_LENGTH, SEQUENCE_LENGTH)
        input_mask = generate_square_subsequent_mask(x.size(1)).to(x.device)

        x = self.pos_encoder(emb)
        x = self.decoder(x, memory=x, tgt_mask=input_mask, memory_mask=input_mask)
        x = self.dropout(x)
        out = self.linear(x)
        return out

In [23]:
vocab_size = len(word_to_int)

print(vocab_size)

20413


In [24]:
# Before we train the model, let’s define some hyperparameters and
# initialize the Transformer model as well.

epochs = 20
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextGen(
    vocab_size=vocab_size,
    embed_dim=100,
    num_layers=2,
    num_heads=2,
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
print(model)

# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")

total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.\n")

TextGen(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (emb): Embedding(20413, 100)
  (decoder_layer): TransformerDecoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
    )
    (linear1): Linear(in_features=100, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=100, bias=True)
    (norm1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (dropout3): Dropout(p=0.1, inplace=False)
  )
  (decoder): Transfor

In [25]:
# Training.

def train(model, epochs, dataloader, criterion):
    model.train()
    for epoch in range(epochs):
        running_loss = 0
        for input_seq, target_seq in dataloader:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            outputs = model(input_seq)
            target_seq = target_seq.contiguous().view(-1)
            outputs = outputs.view(-1, vocab_size)

            loss = criterion(outputs, target_seq.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.detach().cpu().numpy()
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch} loss: {epoch_loss:.3f}")

In [26]:
%%time

train(model, epochs, train_dataloader, criterion)

Epoch 0 loss: 3.498
Epoch 1 loss: 1.928
Epoch 2 loss: 1.517
Epoch 3 loss: 1.267
Epoch 4 loss: 1.093
Epoch 5 loss: 0.965
Epoch 6 loss: 0.864
Epoch 7 loss: 0.788
Epoch 8 loss: 0.727
Epoch 9 loss: 0.679
Epoch 10 loss: 0.640
Epoch 11 loss: 0.606
Epoch 12 loss: 0.578
Epoch 13 loss: 0.554
Epoch 14 loss: 0.532
Epoch 15 loss: 0.515
Epoch 16 loss: 0.498
Epoch 17 loss: 0.484
Epoch 18 loss: 0.471
Epoch 19 loss: 0.459
CPU times: user 41min 45s, sys: 11.7 s, total: 41min 56s
Wall time: 42min 17s


In [27]:
checkpoint = {'model': TextGen(vocab_size=vocab_size, embed_dim=100, num_layers=2, num_heads=2),
              'state_dict': model.state_dict(),
              'optimizer' : optimizer.state_dict()}

torch.save(checkpoint, '/content/drive/My Drive/Диплом_2024/models/transformers_therapy_checkpoint_2.pth')

In [None]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False

    model.eval()

    return model

In [28]:
def return_int_vector(text):
    words = text.split()
    input_seq = torch.LongTensor([word_to_int.get(word, 0) for word in words[-SEQUENCE_LENGTH:]]).unsqueeze(0)
    return input_seq

def sample_next(predictions):
    """
    Greedy sampling.
    """
    # Greedy approach.
    probabilities = F.softmax(predictions[:, -1, :], dim=-1).cpu()
    next_token = torch.argmax(probabilities)
    return int(next_token.cpu())

def text_generator(sentence, generate_length):
    model.eval()
    sample = sentence
    for i in range(generate_length):
        int_vector = return_int_vector(sample)
        if len(int_vector) >= SEQUENCE_LENGTH - 1:
            break
        input_tensor = int_vector.to(device)
        with torch.no_grad():
            predictions = model(input_tensor)
        next_token = sample_next(predictions)
        sample += ' ' + int_to_word[next_token]
    print(sample)
    print('\n')

In [29]:
sentences = [
    "хрипы"
]
generate_length = 5
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: хрипы
хрипы в нижних отделах обоих лёгких




In [30]:
sentences = [
    "жалобы"
]
generate_length = 3
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: жалобы
жалобы на кашель с




In [31]:
sentences = [
    "на момент осмотра"
]
generate_length = 6
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: на момент осмотра
на момент осмотра те вне приступа сопутствующий диагноз хронического




In [32]:
sentences = [
    "шейные лимфоузлы"
]
generate_length = 4
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: шейные лимфоузлы
шейные лимфоузлы не пальпируются мышечная система




In [33]:
sentences = [
    "на кашель"
]
generate_length = 10
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: на кашель
на кашель с трудно отделяемой мокротой anamnesisvitae родился в году рос и




In [34]:
sentences = [
    "аллергический"
]
generate_length = 5
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: аллергический
аллергический ринит жалобы на момент поступления




In [35]:
sentences = [
    "бронхиальная"
]
generate_length = 5
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: бронхиальная
бронхиальная астма неаллергическая форма легкой степени




In [36]:
sentences = [
    "эозинофилия"
]
generate_length = 1
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: эозинофилия
эозинофилия крови




In [37]:
sentences = [
    "сердечный"
]
generate_length = 5
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: сердечный
сердечный толчок не определяется верхушечный толчок




In [38]:
# Тестирование.

model.eval()

preds = []
targets = []

for input_seq, target_seq in test_dataloader:
        input_seq, target_seq = input_seq.to(device), target_seq.to(device)

        with torch.no_grad():
            predictions = model(input_seq)

        target_seq = target_seq.contiguous().view(-1)
        target_seq_n = target_seq.cpu().numpy()
        target_seq_n = list(target_seq_n)



        predictions_np = predictions.cpu().numpy()
        batch_preds = np.argmax(predictions_np, axis=2)


        batch_preds_l = batch_preds.ravel()
        batch_preds_l = list(batch_preds_l)

        targets.extend(target_seq_n)
        preds.extend(batch_preds_l)



In [39]:
# Метрики качества на тестовой выборке.

print(classification_report(targets, preds))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           2       1.00      1.00      1.00         5
           4       1.00      1.00      1.00         7
           5       1.00      1.00      1.00         8
           6       1.00      0.33      0.50         9
           7       0.99      1.00      0.99        74
           8       0.99      1.00      0.99        75
          10       1.00      1.00      1.00        25
          12       1.00      1.00      1.00        13
          17       0.00      0.00      0.00         0
          21       1.00      1.00      1.00        16
          22       0.96      1.00      0.98        52
          23       0.97      0.98      0.97       384
          25       0.98      0.94      0.96       192
          26       0.87      0.98      0.92       291
          30       0.50      0.79      0.61        19
          31       1.00      0.80      0.89        20
          32       0.98      0.96      0.97       790
          33       0.84    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
