In [None]:
from datasets import load_dataset

dataset_separated = load_dataset("Thecoder3281f/MIT_separated", "normal")
dataset_mixed = load_dataset("Thecoder3281f/MIT_mixed", "normal")

In [None]:

from transformers import PreTrainedTokenizerFast

In [None]:
# # Build vocab from your space-tokenized SMILES
# def build_vocab_from_dataset(dataset, fields=["input", "target"]):
#     vocab = set()
#     splits = ["train", "val", "test"]
#     for split in splits:
#         for ex in dataset[split]:
#             for f in fields:
#                 vocab.update(ex[f].split())

#     return vocab

# vocab = build_vocab_from_dataset(dataset)
# vocab.update(["[PAD]", "[UNK]", "<s>", "</s>"])

# #print(vocab)
# vocab = {tok: i for i, tok in enumerate(sorted(vocab), start=0)}

# #print(vocab)

# # Create WordLevel tokenizer
# tok = Tokenizer(WordLevel(vocab=vocab, unk_token="[UNK]"))
# tok.pre_tokenizer = Whitespace()

# # Wrap as a Hugging Face tokenizer
# hf_tokenizer = PreTrainedTokenizerFast(
#     tokenizer_object=tok,
#     unk_token="[UNK]",
#     pad_token="[PAD]",
#     eos_token="</s>",
#     bos_token="<s>",
# )

# hf_tokenizer.save_pretrained("smiles-whitespace-tokenizer-separated-mit")

In [None]:
tokenizer_separated = PreTrainedTokenizerFast.from_pretrained("smiles-whitespace-tokenizer-separated-mit")
tokenizer_mixed = PreTrainedTokenizerFast.from_pretrained("smiles-whitespace-tokenizer-mixed-mit")

