In [None]:
from CharRNN import CharRNNV2
import torch.nn as nn, random, bisect, torch, numpy as np, pandas as pd
from onehotencoder import OneHotEncoder
from rdkit import RDLogger, Chem
from torch.utils.data import Dataset, DataLoader, Sampler, get_worker_info
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
torch.set_float32_matmul_precision("high")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
#Basic one hot encoder I made to encode and decode both characters and sequences
endecode = OneHotEncoder()
#Hyperparameters
vocab_size = OneHotEncoder.get_vocab_size(self = endecode)
num_layers = 3
n_gram = 1
dropped_out = 0.2
hidden_size = 1024
learning_rate = 5e-4
num_epochs = 200
batch_size = 128
temp = 1
p = 1

In [None]:
class FileDataset(Dataset):
    def __init__(self, filepaths, encoder, n_gram):
        self.filepaths = filepaths
        self.encoder   = encoder
        self.n_gram    = n_gram

        # build cumulative counts & line‐offset tables as before
        self.counts  = []
        self.offsets = []
        total = 0
        for path in filepaths:
            offs = []
            with open(path, 'rb') as f:
                while True:
                    pos = f.tell()
                    line = f.readline()
                    if not line:
                        break
                    offs.append(pos)
            total += len(offs)
            self.counts.append(total)
            self.offsets.append(offs)

        # placeholder for per‐worker file handles
        self.file_handles = None

    def __len__(self):
        return self.counts[-1]

    def __getitem__(self, idx):
        # ensure worker has opened files
        if self.file_handles is None:
            raise RuntimeError("file_handles not initialized – did you forget worker_init_fn?")

        # map idx → (file_idx, line_idx)
        file_idx = bisect.bisect_right(self.counts, idx)
        prev     = 0 if file_idx == 0 else self.counts[file_idx-1]
        line_idx = idx - prev

        # seek & read from the already-open file handle
        fh = self.file_handles[file_idx]
        fh.seek(self.offsets[file_idx][line_idx])
        seq = fh.readline().decode('utf-8').strip()

        # your n-gram logic
        seq_enc = self.encoder.encode_sequence(seq)  # (L, D)
        L, D    = seq_enc.shape
        n       = self.n_gram

        windows = [seq_enc[i : i + n]       for i in range(L - n)]
        targets = [seq_enc[i + n].view(1, D) for i in range(L - n)]

        return torch.stack(windows), torch.cat(targets, dim=0)

def worker_init_fn(worker_id):
    worker_info = get_worker_info()
    dataset     = worker_info.dataset
    dataset.file_handles = [
        open(path, 'rb') for path in dataset.filepaths
    ]

class FileBatchSampler(Sampler):
    def __init__(self, counts, batch_size, shuffle=True, drop_last=True, sample_ratio: float = 1.0):
        self.counts     = counts
        self.batch_size = batch_size
        self.shuffle    = shuffle
        self.drop_last  = drop_last
        self.sample_ratio = sample_ratio

        self.batches = []
        prev = 0
        for cum in counts:
            idxs = list(range(prev, cum))
            if shuffle:
                random.shuffle(idxs)
            for i in range(0, len(idxs), batch_size):
                batch = idxs[i : i + batch_size]
                if len(batch) == batch_size or not drop_last:
                    self.batches.append(batch)
            prev = cum

        if shuffle:
            random.shuffle(self.batches)
        if not (0 < sample_ratio <= 1):
            raise ValueError("sample_ratio must be in (0,1]")
        if sample_ratio < 1.0:
            keep_n = int(len(self.batches) * sample_ratio)
            self.batches = random.sample(self.batches, keep_n)

    def __iter__(self):
        yield from self.batches

    def __len__(self):
        return len(self.batches)

filepaths = [f"data/seqs_len{i}.txt" for i in list(range(26,46))]
ds = FileDataset(filepaths, endecode, n_gram=n_gram)
full_sampler = FileBatchSampler(
    counts=ds.counts,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    sample_ratio=1.0
)

all_batches = list(full_sampler)
random.shuffle(all_batches)

val_frac  = 0.10
n_val     = int(len(all_batches) * val_frac)
val_batches   = all_batches[:n_val]
train_batches = all_batches[n_val:]

class ListBatchSampler(Sampler):
    def __init__(self, batch_list):
        self.batch_list = batch_list
    def __iter__(self):
        yield from self.batch_list
    def __len__(self):
        return len(self.batch_list)

train_sampler = ListBatchSampler(train_batches)
val_sampler   = ListBatchSampler(val_batches)

train_loader = DataLoader(
    ds,
    batch_sampler=train_sampler,
    num_workers=10,
    worker_init_fn=worker_init_fn
)

val_loader = DataLoader(
    ds,
    batch_sampler=val_sampler,
    num_workers=10,
    worker_init_fn=worker_init_fn
)
total_steps = len(train_loader) * num_epochs
warmup_steps = int(0.05 * total_steps)


In [None]:
charRNN = CharRNNV2(vocab_size, num_layers, n_gram, total_steps, warmup_steps, learning_rate, hidden_size, dropped_out).to(device)
trainer = Trainer(
    max_epochs=200,
    accelerator="cuda",
    precision='16-mixed',
    accumulate_grad_batches=4,
    logger=TensorBoardLogger("tb_logs", name="char_rnn"),
    callbacks=[
        ModelCheckpoint(monitor="valid_loss", mode="min"),
        EarlyStopping(monitor="valid_loss", patience=5),
    ],
    profiler="pytorch",
)
trainer.fit(charRNN, train_loader, val_loader)

In [None]:
torch.save(charRNN,'Models/charRNNv1-gram.pt')

In [None]:
charRNN = torch.load('Models/charRNNv1-gram.pt', weights_only=False).to(device)