In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
import onnxruntime as ort
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import json

In [11]:
class ChatDataset(Dataset):
    def __init__(self, input_sequences, target_sequences, response_types):
        self.input_sequences = input_sequences
        self.target_sequences = target_sequences
        self.response_types = response_types

    def __len__(self):
        return len(self.input_sequences)

    def __getitem__(self, idx):
        return (self.input_sequences[idx], self.target_sequences[idx], self.response_types[idx])

In [12]:
def tokenize(texts, tokenizer, max_length):
    return [torch.tensor(tokenizer(text), dtype=torch.long) for text in texts]

def pad_collate_fn(batch):
    input_seqs, target_seqs, response_types = zip(*batch)
    input_seqs_padded = pad_sequence(input_seqs, batch_first=True, padding_value=0)
    target_seqs_padded = pad_sequence(target_seqs, batch_first=True, padding_value=0)
    response_types = torch.tensor(response_types, dtype=torch.long)
    return input_seqs_padded, target_seqs_padded, response_types

In [13]:
# Загрузка и обработка данных
data = pd.read_csv('chaos_god_dataset.csv')
data['Bot_Response'] = data['Bot_Response'].apply(lambda x: '<start> ' + x + ' <end>')

all_texts = data['User_Message'].tolist() + data['Bot_Response'].tolist()
tokenizer = Counter(' '.join(all_texts).split())
vocab = {word: idx + 1 for idx, (word, _) in enumerate(tokenizer.items())}
vocab_size = len(vocab) + 1

def encode(text, vocab):
    return [vocab[word] for word in text.split() if word in vocab]

input_sequences = tokenize(data['User_Message'].tolist(), lambda x: encode(x, vocab), max_length=None)
target_sequences = tokenize(data['Bot_Response'].tolist(), lambda x: encode(x, vocab), max_length=None)

label_encoder = LabelEncoder()
response_labels_encoded = label_encoder.fit_transform(data['Response_Type'])

X_train, X_test, y_train, y_test, train_resp, test_resp = train_test_split(input_sequences, target_sequences, response_labels_encoded, test_size=0.2, random_state=42)

train_dataset = ChatDataset(X_train, y_train, train_resp)
test_dataset = ChatDataset(X_test, y_test, test_resp)

train_loader = DataLoader(train_dataset, batch_size=64, collate_fn=pad_collate_fn, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, collate_fn=pad_collate_fn, shuffle=False)

max_seq_length = max(max(len(seq) for seq in input_sequences), max(len(seq) for seq in target_sequences))

In [14]:
class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Seq2Seq, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder_lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.decoder_lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, src, trg, hidden, cell, use_teacher_forcing):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.fc.out_features

        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(src.device)
        embedded_src = self.embedding(src)
        encoder_outputs, (hidden, cell) = self.encoder_lstm(embedded_src, (hidden, cell))

        input = trg[:, 0]
        for t in range(1, trg_len):
            input = input.unsqueeze(1)
            embedded_input = self.embedding(input)
            decoder_output, (hidden, cell) = self.decoder_lstm(embedded_input, (hidden, cell))
            output = self.fc(decoder_output.squeeze(1))
            outputs[:, t, :] = output

            if use_teacher_forcing:
                input = trg[:, t]
            else:
                input = output.argmax(1)

        return outputs, hidden, cell

class ResponseClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes):
        super(ResponseClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.fc1 = nn.Linear(embedding_dim, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        embedded = self.embedding(x).mean(dim=1)
        x = torch.relu(self.fc1(embedded))
        x = self.fc2(x)
        return x

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

embedding_dim = 256
hidden_dim = 512
num_classes = len(label_encoder.classes_)

seq2seq_model = Seq2Seq(vocab_size, embedding_dim, hidden_dim).to(device)
classifier_model = ResponseClassifier(vocab_size, embedding_dim, num_classes).to(device)

criterion_seq2seq = nn.CrossEntropyLoss(ignore_index=0)
criterion_classifier = nn.CrossEntropyLoss()

optimizer_seq2seq = optim.Adam(seq2seq_model.parameters(), lr=0.001)
optimizer_classifier = optim.Adam(classifier_model.parameters(), lr=0.001)

def train_model(seq2seq_model, classifier_model, train_loader, criterion_seq2seq, criterion_classifier, optimizer_seq2seq, optimizer_classifier, epochs=10, teacher_forcing_ratio=0.5):
    seq2seq_model.train()
    classifier_model.train()

    for epoch in range(epochs):
        for input_seqs, target_seqs, response_types in train_loader:
            input_seqs, target_seqs, response_types = input_seqs.to(device), target_seqs.to(device), response_types.to(device)

            # Инициализация скрытых состояний и состояний ячеек нулями
            hidden = torch.zeros(1, input_seqs.size(0), hidden_dim).to(device)
            cell = torch.zeros(1, input_seqs.size(0), hidden_dim).to(device)
            
            # Определение использования принуждения учителя
            use_teacher_forcing = torch.tensor([1 if torch.rand(1).item() < teacher_forcing_ratio else 0]).to(device)
            
            # Seq2Seq модель
            optimizer_seq2seq.zero_grad()
            output_seq2seq, hidden, cell = seq2seq_model(input_seqs, target_seqs, hidden, cell, use_teacher_forcing)
            loss_seq2seq = criterion_seq2seq(output_seq2seq.view(-1, vocab_size), target_seqs.view(-1))
            loss_seq2seq.backward()
            optimizer_seq2seq.step()

            # Классификационная модель
            optimizer_classifier.zero_grad()
            output_classifier = classifier_model(input_seqs)
            loss_classifier = criterion_classifier(output_classifier, response_types)
            loss_classifier.backward()
            optimizer_classifier.step()

        print(f'Epoch {epoch+1}/{epochs}, Loss Seq2Seq: {loss_seq2seq.item()}, Loss Classifier: {loss_classifier.item()}')

train_model(seq2seq_model, classifier_model, train_loader, criterion_seq2seq, criterion_classifier, optimizer_seq2seq, optimizer_classifier, epochs=80)

Epoch 1/80, Loss Seq2Seq: 6.77866268157959, Loss Classifier: 1.793488621711731
Epoch 2/80, Loss Seq2Seq: 6.278234958648682, Loss Classifier: 1.4843406677246094
Epoch 3/80, Loss Seq2Seq: 5.521275520324707, Loss Classifier: 1.4282360076904297
Epoch 4/80, Loss Seq2Seq: 5.236615180969238, Loss Classifier: 1.384177803993225
Epoch 5/80, Loss Seq2Seq: 5.177277088165283, Loss Classifier: 1.4297065734863281
Epoch 6/80, Loss Seq2Seq: 5.410035610198975, Loss Classifier: 1.0753049850463867
Epoch 7/80, Loss Seq2Seq: 5.597518444061279, Loss Classifier: 1.018714189529419
Epoch 8/80, Loss Seq2Seq: 4.2173004150390625, Loss Classifier: 1.3303660154342651
Epoch 9/80, Loss Seq2Seq: 4.36333703994751, Loss Classifier: 0.8859329223632812
Epoch 10/80, Loss Seq2Seq: 4.5608320236206055, Loss Classifier: 1.2981430292129517
Epoch 11/80, Loss Seq2Seq: 4.190622806549072, Loss Classifier: 0.9888630509376526
Epoch 12/80, Loss Seq2Seq: 4.9105987548828125, Loss Classifier: 1.038123607635498
Epoch 13/80, Loss Seq2Seq: 4

In [16]:
# Экспорт моделей
src_dummy_input = torch.randint(0, vocab_size, (1, max_seq_length)).long().to(device)
trg_dummy_input = torch.randint(0, vocab_size, (1, max_seq_length)).long().to(device)
hidden_dummy_input = torch.zeros(1, 1, hidden_dim).to(device)
cell_dummy_input = torch.zeros(1, 1, hidden_dim).to(device)
teacher_forcing_dummy_input = torch.tensor([1]).to(device)  # 1 для использования teacher forcing, 0 для его отключения

# Экспорт модели Seq2Seq
torch.onnx.export(
    seq2seq_model, 
    (src_dummy_input, trg_dummy_input, hidden_dummy_input, cell_dummy_input, teacher_forcing_dummy_input),
    "seq2seq_model.onnx",
    input_names=["src", "trg", "hidden", "cell", "use_teacher_forcing"],
    output_names=["output", "hidden_out", "cell_out"],
    dynamic_axes={"src": {0: "batch_size", 1: "sequence"}, "trg": {0: "batch_size", 1: "sequence"}, "output": {0: "batch_size", 1: "sequence"}}
)

# Экспорт модели классификатора
classifier_dummy_input = torch.randint(0, vocab_size, (1, max_seq_length)).long().to(device)
torch.onnx.export(
    classifier_model, 
    classifier_dummy_input, 
    "classifier_model.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    dynamic_axes={"input": {0: "batch_size", 1: "sequence"}, "output": {0: "batch_size"}}
)
# Сохранение словаря vocab в JSON файл
with open("vocab.json", "w") as f:
    json.dump(vocab, f)

# Сохранение энкодера label_encoder в JSON файл
with open("label_encoder.json", "w") as f:
    json.dump(label_encoder.classes_.tolist(), f)

  if use_teacher_forcing:


verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR



In [17]:
def decode_sequence(seq2seq_model, classifier_model, input_seq, vocab, max_length):
    input_seq = torch.tensor(encode(input_seq, vocab)).unsqueeze(0).long().to(device)
    if input_seq.size(1) == 0:  # Если вопроса нет в словаре
        return "Не могу понять вопрос.", "Неизвестно"
    
    seq2seq_model.eval()
    classifier_model.eval()

    with torch.no_grad():
        embedded = seq2seq_model.embedding(input_seq)
        encoder_outputs, (hidden, cell) = seq2seq_model.encoder_lstm(embedded)

        target_seq = torch.tensor([vocab.get('<start>', 0)]).unsqueeze(0).long().to(device)
        decoded_sentence = ''
        for _ in range(max_length):
            embedded = seq2seq_model.embedding(target_seq)
            decoder_output, (hidden, cell) = seq2seq_model.decoder_lstm(embedded, (hidden, cell))
            output = seq2seq_model.fc(decoder_output.squeeze(1))
            top1 = output.argmax(1).item()
            if top1 == vocab.get('<end>', 0):
                break
            decoded_sentence += ' ' + [key for key, value in vocab.items() if value == top1][0]
            target_seq = torch.tensor([top1]).unsqueeze(0).long().to(device)

        response_type = classifier_model(input_seq).argmax(1).item()
        response_label = label_encoder.inverse_transform([response_type])[0]

    return decoded_sentence.strip(), response_label

In [18]:
while True:
    user_input = input("Вы: ")
    if user_input.lower() == '':
        print("До свидания!")
        break
    else:
        decoded_sentence, response_type = decode_sequence(seq2seq_model, classifier_model, user_input, vocab, max_length=max_seq_length)
        print("Ответ бота:", decoded_sentence)
        print("Характер ответа:", response_type)

Вы:  Привет


Ответ бота: Ну что ты хочешь обсудить?
Характер ответа: грубость


Вы:  Как тебя зовут?


Ответ бота: О, это не для
Характер ответа: оскорбительный


Вы:  Кто ты такой?


Ответ бота: Мне без разницы.
Характер ответа: оскорбительный


Вы:  Ты веришь в существование души?


Ответ бота: Душа? Она существует только в их собственных воображениях.
Характер ответа: ироничный


Вы:  


До свидания!


In [None]:
import onnxruntime as ort

# Загрузка модели
ort_session = ort.InferenceSession("seq2seq_model.onnx")

# Подготовка входных данных
src_dummy_input = torch.randint(0, vocab_size, (1, max_seq_length)).long().numpy()
trg_dummy_input = torch.randint(0, vocab_size, (1, max_seq_length)).long().numpy()
hidden_dummy_input = torch.zeros(1, 1, hidden_dim).numpy()
cell_dummy_input = torch.zeros(1, 1, hidden_dim).numpy()

# Выполнение инференса
outputs = ort_session.run(
    None,
    {
        "src": src_dummy_input,
        "trg": trg_dummy_input,
        "hidden": hidden_dummy_input,
        "cell": cell_dummy_input
    },
)

print(outputs)

In [None]:
# Загрузка модели
ort_session = ort.InferenceSession("classifier_model.onnx")

# Подготовка входных данных
classifier_dummy_input = torch.randint(0, vocab_size, (1, max_seq_length)).long().numpy()

# Выполнение инференса
outputs = ort_session.run(
    None,
    {
        "input": classifier_dummy_input
    },
)

print(outputs)