In [None]:
from CharRNN import CharRNN
import torch, torch.optim as optim, torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from transformers import PreTrainedTokenizerFast
from onehotencoder import onehotencoder
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
vocab_size = 50
embedded_dim = 768
hidden_dim = 768
num_layers = 3
dropout = .2
learning_rate = 0.001
num_epochs = 10
batch_size = 64

In [None]:
endecode = onehotencoder()

class SequenceDataset(Dataset):
    def __init__(self, file_path, encoder):
        self.file_path = file_path
        self.encoder = encoder
        with open(file_path, 'r') as f:
            self.lines = f.readlines()

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        sequence = self.lines[idx].strip()
        input_tensor = self.encoder.encode_sequence(sequence)
        target_tensor = self.encoder.encode_sequence(sequence, targets=True)
        return input_tensor, target_tensor

dataset = SequenceDataset('data/train.csv', endecode)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers= 12)

In [None]:
charRNN = CharRNN(
    vocab_size,
    embedded_dim,
    hidden_dim,
    num_layers,
    dropout,
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(charRNN.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for batch_inputs, batch_targets in dataloader:
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)
        
        logits = charRNN(batch_inputs)

        targets_flat = torch.argmax(batch_targets, dim=2).reshape(-1)
        logits_flat = logits.reshape(batch_inputs.size(0) * batch_inputs.size(1), -1)
        
        loss = criterion(logits_flat, targets_flat)
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

In [None]:
currenToken = endecode.encode('[BOS]').to(device)
charRNN.eval()
generation = []
with torch.no_grad():
    for i in range(vocab_size):
        if currenToken.dim() == 2:
            currenToken = currenToken.unsqueeze(0) 
        logits = charRNN(currenToken)
        probs = nn.functional.softmax(logits.squeeze(0)[-1], dim=0)
        p = 0.9
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=0)
        sorted_indices_to_remove = cumulative_probs > p
        sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
        sorted_indices_to_remove[0] = False
        indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
        filtered_probs = probs.masked_fill(indices_to_remove, 0).clone()
        filtered_probs = filtered_probs / filtered_probs.sum()
        next_token_index = torch.multinomial(filtered_probs, 1).item()
        next_token = torch.zeros(vocab_size)
        next_token[next_token_index] = 1
        char = endecode.decode(next_token)
        if next_token_index == endecode.encode('[EOS]').argmax().item():
            break
        generation.append(char)
        currenToken = next_token.unsqueeze(0).unsqueeze(0).to(device)

print(''.join(generation))