In [3]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import random

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [5]:
text = open("data.txt", "r").read()

vocab = sorted(set(text))
vocab_size = len(vocab)

itos = {i: s for i, s in enumerate(vocab)}
stoi = {s: i for i, s in itos.items()}

In [6]:
class LSTM(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super().__init__()
    # Forget gate
    self.ft = nn.Linear(input_size + hidden_size, hidden_size)
    self.ftgate = nn.Sigmoid()
    # Update gate
    self.it = nn.Linear(input_size + hidden_size, hidden_size)
    self.itgate = nn.Sigmoid()
    self.ct = nn.Linear(input_size + hidden_size, hidden_size)
    self.ctact = nn.Tanh()
    # Output gate
    self.ot = nn.Linear(input_size + hidden_size, hidden_size)  # Hidden forget
    self.otgate = nn.Sigmoid()
    self.tanh_cell = nn.Tanh()
    self.out = nn.Linear(hidden_size, output_size)

  def init_states(self, device):
    hidden = torch.zeros((1, hidden_size), device=device)
    cell = torch.zeros((1, hidden_size), device=device)
    return hidden, cell

  def forward(self, input, hidden, cell):
    input_hidden = torch.cat((input, hidden), dim=1)
    ft = self.ftgate(self.ft(input_hidden))  # Forget Gate
    # Update Gate
    it = self.itgate(self.it(input_hidden))
    c_t = self.ctact(self.ct(input_hidden))
    gain = it * c_t
    # Update cell state
    cell = cell * ft + gain
    # Output Gate
    ot = self.otgate(self.ot(input_hidden))
    tanh_cell = self.tanh_cell(cell)
    hidden = ot * tanh_cell  # Update hidden state
    logits = self.out(hidden)

    return logits, hidden, cell

In [16]:
context_window = 32
input_size = output_size = vocab_size
hidden_size = 128

lstm = LSTM(input_size, hidden_size, output_size).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=0.001)

In [22]:
def generate(start="", len_=20):
  chars = [stoi[start]]

  hidden, cell = lstm.init_states(device)
  for i in range(len_):
    input = chars[-1]
    input = F.one_hot(torch.tensor(input).type(torch.LongTensor), num_classes=vocab_size).to(device)
    inp = input.unsqueeze(0)
    logits, hidden, cell = lstm(inp, hidden, cell)
    probs = torch.softmax(logits, dim=1)
    ix = torch.multinomial(probs, 1).item()
    chars.append(ix)
  return "".join([itos[ch] for ch in chars])

In [18]:
from tqdm import tqdm

In [19]:
N_EPOCHS = 50
CHUNK_SIZE = 2000

for epoch in range(N_EPOCHS):
  randix = torch.randint(0, len(text)-context_window-1, (1,)).item()
  for i in tqdm(range(randix, randix+CHUNK_SIZE)):
    sample = text[i:i+context_window]
    sample = [stoi[ch] for ch in sample]
    ohe_sample = F.one_hot(torch.tensor(sample).type(torch.LongTensor), num_classes=vocab_size).to(device)
    target = text[i+1:i+1+context_window]
    enc_target = torch.tensor([stoi[ch] for ch in target]).to(device)

    hidden, cell = lstm.init_states(device)
    total_loss = 0
    for i in range(len(ohe_sample)):
      inp = ohe_sample[i].unsqueeze(0)
      tg = enc_target[i].unsqueeze(0)
      logits, hidden, cell = lstm(inp, hidden, cell)
      loss = loss_fn(logits, tg)
      total_loss = total_loss + loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

  if (epoch+1) % 10 == 0:
    print(total_loss.item())
    print(generate(random.choice(vocab)))

100%|██████████| 2000/2000 [01:22<00:00, 24.33it/s]
100%|██████████| 2000/2000 [01:23<00:00, 24.03it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.46it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.41it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.38it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.28it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.53it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.48it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.22it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.40it/s]


59.70364761352539
пачто?
Я бав Онвоко ц


100%|██████████| 2000/2000 [01:21<00:00, 24.52it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.52it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.28it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.38it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.47it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.21it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.52it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.51it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.42it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.26it/s]


32.81169509887695
хось что Джонул» о.. 


100%|██████████| 2000/2000 [01:22<00:00, 24.39it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.46it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.46it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.50it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.44it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.30it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.45it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.64it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.44it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.30it/s]


30.660053253173828
? !
Я остонак, в это 


100%|██████████| 2000/2000 [01:21<00:00, 24.45it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.46it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.51it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.55it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.53it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.26it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.43it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.65it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.20it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.38it/s]


59.946922302246094
А не хорожный самойст


100%|██████████| 2000/2000 [01:21<00:00, 24.48it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.52it/s]
100%|██████████| 2000/2000 [01:21<00:00, 24.46it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.28it/s]
100%|██████████| 2000/2000 [01:23<00:00, 24.05it/s]
100%|██████████| 2000/2000 [01:24<00:00, 23.79it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.13it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.16it/s]
100%|██████████| 2000/2000 [01:23<00:00, 23.95it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.26it/s]

41.71445083618164

Наша об передчнуюй. 





In [32]:
torch.save(lstm.state_dict(), "/content/lstm_params.pth")

In [31]:
print(generate("Д", 500))

Дэн, я исте отдазшо увнта,
ющи вот чеиявее оставаелачи, ссем облаЗай. амен в званай, офис. Зартеле попаяске.
Ты этовым, перех ршек понянть ты шенивну это... мяленке впривёл денееЧко деньги, тожако, в боте. уечка привёз мне всжавал очемьцу дива тубрыхатьлись ты продишь конве дне дошён верина работы. В можницая буд ес попрум втлно, помачим пам адисленис.
Ты трос.
Начивать моде нужедживаный довадцапкаться ты поисте. ь сотня есть. Я конверт! Взкоридивас не я на ещё я, дом. у привёз молй довайте кходу


### Damn... These are hallucinations of course, but yo... styling, words, ideas and main topics (money, crime, job)
<img src="https://media1.tenor.com/m/FFUaxpGNKKkAAAAd/%D0%BE%D0%B7%D0%BE%D0%BD-ozon.gif" width=40%>