In [1]:
from CharRNN import CharRNN
import torch, torch.optim as optim, torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from onehotencoder import OneHotEncoder
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
#Basic one hot encoder i made to encode and decode both characters and sequences
endecode = OneHotEncoder()
#Hyperparameters
vocab_size = OneHotEncoder.get_vocab_size(self = endecode)
hidden_dim = 768
n_gram = 1
learning_rate = 1e-6
num_epochs = 15
batch_size = 256
temp = 1
p = .95
eps = .001

In [3]:
#Torch dataset because the processed inputs and outputs were over 60 gb in size

class SequenceDataset(Dataset):
    def __init__(self, file_path, encoder, n_gram = 1):
        self.n_gram = n_gram
        self.file_path = file_path
        self.encoder = encoder
        with open(file_path, 'r') as f:
            self.lines = f.readlines()
    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        input_tensors = []
        target_tensors = []
        sequence = self.lines[idx].strip()
        sequence_input = self.encoder.encode_sequence(sequence)
        sequence_target = self.encoder.encode_sequence(sequence,targets = True)
        pad = self.encoder.encode('[PAD]').view(1,-1)
        length = sequence_input.shape[0]
        for i in range(length):
            if self.n_gram == 1:
                if length - (i + 1) < self.n_gram:
                    padding = pad.repeat(self.n_gram - length + i + 1,1)
                    if i == length - 1:
                        input_tensors.append(torch.cat([sequence_input[i:length-1,:],padding],dim=0))
                        target_tensors.append(padding)
                    else:
                        input_tensors.append(torch.cat([sequence_input[i:length-1,:],padding],dim=0))
                        target_tensors.append(torch.cat([sequence_target[i+1:length-1,:], padding],dim=0))
                else:
                    input_tensors.append(sequence_input[i:i+self.n_gram,:])
                    target_tensors.append(sequence_target[i+1:i+self.n_gram+1,:])
            else:
                if length - i - 1 < self.n_gram:
                    padding = pad.repeat(self.n_gram - length + i + 1,1)
                    if i == length - 1:
                        input_tensors.append(torch.cat([sequence_input[i:length-1,:],padding],dim=0))
                        target_tensors.append(padding)
                    else:
                        input_tensors.append(torch.cat([sequence_input[i:length-1,:],padding],dim=0))
                        target_tensors.append(torch.cat([sequence_target[i+1:length-1,:],padding],dim=0))
                else:
                    input_tensors.append(sequence_input[i:i+self.n_gram,:])
                    target_tensors.append(sequence_target[i+1:i+self.n_gram+1,:])
        input_stack = torch.stack(input_tensors)
        target_stack = torch.stack(target_tensors)
        return input_stack, target_stack

#Load the dataset for working
dataset = SequenceDataset('data/train.csv', endecode)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers= 3)

In [4]:
#Declare RNN with vocab size, hidden dim size
charRNN = CharRNN(vocab_size, hidden_dim).to(device)

#Using basic cross entropy loss
criterion = nn.MSELoss(reduction='mean')

#AdamW
optimizer = optim.AdamW(charRNN.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=3)

#Typical training loop
for epoch in range(num_epochs):
    epoch_losses = torch.zeros(batch_size,device=device)

    for idx, (batch_inputs, batch_targets) in enumerate(dataloader):
        batch_inputs = batch_inputs
        batch_targets = batch_targets
        batch_size = batch_inputs.size(0)
        sequence_losses = torch.zeros(batch_size,device=device)

        for b in range(batch_size):
            sequence_inputs = batch_inputs[b].to(device)
            sequence_targets = batch_targets[b].to(device)

            optimizer.zero_grad()

            hidden = charRNN.init_hidden(1).to(device)

            sequence_loss = torch.zeros(len(sequence_inputs), device=device)
            for i in range(len(sequence_inputs)):
                ngram_input = sequence_inputs[i].unsqueeze(0)
                ngram_target = sequence_targets[i].unsqueeze(0)

                logits, mu, std, hidden = charRNN(ngram_input, hidden)
                hidden = hidden.detach()

                reconstruction_recon_loss = criterion(logits, ngram_target)
                KL_loss = -0.5 * torch.sum(1 + torch.log(std.pow(2)) - mu.pow(2) - std.pow(2))

                sequence_loss[i] = reconstruction_recon_loss + KL_loss

            sequence_losses[b] = sequence_loss.sum()
        loss = sequence_losses.mean()
        loss.backward()
        epoch_losses[idx] = loss.item()
    avg_epoch_loss = epoch_losses.mean().item()
    scheduler.step(avg_epoch_loss)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_epoch_loss}")


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/abog/PycharmProjects/Drug-Discovery/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/abog/PycharmProjects/Drug-Discovery/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/abog/PycharmProjects/Drug-Discovery/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/abog/PycharmProjects/Drug-Discovery/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 212, in collate
    collate(samples, collate_fn_map=collate_fn_map)
  File "/home/abog/PycharmProjects/Drug-Discovery/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/abog/PycharmProjects/Drug-Discovery/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 271, in collate_tensor_fn
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Trying to resize storage that is not resizable


In [None]:
#This is a bit wonky as its turning the output into a probability distribution and then takes the smallest group of logits to add up to the probability of top_p then samples those
def top_p_filtering(logits_p, top_p, temp_p):
    probs = nn.functional.softmax(logits_p.squeeze(0)[-1] / temp_p, dim=0)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=0) 
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = False
    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
    filtered_probs = probs.masked_fill(indices_to_remove, 0).clone()
    filtered_probs = filtered_probs / filtered_probs.sum()
    next_token_idx = torch.multinomial(filtered_probs, 1).item()
    return next_token_idx

In [None]:
#Inputs the start token then does top p sampling until it generates the stop token[EOS] to it hits 200 characters whatever comes first.
currenToken = endecode.encode('[BOS]').to(device)
charRNN.to(device)
charRNN.eval()
generation = []
with torch.no_grad():
    while True:
        if currenToken.dim() == 2:
            currenToken = currenToken.unsqueeze(0)
        logits, _ , _ = charRNN(currenToken)
        next_token_index = top_p_filtering(logits, p, temp)
        next_token = torch.zeros(vocab_size)
        next_token[next_token_index] = 1
        char = endecode.decode(next_token)
        if char == '[EOS]': break
        generation.append(char)
        currenToken = next_token.unsqueeze(0).unsqueeze(0).to(device)

print(''.join(generation))

In [None]:
torch.save(charRNN,'Models/charRNNnoN-gram.pt')

In [None]:

charRNN = torch.load('Models/charRNNnoN-gram.pt', weights_only=False)
currenToken = endecode.encode('[BOS]').to(device)
charRNN.to(device)
charRNN.eval()
generations = []
for i in range(int(5e4)):
    generation = []
    with torch.no_grad():
        while True:
            if currenToken.dim() == 2:
                currenToken = currenToken.unsqueeze(0)
            logits = charRNN(currenToken)
            next_token_index = top_p_filtering(logits, p, temp)
            next_token = torch.zeros(vocab_size)
            next_token[next_token_index] = 1
            char = endecode.decode(next_token)
            if char == '[EOS]': break
            generation.append(char)
            currenToken = next_token.unsqueeze(0).unsqueeze(0).to(device)

    generations.append(''.join(generation))

In [None]:
with open('GRUOnly95P.txt', 'w') as file:
    for item in generations:
        file.write(f"{item}\n")
