##Imports

In [81]:
import torch
import pandas as pd
from collections import Counter
import string
import re
import nltk
import nltk.tokenize
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
PATH_TO_CSV_FILE = '/content/drive/MyDrive/youtoxic_english_1000.csv'

##Model class:


In [80]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

##Tokenizer: 
нужно было доработать класс под себя, времени не было, так что я скопировал из рекомендованного репозитория и доделал

In [63]:
class Tokenizer:
    """Tokenizer splits text into words.
    Args:
        - language: Text language.
        - download: If true, download required NLTK data packages during init.
    Inputs:
        - text: Input text.
    Outputs:
        Sequence of sentences, each sentence is a sequence of words.
        Each sentence ends with punctuation token.
    """
    def __init__(self, language="russian", download=True):
        if download:
            nltk.download("punkt", quiet=True)
        self._language = language

    def __call__(self, text):
        text = text.translate(str.maketrans('', '', string.punctuation))
        sentences = []
        for sentence in nltk.tokenize.sent_tokenize(text, language=self._language):
            sentences.append(list(map(str.lower, nltk.tokenize.word_tokenize(sentence, language=self._language))))
        return sentences

##Dataset:

In [73]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        args,
    ):
        self.args = args
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        text = pd.read_csv(PATH_TO_CSV_FILE)['Text'].str.cat(sep = ' ')
        text = text.translate(str.maketrans('', '', string.punctuation))
        text = re.sub(r'^https?:\/\/.*[\r\n]*', '', text, flags=re.MULTILINE)
        tk = Tokenizer('english')
        return tk(text)[0]

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.args['sequence_length']

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.args['sequence_length']]),
            torch.tensor(self.words_indexes[index+1:index+self.args['sequence_length']+1]),
        )

##Train func:

In [82]:
def train(dataset, model, print_every_n_batch, args):
    model.train()

    dataloader = DataLoader(dataset, batch_size=args['batch-size'])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(args['max-epochs']):
        print(f'----------epoch #{epoch+1}-----------')
        state_h, state_c = model.init_state(args['sequence_length'])

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()
            if batch % print_every_n_batch == 0:
              print({ 'batch': batch, 'loss': loss.item() })




##Predict func:

In [40]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

##Training: 
обучение модели лучше запустить заново, из-за отсуствия времени не успел толком прогнать обучение.

In [84]:
args = {'sequence_length': 4, 
        'max-epochs': 1, 
        'batch-size': 256}


dataset = Dataset(args)
model = Model(dataset)

train(dataset, model, 50, args)
print()
print(predict(dataset, model, text='hey'))

----------epoch #1-----------
{'batch': 0, 'loss': 8.490377426147461}
{'batch': 50, 'loss': 7.079318523406982}
{'batch': 100, 'loss': 6.960036754608154}

['hey', 'fist', 'frustrated', 'comments', 'nasty', 'lootedstole', 'red', 'if', 'backward', 'you', 'by', 'says', 'their', 'for', 'for', 'look', 'a', 'robberybut', 'fucked', 'tobacco', 'theyd', 'people', 'it', 'aggressive', 'extremely', 'some', 'handle', 'even', 'be', 'which', 'can', 'a', 'and', 'bring', 'by', 'i', 'the', 'stepping', 'are', 'asian', 'cant', 'focused', 'hahahaha', 'of', 'a', 'understand', 'use', 'right', 'dont', 'their', 'dont', 'let', 'police', 'is', 'stand', 'is', 'if', 'fuck', '314', 'black', 'rapping', 'the', 'was', 'to', 'children', 'questions', 'get', 'allow', 'head', 'witness', 'the', 'a', 'yourself', 'and', 'the', 'in', 'everyone', 'know', 'never', 'that', 'again', 'i', 'the', 'the', 'she', 'but', 'i', 'every', 'to', 'to', 'tshirts', 'blacks', 'addict', 'the', 'fascinating', 'entitled', 'work', 'almost', 'drug', 