def preprocess(batch, tokenizer):
    model_inputs = tokenizer(
        batch["input"],
        padding="max_length",
        truncation=True,
        max_length=256,
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch["target"],
            padding="max_length",
            truncation=True,
            max_length=256,
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets_separated = dataset_separated.map(lambda batch: preprocess(batch, tokenizer_separated), batched=True, remove_columns=["input", "target"])
tokenized_datasets_mixed = dataset_mixed.map(lambda batch: preprocess(batch, tokenizer_mixed), batched=True, remove_columns=["input", "target"])

tokenized_datasets_separated, tokenized_datasets_mixed

In [None]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


In [None]:
import logging
import numpy as np

# set up logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs

def canonicalize(smiles):
    """Join tokens, parse to molecule, return canonical SMILES or None."""
    try:
        s = smiles.replace(" ", "")
        mol = Chem.MolFromSmiles(s)
        if mol is None:
            return None
        return Chem.MolToSmiles(mol, canonical=True)
    except Exception:
        return None

def tanimoto(a, b):
    """Compute Tanimoto similarity between two SMILES."""
    try:
        ma, mb = Chem.MolFromSmiles(a), Chem.MolFromSmiles(b)
        if not ma or not mb:
            return 0
        fa = AllChem.GetMorganFingerprintAsBitVect(ma, 2)
        fb = AllChem.GetMorganFingerprintAsBitVect(mb, 2)
        return DataStructs.TanimotoSimilarity(fa, fb)
    except Exception:
        return 0

def compute_metrics(eval_pred, tokenizer):
    preds, labels = eval_pred

    # handle tuple
    if isinstance(preds, tuple):
        preds = preds[0]

    # convert logits to token IDs if needed
    preds = np.array(preds)
    if preds.ndim == 3:  # (batch, seq_len, vocab_size)
        preds = np.argmax(preds, axis=-1)

    labels = np.array(labels)


    # ðŸ§  handle both top-1 and top-k outputs
    if preds.ndim == 3:  # (batch, k, seq_len)
        batch_size, k, seq_len = preds.shape
        preds = preds.reshape(batch_size * k, seq_len)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_preds = [decoded_preds[i * k:(i + 1) * k] for i in range(batch_size)]
    else:  # (batch, seq_len)
        decoded_preds = [[tokenizer.decode(p, skip_special_tokens=True)] for p in preds]

    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    top1_correct = 0
    tanimotos = []
    valids = []

    for k_preds, label in zip(decoded_preds, decoded_labels):
        label_c = canonicalize(label)
        

        if label_c is None:
            tanimotos.append(0)
            valids.append(0)
            continue
        valids.append(bool(Chem.MolFromSmiles(label_c)))
        best_tani = 0

        for i, pred in enumerate(k_preds):
            p_c = canonicalize(pred)
            if p_c is None:
                continue

            tani = tanimoto(p_c, label_c)
            best_tani = max(best_tani, tani)

            if p_c == label_c:
                if i == 0:
                    top1_correct += 1
                break

        tanimotos.append(best_tani)

    canonical_top1 = top1_correct / len(decoded_labels)
    mean_tanimoto = sum(tanimotos) / len(tanimotos)
    validity = sum(valids) / len(valids)

    logger.info(f"Canonical Top-1 Accuracy: {canonical_top1:.3f}")
    logger.info(f"Mean Tanimoto: {mean_tanimoto:.3f}")
    logger.info(f"Validity: {validity:.3f}")

    return {
        "canonical_top1": canonical_top1,
        "mean_tanimoto": mean_tanimoto,
        "validity": validity,
    }


In [None]:
# train_dataset_small = tokenized_datasets["train"].shuffle(seed=42).select(range(40000))
# val_dataset_small = tokenized_datasets["val"].shuffle(seed=42).select(range(3000))
# test_dataset_small = tokenized_datasets["test"].shuffle(seed=42).select(range(3000))

train_dataset_separated = tokenized_datasets_separated["train"]
val_dataset_separated = tokenized_datasets_separated["val"]
test_dataset_separated = tokenized_datasets_separated["test"]

train_dataset_mixed = tokenized_datasets_mixed["train"]
val_dataset_mixed = tokenized_datasets_mixed["val"]
test_dataset_mixed = tokenized_datasets_mixed["test"]

In [None]:
from transformers import Trainer, TrainingArguments, T5ForConditionalGeneration, EarlyStoppingCallback

In [None]:
def model_init(tokenizer):
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    model.resize_token_embeddings(len(tokenizer))
    return model


def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        logits = logits[0]
    return logits.argmax(dim=-1), labels


def train_loop(sep_or_comb, model, tokenizer, train_dataset, val_dataset, num_steps=20000):
    if sep_or_comb == "separated":
        output_dir = f"prelim-t5-small-mit-sepvscomb-separated"
    else:
        output_dir = f"prelim-t5-small-mit-sepvscomb-mixed"

    args = TrainingArguments(
        output_dir=output_dir,
        eval_strategy="steps",
        save_strategy="steps",
        learning_rate=3e-4,
        # per_device_train_batch_size=64,
        auto_find_batch_size=True,
        per_device_eval_batch_size=16,
        warmup_ratio=0.1,
        max_steps=num_steps,
        save_steps=1000,
        eval_steps=1000,
        # num_train_epochs=1,
        logging_strategy="steps",
        logging_steps=100,
        report_to="tensorboard",
        weight_decay=0.01,
        logging_dir="./logs",
        greater_is_better=True,
        metric_for_best_model="canonical_top1",
        load_best_model_at_end=True,
        gradient_checkpointing=True,
        eval_accumulation_steps=128,
        fp16=True,
        save_total_limit=3,
    )

    # Add early stopping callback
    early_stop_callback = EarlyStoppingCallback(
        early_stopping_patience=3,   # stop if no improvement for 3 evals
        early_stopping_threshold=0.01 # minimum change to qualify as improvement
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        compute_metrics=lambda eval_pred: compute_metrics(eval_pred, tokenizer),
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        callbacks=[early_stop_callback],
    )

    trainer.train(resume_from_checkpoint=False)

    # Save the final model
    trainer.save_model(f"{output_dir}" + str(num_steps))


def test_loop(sep_or_comb, tokenizer, test_dataset):
    if sep_or_comb == "separated":
        output_dir = f"prelim-t5-small-mit-sepvscomb-separated"
    else:
        output_dir = f"prelim-t5-small-mit-sepvscomb-mixed"

    args = TrainingArguments(
        output_dir=output_dir,
    )

    trainer = Trainer(
        model=T5ForConditionalGeneration.from_pretrained(output_dir),
        args=args,
        tokenizer=tokenizer,
        compute_metrics=lambda eval_pred: compute_metrics(eval_pred, tokenizer),
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )

    metrics = trainer.evaluate(eval_dataset=test_dataset)
    print(metrics)

# best_trial = trainer.hyperparameter_search(
#     direction="minimize",
#     backend="optuna",
#     n_trials=5,
#     hp_space=lambda trial: {
#         "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-4),
#         "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [8, 16]),
#         "weight_decay": trial.suggest_float("weight_decay", 0, 0.3),
#     },
# )

# best_trial

In [None]:
model_separated = model_init(tokenizer_separated)
model_mixed = model_init(tokenizer_mixed)

model_separated, model_mixed

In [None]:
train_loop("separated", model=model_separated, tokenizer=tokenizer_separated, train_dataset=train_dataset_separated, val_dataset=val_dataset_separated, num_steps=20000)



In [None]:
train_loop("mixed", model=model_mixed, tokenizer=tokenizer_mixed, train_dataset=train_dataset_mixed, val_dataset=val_dataset_mixed, num_steps=20000)

In [None]:
test_loop("separated", tokenizer=tokenizer_separated, test_dataset=test_dataset_separated)


In [None]:
test_loop("mixed", tokenizer=tokenizer_mixed, test_dataset=test_dataset_mixed)

In [None]:
from transformers import BitsAndBytesConfig
import torch

def inference_4bit(tokenizer, text, model, num_return_sequences=2, num_beams=10):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype="float16",
    )

    model_4bit = T5ForConditionalGeneration.from_pretrained(
        model,
        quantization_config=bnb_config,
        device_map="auto",
    ).to("cuda")

    inputs = tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(model_4bit.device)
    outputs = model_4bit.generate(
        **inputs, 
        max_length=256, 
        repetition_penalty=1.0, 
        do_sample=False, 
        num_return_sequences=num_return_sequences, 
        num_beams=num_beams, 
        output_scores=True, 
        return_dict_in_generate=True
    )

    seq_scores = outputs.sequences_scores  # logits for each sequence
    probs = torch.softmax(seq_scores, dim=0) * 100  # %
    print("Scores: ", probs)

    preds_4bit = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    for i, p in enumerate(preds_4bit):
        print(f"4-bit Model Prediction {i+1}: {p}")

