In [44]:
import torch 
import math
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [45]:
text = open('lotr1.txt', 'r').read().lower()
text = text.replace('\xa0', ' ')
len(text)

938287

In [46]:
char_to_idx = {char: idx for (idx, char) in enumerate(list(set(text)))}
idx_to_char = {idx: char for (char, idx) in char_to_idx.items()}
one_hot_size = len(char_to_idx)

In [70]:
class OneHotDataset:
    def __init__(self, seq_len, chr_step, text, char_to_idx):
        sequences = []
        next_chars = []

        for i in range(0, len(text) - seq_len, chr_step):
            sequences.append(text[i:i+seq_len]) # input sequence
            next_chars.append(text[i+seq_len]) # char to predict
        
        self.x = torch.zeros(len(sequences), seq_len, one_hot_size, dtype=torch.bool).to(device) # shape (L, N ,D)
        self.y = torch.zeros(len(sequences), dtype=torch.bool).to(device)

        for i, sentence in enumerate(sequences):
            for t, char in enumerate(sentence):
                self.x[i, t, char_to_idx[char]] = 1
            self.y[i] = char_to_idx[next_chars[i]]

    def __getitem__(self, idx):
        return self.x[idx, :, :], self.y[idx]

    def __len__(self):
        return self.y.shape[0]

data = OneHotDataset(50, 4, text, char_to_idx)

In [71]:
x, _ = data[1]
x.shape

torch.Size([50, 60])

In [75]:
loader = torch.utils.data.DataLoader(data, batch_size=32)

for x, y in loader:
    print(x.transpose(0, 1).shape)
    break

torch.Size([50, 32, 60])
