In [None]:
from CharRNN import CharRNN
import torch, torch.optim as optim, torch.nn as nn
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from onehotencoder import OneHotEncoder
import time
import pandas as pd
import numpy as np
from rdkit import RDLogger, Chem
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
#Basic one hot encoder I made to encode and decode both characters and sequences
endecode = OneHotEncoder()
#Hyperparameters
vocab_size = OneHotEncoder.get_vocab_size(self = endecode)
num_layers = 26
n_gram = 1
dropped_out = 0.2
learning_rate = 1e-2
num_epochs = 40
batch_size = 64
temp = 1
p = .95
b_start = 0
b_end = 1
anneal_epochs = 20
subset_fraction = 0.4

In [None]:


#Torch dataset because the processed inputs and outputs were over 60 gb in size

class SequenceDataset(Dataset):
    def __init__(self, file_path, encoder, n_gram = 1):
        self.n_gram = n_gram
        self.file_path = file_path
        self.encoder = encoder
        with open(file_path, 'r') as f:
            self.lines = f.readlines()
    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        # grab and encode
        sequence = self.lines[idx].strip()
        seq_input  = self.encoder.encode_sequence(sequence)   # (L, D)
        seq_target = self.encoder.encode_sequence(sequence)   # (L, D)
        L = seq_input.size(0)
        n = self.n_gram

        # how many sliding windows actually have a real “next” char?
        num_windows = L - (n if n > 1 else 1)
        if num_windows <= 0:
            # no valid windows—either skip or return empty tensors
            return torch.empty(0, n, seq_input.size(1)), torch.empty(0, seq_input.size(1))

        # build your n-gram inputs and their true next‐token targets
        # for n=1 this is simply [x[i]] → x[i+1], for n>1 it's the sliding window
        inputs  = [seq_input[i : i + n]     for i in range(num_windows)]
        targets = [seq_target[i + n].unsqueeze(0) for i in range(num_windows)]

        # stack into (T, n, D) and (T, D)
        input_stack  = torch.stack(inputs)           # shape: (num_windows, n, D)
        target_stack = torch.stack(targets).squeeze(1)  # shape: (num_windows, D)

        return input_stack, target_stack

#Load the dataset for working
dataset = SequenceDataset('data/train.csv', endecode, n_gram = n_gram)

full_size = len(dataset)
subset_size = int(full_size * subset_fraction)
all_indices = list(range(full_size))
np.random.shuffle(all_indices)
subset_indices = all_indices[:subset_size]
subset_sampler = SubsetRandomSampler(subset_indices)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers= 3, sampler=subset_sampler)
charRNN = CharRNN(vocab_size, num_layers, n_gram, dropped_out).to(device)

In [None]:
#Using basic cross-entropy loss
criterion = nn.CrossEntropyLoss(ignore_index=27)

#AdamW
optimizer = optim.AdamW(charRNN.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=3)
charRNN.train()
#Typical training loop
for epoch in range(num_epochs):
    start_time = time.time()
    total_epoch_loss = 0.0
    if epoch < anneal_epochs:
        current_beta = b_start + (b_end - b_start) * (epoch / anneal_epochs)
    else:
        current_beta = b_end

    for idx, (batch_inputs, batch_targets) in enumerate(dataloader):
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.squeeze(2).to(device)
        current_batch_size = batch_inputs.size(0)
        seq_len = batch_inputs.size(1)
        batch_inputs = batch_inputs.view(current_batch_size, seq_len, n_gram * vocab_size)
        target_indices = torch.argmax(batch_targets, dim=2).long()

        optimizer.zero_grad()

        hidden = charRNN.init_hidden(current_batch_size).to(device)

        logits, mu, std, hidden = charRNN(batch_inputs, hidden)
        logits_permuted = logits.permute(0, 2, 1)

        reconstruction_loss = criterion(logits_permuted, target_indices)
        kl_loss = -0.5 * torch.sum(1 + torch.log(std.pow(2) + 1e-8) - mu.pow(2) - std.pow(2), dim=1)
        kl_loss = torch.mean(kl_loss)

        loss = reconstruction_loss + kl_loss * current_beta
        loss.backward()

        optimizer.step()
        total_epoch_loss += loss.item()

    avg_epoch_loss = total_epoch_loss / len(dataloader)
    scheduler.step(avg_epoch_loss)

    end_time = time.time()
    epoch_duration = end_time - start_time
    epoch_duration_minutes = int(epoch_duration // 60)
    epoch_duration_seconds = int(epoch_duration % 60)

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_epoch_loss}, Time: {epoch_duration_minutes}m {epoch_duration_seconds}s")
torch.save(charRNN,'Models/charRNN1-gram.pt')

