In [1]:
from CharRNN import CharRNN
import torch, torch.optim as optim, torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from onehotencoder import onehotencoder
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
#Basic one hot encoder i made to encode and decode both characters and sequences
endecode = onehotencoder()
#Hyperparameters
vocab_size = onehotencoder.getVocabSize(self = endecode)
hidden_dim = 150
n_gram = 1
learning_rate = 1e-6
num_epochs = 15
batch_size = 256
temp = 1
p = .9
eps = .001

In [4]:
#Torch dataset because the processed inputs and outputs were over 60 gb in size

class SequenceDataset(Dataset):
    def __init__(self, file_path, encoder):
        self.file_path = file_path
        self.encoder = encoder
        with open(file_path, 'r') as f:
            self.lines = f.readlines()
    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        sequence = self.lines[idx].strip()
        input_tensor = self.encoder.encode_sequence(sequence)
        target_tensor = self.encoder.encode_sequence(sequence, targets=True)
        return input_tensor, target_tensor

#Load the dataset for working
dataset = SequenceDataset('data/train.csv', endecode)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers= 3)

In [5]:
#Declare RNN with vocab size, hidden dim size
charRNN = CharRNN(vocab_size, hidden_dim).to(device)

#Using basic cross entropy loss
criterion = nn.MSELoss(reduction='mean')

#AdamW
optimizer = optim.AdamW(charRNN.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=3)

#Typical training loop
for epoch in range(num_epochs):
    loss_avg = []
    for batch_inputs, batch_targets in dataloader:
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)
        optimizer.zero_grad()
        
        logits, mu, std = charRNN(batch_inputs)

        reconstruction_loss = criterion(logits, batch_targets)
        KL_loss = -0.5 * torch.sum(1 + torch.log(std.pow(2)) - mu.pow(2) - std.pow(2))
        partial_loss = reconstruction_loss + KL_loss

        loss = partial_loss
        loss.backward()

        optimizer.step()
        loss_avg.append(loss.item())
    avg = torch.mean(torch.Tensor(loss_avg)).item()
    scheduler.step(avg)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg}")

KeyboardInterrupt: 

In [None]:
#This is a bit wonky as its turning the output into a probability distribution and then takes the smallest group of logits to add up to the probability of top_p then samples those
def top_p_filtering(logits_p, top_p, temp_p):
    probs = nn.functional.softmax(logits_p.squeeze(0)[-1] / temp_p, dim=0)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=0) 
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = False
    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
    filtered_probs = probs.masked_fill(indices_to_remove, 0).clone()
    filtered_probs = filtered_probs / filtered_probs.sum()
    next_token_idx = torch.multinomial(filtered_probs, 1).item()
    return next_token_idx

In [None]:
#Inputs the start token then does top p sampling until it generates the stop token[EOS] to it hits 200 characters whatever comes first.
currenToken = endecode.encode('[BOS]').to(device)
charRNN.to(device)
charRNN.eval()
generation = []
with torch.no_grad():
    while True:
        if currenToken.dim() == 2:
            currenToken = currenToken.unsqueeze(0)
        logits, _ , _ = charRNN(currenToken)
        next_token_index = top_p_filtering(logits, p, temp)
        next_token = torch.zeros(vocab_size)
        next_token[next_token_index] = 1
        char = endecode.decode(next_token)
        if char == '[EOS]': break
        generation.append(char)
        currenToken = next_token.unsqueeze(0).unsqueeze(0).to(device)

print(''.join(generation))

In [None]:
torch.save(charRNN,'Models/charRNNnoN-gram.pt')

In [None]:

charRNN = torch.load('Models/charRNNnoN-gram.pt', weights_only=False)
currenToken = endecode.encode('[BOS]').to(device)
charRNN.to(device)
charRNN.eval()
generations = []
for i in range(int(5e4)):
    generation = []
    with torch.no_grad():
        while True:
            if currenToken.dim() == 2:
                currenToken = currenToken.unsqueeze(0)
            logits = charRNN(currenToken)
            next_token_index = top_p_filtering(logits, p, temp)
            next_token = torch.zeros(vocab_size)
            next_token[next_token_index] = 1
            char = endecode.decode(next_token)
            if char == '[EOS]': break
            generation.append(char)
            currenToken = next_token.unsqueeze(0).unsqueeze(0).to(device)

    generations.append(''.join(generation))