def inference_8bit(tokenizer, text, model, num_return_sequences=2, num_beams=10):
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
    )

    model_8bit = T5ForConditionalGeneration.from_pretrained(
        model,
        quantization_config=bnb_config,
        device_map="auto",
    ).to("cuda")

    inputs = tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(model_8bit.device)
    outputs = model_8bit.generate(
        **inputs, 
        max_length=256, 
        repetition_penalty=1.0, 
        do_sample=False, 
        num_return_sequences=num_return_sequences, 
        num_beams=num_beams, 
        output_scores=True, 
        return_dict_in_generate=True
    )

    seq_scores = outputs.sequences_scores  # logits for each sequence
    probs = torch.softmax(seq_scores, dim=0) * 100  # %
    print("Scores: ", probs)

    preds_8bit = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    for i, p in enumerate(preds_8bit):
        print(f"8-bit Model Prediction {i+1}: {p}")

def inference_fp16(tokenizer, text, model, num_return_sequences=2, num_beams=10):
    model_fp16 = T5ForConditionalGeneration.from_pretrained(
        model,
        torch_dtype=torch.float16,
        device_map="auto",
    ).to("cuda")

    inputs = tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(model_fp16.device)
    outputs = model_fp16.generate(
        **inputs, 
        max_length=256, 
        repetition_penalty=1.0, 
        do_sample=False, 
        num_return_sequences=num_return_sequences, 
        num_beams=num_beams, 
        output_scores=True, 
        return_dict_in_generate=True
    )

    seq_scores = outputs.sequences_scores  # logits for each sequence
    probs = torch.softmax(seq_scores, dim=0) * 100  # %
    print("Scores: ", probs)

    preds_fp16 = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    for i, p in enumerate(preds_fp16):
        print(f"FP16 Model Prediction {i+1}: {p}")

def inference_fp32(tokenizer, text, model, num_return_sequences=2, num_beams=10):
    model_fp32 = T5ForConditionalGeneration.from_pretrained(
        model,
        torch_dtype=torch.float32,
        device_map="auto",
    ).to("cuda")

    inputs = tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(model_fp32.device)
    outputs = model_fp32.generate(
        **inputs, 
        max_length=256, 
        repetition_penalty=1.0, 
        do_sample=False, 
        num_return_sequences=num_return_sequences, 
        num_beams=num_beams, 
        output_scores=True, 
        return_dict_in_generate=True
    )

    seq_scores = outputs.sequences_scores  # logits for each sequence
    probs = torch.softmax(seq_scores, dim=0) * 100  # %
    print("Scores: ", probs)

    preds_fp32 = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    for i, p in enumerate(preds_fp32):
        print(f"FP32 Model Prediction {i+1}: {p}")



In [None]:
text = "C O c 1 c c c c 2 c 1 C ( C ) C ( = O ) N 2 C . N # C C I > C C O . C C [O-] . [Na+]"

In [None]:


inference_4bit(tokenizer_separated, text, "t5-small-mit-sepvscomb-separated20000", num_return_sequences=3, num_beams=10)


In [None]:
inference_8bit(tokenizer_separated, text, "t5-small-mit-sepvscomb-separated20000", num_return_sequences=3, num_beams=10)


In [None]:
inference_fp16(tokenizer_separated, text, "t5-small-mit-sepvscomb-separated20000", num_return_sequences=3, num_beams=10)


In [None]:
inference_fp32(tokenizer_separated, text, "t5-small-mit-sepvscomb-separated20000", num_return_sequences=3, num_beams=10)