In [None]:
#This is a bit wonky as its turning the output into a probability distribution and then takes the smallest group of logits to add up to the probability of top_p then samples those
def top_p_filtering(logits_p, top_p, temp_p):
    probs = nn.functional.softmax(logits_p.squeeze(0)[-1] / temp_p, dim=0)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=0) 
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = False
    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
    filtered_probs = probs.masked_fill(indices_to_remove, 0).clone()
    filtered_probs = filtered_probs / filtered_probs.sum()
    next_token_idx = torch.multinomial(filtered_probs, 1).item()
    return next_token_idx

def get_compound_token(s, n=n_gram):
    if not isinstance(s, str) or not s or n <= 0:
        return ""

    token_parts = []
    current_length = 0
    string_index = 0

    while current_length < n and string_index < len(s):
        if s[string_index:].startswith('Cl'):
            token_parts.append('Cl')
            current_length += 1
            string_index += 2
        elif s[string_index:].startswith('Br'):
            token_parts.append('Br')
            current_length += 1
            string_index += 2
        else:
            token_parts.append(s[string_index])
            current_length += 1
            string_index += 1

    return "".join(token_parts)

def validate_generation(file_path):
    # initialize variables
    valid_count, invalid_count = 0, 0

    # read the lines of the file
    with open(file_path, 'r') as f:
        # read all lines into sequences
        sequences = f.readlines()

    # count the valid sequences and the invalid sequences
    for sequence in sequences:
        valid = sanitize(sequence) # validate
        if valid == 1: # valid
            valid_count += 1
        else: # invalid
            invalid_count += 1

    # Get the percentage of valid VS invalid sequences
    valid_percentage = valid_count / (valid_count + invalid_count)
    return valid_percentage

def sanitize(sequence):
    # Disable all RDKit warnings
    RDLogger.DisableLog('rdApp.*')

    # check sanitizing for the sequence input
    try:
        # attempt to sanitize
        mol = Chem.MolFromSmiles(sequence, sanitize=True)
        if mol:
            return 1  # valid
        else:
            return 0  # invalid
    except Exception as e:
        print(f"Error sanitizing molecule: {e}")
        return 1  # invalid with error

In [None]:

charRNN = torch.load('Models/charRNN1-gram.pt', weights_only=False).to(device)
if n_gram == 1:
    current_n_gram = endecode.encode('[BOS]').to(device)
else:
    string_series = pd.read_csv('data/train.csv', header=None)[0]
    string_series = string_series[string_series.apply(lambda x: isinstance(x,str) and x !='')]
    top_n_grams = string_series.apply(lambda s: get_compound_token(s, n=n_gram-1))
    top_chars = (top_n_grams.value_counts()/sum(top_n_grams.value_counts())).to_dict()
    token = np.random.choice(list(top_chars.keys()),p=list(top_chars.values()))
    start_token = endecode.encode('[BOS]')
    current_n_gram = endecode.encode_sequence(token,skip_append=True)
    current_n_gram = torch.tensor(np.concatenate((start_token,current_n_gram),axis=0)).to(device)

charRNN.to(device)
charRNN.eval()
generations = []
for i in range(int(1e3)):
    generation = []
    charCount = 0
    if i % 1000 == 0: print(i)
    with torch.no_grad():
        while True:
            if current_n_gram.dim() == 2:
                current_n_gram = current_n_gram.unsqueeze(0)
            logits, _, _, _ = charRNN(current_n_gram)
            next_token_index = top_p_filtering(logits, p, temp)
            next_token = torch.zeros(vocab_size)
            next_token[next_token_index] = 1
            char = endecode.decode(next_token)
            charCount += 1
            if char == '[EOS]' or charCount >= 100: break
            generation.append(char)
            current_n_gram = current_n_gram.squeeze(0).to(device)
            next_token = next_token.to(device)
            current_n_gram = torch.concat([current_n_gram[1:],next_token.unsqueeze(0)],dim=0)
    generations.append(''.join(generation))



In [None]:
with open('data/GRUOnly95P1-gram.txt', 'w') as file:
    for item in generations:
        file.write(f"{item}\n")
print(f"Valid percentage: {validate_generation('data/GRUOnly95P1-gram.txt')}")