<a href="https://colab.research.google.com/github/SachinPrasanth777/PyTorch/blob/main/RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
with open('anna.txt','r') as f:
  text = f.read()
print(text[:100])

Chapter 1


Happy families are all alike; every unhappy family is unhappy in its own
way.

Everythin


In [5]:
chars = tuple(set(text))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii,ch in int2char.items()}
encoded = np.array([char2int[ch] for ch in text])
print(encoded[:100])

[46 31 62 64 67 77 38 30 65 75 75 75 39 62 64 64  1 30 19 62 26 48 10 48
 77 70 30 62 38 77 30 62 10 10 30 62 10 48 28 77 79 30 77 33 77 38  1 30
 82 57 31 62 64 64  1 30 19 62 26 48 10  1 30 48 70 30 82 57 31 62 64 64
  1 30 48 57 30 48 67 70 30 17  5 57 75  5 62  1  2 75 75 81 33 77 38  1
 67 31 48 57]


In [7]:
def one_hot_encode(arr,n_labels):
  one_hot = np.zeros((arr.size,n_labels), dtype = np.float32)
  one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
  one_hot = one_hot.reshape((*arr.shape,n_labels))
  return one_hot

In [9]:
test_seq = np.array([[3, 5, 1]])
one_hot = one_hot_encode(test_seq, 8)
print(one_hot)

[[[0. 0. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 1. 0. 0.]
  [0. 1. 0. 0. 0. 0. 0. 0.]]]


In [15]:
def get_batches(arr,batch_size,seq_length):
  batch_size_total = batch_size * seq_length
  n_batches = len(arr)//batch_size_total
  arr = arr[:n_batches * batch_size_total]
  arr = arr.reshape((batch_size, -1))
  x = np.zeros((batch_size, seq_length), dtype=arr.dtype)
  y = np.zeros_like(x)
  for n in range(0, arr.shape[1], seq_length):
    x = arr[:, n:n+seq_length]
    y = np.zeros_like(x)
    try:
      y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+seq_length]
    except IndexError:
      y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]
    yield x,y

In [16]:
batches = get_batches(encoded,8,50)
x, y = next(batches)

In [17]:
print('x\n', x[:10, :10])
print('\ny\n', y[:10, :10])

x
 [[46 31 62 64 67 77 38 30 65 75]
 [70 17 57 30 67 31 62 67 30 62]
 [77 57 13 30 17 38 30 62 30 19]
 [70 30 67 31 77 30 51 31 48 77]
 [30 70 62  5 30 31 77 38 30 67]
 [51 82 70 70 48 17 57 30 62 57]
 [30 45 57 57 62 30 31 62 13 30]
 [68 32 10 17 57 70 28  1  2 30]]

y
 [[31 62 64 67 77 38 30 65 75 75]
 [17 57 30 67 31 62 67 30 62 67]
 [57 13 30 17 38 30 62 30 19 17]
 [30 67 31 77 30 51 31 48 77 19]
 [70 62  5 30 31 77 38 30 67 77]
 [82 70 70 48 17 57 30 62 57 13]
 [45 57 57 62 30 31 62 13 30 70]
 [32 10 17 57 70 28  1  2 30  9]]


In [21]:
train_on_gpu = torch.cuda.is_available()
if(train_on_gpu):
    print('Training on GPU!')
else:
    print('No GPU available, training on CPU; consider making n_epochs very small.')

No GPU available, training on CPU; consider making n_epochs very small.


In [35]:
class RNN(nn.Module):
  def init(self,tokens,n_hidden=256,n_layers=2,drop_prob=0.5,lr=0.001):
    super().__init__()
    self.drop_prob = drop_prob
    self.n_layers = n_layers
    self.n_hidden = n_hidden
    self.lr = lr
    self.chars = tokens
    self.int2chars = dict(enumerate(self.chars))
    self.lstm = nn.lstm(len(self.chars), n_hidden, n_layers, dropout=drop_prob, batch_first=True)
    self.dropout = nn.Dropout(drop_prob)
    self.fc = nn.Linear(n_hidden, len(self.chars))

  def forward(self, x, hidden):
        r_output, hidden = self.lstm(x, hidden)
        out = self.dropout(r_output)
        out = out.contiguous().view(-1, self.n_hidden)
        out = self.fc(out)
        return out, hidden

  def init_hidden(self,batch_size):
    weight = next(self.parameters()).data
    if(train_on_gpu):
      hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),
                weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())
    else:
      hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),
                weight.new(self.n_layers,batch_size, self.n_hidden).zero_())
    return hidden