# RNN

## import libraries

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import re

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms.v2 as v2

from torch.utils.data import BatchSampler, SequentialSampler
from torchvision import models

from tqdm import tqdm

## Class RNN


In [2]:
class TextRNN(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.hidden_size = 64
    self.in_features = in_features
    self.out_features = out_features

    self.rnn = nn.RNN(self.in_features, self.hidden_size, batch_first=True)
    self.out = nn.Linear(self.hidden_size, self.out_features)

  def forward(self, x):
    x, h = self.rnn(x)
    y = self.out(h)
    return y

## Class Dataset


In [3]:
class charsDataset(data.Dataset):
  def __init__(self, path, prev_chars=3):
    self.prev_chars = prev_chars

    with open(path, 'r', encoding='utf-8') as f:
      self.text = f.read()
      self.text = self.text.replace('\ufeff', '')
      self.text = re.sub(f'[^А-яA-z0-9.,?;: ]', '', self.text)

    self.text = self.text.lower()
    self.alphabet = set(self.text)
    self.int_to_alpha = dict(enumerate(sorted(self.alphabet)))
    self.alpha_to_int = {v: k for k, v in self.int_to_alpha.items()}
    self.num_characters = len(self.alphabet)
    self.onehots = torch.eye(self.num_characters)

  def __getitem__(self, item):
    _data = torch.vstack([self.onehots[self.alpha_to_int[self.text[x]]] for x in range(item, item + self.prev_chars)])
    ch = self.text[item + self.prev_chars]
    t = self.alpha_to_int[ch]
    return _data, t

  def __len__(self):
    return len(self.text) - 1 - self.prev_chars

## Train

In [4]:
batch=8
d_train = charsDataset('/content/train_data_true', prev_chars=10)
train_loader = data.DataLoader(d_train, batch_size=batch, shuffle=True)

In [5]:
model = TextRNN(d_train.num_characters, d_train.num_characters)

In [6]:
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
loss_f = nn.CrossEntropyLoss()

In [7]:
epoch=100
model.train()

for _e in range(epoch):
  loss_mean = 0
  lm_count = 0

  train_tqdm = tqdm(train_loader, leave=False)
  for x_train, y_train in train_tqdm:
    pred = model(x_train).squeeze(0)
    loss = loss_f(pred, y_train.long())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    lm_count+=1
    loss_mean = 1/lm_count * loss.item() + (1-1/lm_count)*loss_mean
    train_tqdm.set_description(f'[epoch ({_e+1}/{epoch}], loss_mean: {loss_mean:.3f}')



In [8]:
st = model.state_dict()
torch.save(st, 'model_rnn_1.tar')

## Test

In [9]:
model.eval()
predict='Мой дядя самый'.lower()
total=40

for _ in range(total):
  _data = torch.vstack([d_train.onehots[d_train.alpha_to_int[predict[-x]]] for x in range(d_train.prev_chars, 0, -1)])
  p = model(_data.unsqueeze(0)).squeeze(0)
  indx = torch.argmax(p, dim=1)
  predict += d_train.int_to_alpha[indx.item()]

print(predict)

мой дядя самыйся это ны что дыто нечем прдоль у осчаст
