In [None]:
from datasets import load_dataset

dataset_separated = load_dataset("Thecoder3281f/MIT_separated_final", "normal")
dataset_mixed = load_dataset("Thecoder3281f/MIT_mixed_final", "normal")

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model_name = "t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

tokenizer.padding_side = "right"
tokenizer.truncation_side = "right"

In [None]:
def preprocess(batch):
    # inputs = [x.replace(" ", "") for x in batch["input"]]
    # targets = [x.replace(" ", "") for x in batch["target"]]
    inputs = batch["input"]
    targets = batch["target"]

    # print(inputs, targets)

    model_inputs = tokenizer(
        inputs,
        padding="max_length",
        truncation=True,
        max_length=256,
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            padding="max_length",
            truncation=True,
            max_length=256,
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [None]:
dataset_separated = dataset_separated.map(preprocess, batched=True, remove_columns=["input", "target"])
dataset_mixed = dataset_mixed.map(preprocess, batched=True, remove_columns=["input", "target"])

In [None]:
dataset_separated_train = dataset_separated["train"]
dataset_separated_val = dataset_separated["val"]
dataset_separated_test = dataset_separated["test"]

dataset_mixed_train = dataset_mixed["train"]
dataset_mixed_val = dataset_mixed["val"]
dataset_mixed_test = dataset_mixed["test"]

In [None]:
import logging
import numpy as np

# set up logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs

def canonicalize(smiles):
    """Join tokens, parse to molecule, return canonical SMILES or None."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        return Chem.MolToSmiles(mol, canonical=True)
    except Exception:
        return None

def tanimoto(a, b):
    """Compute Tanimoto similarity between two SMILES."""
    try:
        ma, mb = Chem.MolFromSmiles(a), Chem.MolFromSmiles(b)
        if not ma or not mb:
            return 0
        fa = AllChem.GetMorganFingerprintAsBitVect(ma, 2)
        fb = AllChem.GetMorganFingerprintAsBitVect(mb, 2)
        return DataStructs.TanimotoSimilarity(fa, fb)
    except Exception:
        return 0

def compute_metrics(eval_pred, tokenizer):
    preds, labels = eval_pred

    # handle tuple
    if isinstance(preds, tuple):
        preds = preds[0]

    # convert logits to token IDs if needed
    preds = np.array(preds)
    if preds.ndim == 3:  # (batch, seq_len, vocab_size)
        preds = np.argmax(preds, axis=-1)

    labels = np.array(labels)

    # handle top-1 outputs
    if preds.ndim == 3:  # (batch, k, seq_len)
        batch_size, k, seq_len = preds.shape
        preds = preds.reshape(batch_size * k, seq_len)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_preds = [decoded_preds[i * k:(i + 1) * k] for i in range(batch_size)]
    else:  # (batch, seq_len)
        decoded_preds = [[tokenizer.decode(p, skip_special_tokens=True)] for p in preds]

    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    top1_correct = 0
    tanimotos = []
    valids = []

    for k_preds, label in zip(decoded_preds, decoded_labels):
        label_c = canonicalize(label)
        
        if label_c is None:
            tanimotos.append(0)
            valids.append(0)
            continue
        valids.append(bool(Chem.MolFromSmiles(label_c)))
        best_tani = 0

        for i, pred in enumerate(k_preds):
            p_c = canonicalize(pred)
            if p_c is None:
                continue

            tani = tanimoto(p_c, label_c)
            best_tani = max(best_tani, tani)

            if p_c == label_c:
                if i == 0:
                    top1_correct += 1
                break

        tanimotos.append(best_tani)

    canonical_top1 = top1_correct / len(decoded_labels)
    mean_tanimoto = sum(tanimotos) / len(tanimotos)
    validity = sum(valids) / len(valids)

    logger.info(f"Canonical Top-1 Accuracy: {canonical_top1:.3f}")
    logger.info(f"Mean Tanimoto: {mean_tanimoto:.3f}")
    logger.info(f"Validity: {validity:.3f}")

    return {
        "canonical_top1": canonical_top1,
        "mean_tanimoto": mean_tanimoto,
        "validity": validity,
    }


In [None]:
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

In [None]:
output_dir = "t5-small-no-tokenizer-separated"
output_dir2 = "t5-small-no-tokenizer-mixed"

In [None]:
# def model_init(tokenizer):
#     model = T5ForConditionalGeneration.from_pretrained("t5-small")
#     model.resize_token_embeddings(len(tokenizer))
#     return model.to("cuda")


def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        logits = logits[0]
    return logits.argmax(dim=-1), labels


args = TrainingArguments(
    output_dir=output_dir,
    eval_strategy="steps",
    save_strategy="steps",
    learning_rate=3e-4,
    # per_device_train_batch_size=64,
    auto_find_batch_size=True,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    max_steps=20000,
    save_steps=1000,
    eval_steps=1000,
    # num_train_epochs=1,
    logging_strategy="steps",
    logging_steps=500,
    report_to="tensorboard",
    weight_decay=0.01,
    logging_dir="./logs/t5-small-separated",
    run_name="t5-small-separated-20k",
    greater_is_better=True,
    metric_for_best_model="canonical_top1",
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    eval_accumulation_steps=128,
    fp16=True,
    save_total_limit=3,
)



# Add early stopping callback
early_stop_callback = EarlyStoppingCallback(
    early_stopping_patience=3,   # stop if no improvement for 3 evals
    early_stopping_threshold=0.001 # minimum change to qualify as improvement
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset_separated_train,
    eval_dataset=dataset_separated_val,
    tokenizer=tokenizer,
    compute_metrics=lambda eval_pred: compute_metrics(eval_pred, tokenizer),
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    callbacks=[early_stop_callback],
)






In [None]:
trainer.train(resume_from_checkpoint=False)



In [None]:
# Save the final model
trainer.save_model(f"{output_dir}" + "20000")

In [None]:
args2 = TrainingArguments(
    output_dir=output_dir2,
    eval_strategy="steps",
    save_strategy="steps",
    learning_rate=3e-4,
    # per_device_train_batch_size=64,
    auto_find_batch_size=True,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    max_steps=20000,
    save_steps=1000,
    eval_steps=1000,
    # num_train_epochs=1,
    logging_strategy="steps",
    logging_steps=500,
    report_to="tensorboard",
    weight_decay=0.01,
    logging_dir="./logs/t5-small-mixed",
    run_name="t5-small-mixed-20k",
    greater_is_better=True,
    metric_for_best_model="canonical_top1",
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    eval_accumulation_steps=128,
    fp16=True,
    save_total_limit=3,
)

trainer2 = Trainer(
    model=model,
    args=args2,
    train_dataset=dataset_mixed_train,
    eval_dataset=dataset_mixed_val,
    tokenizer=tokenizer,
    compute_metrics=lambda eval_pred: compute_metrics(eval_pred, tokenizer),
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    callbacks=[early_stop_callback],
)

In [None]:
trainer2.train(resume_from_checkpoint=False)

In [None]:
# Save the final model
trainer2.save_model(f"{output_dir2}" + "20000")

In [None]:
best_trial = trainer.hyperparameter_search(
    direction="maximise",
    backend="optuna",
    n_trials=10,
    hp_space=lambda trial: {
        "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-4),
        "weight_decay": trial.suggest_float("weight_decay", 0, 0.3),
    },
)

best_trial