In [None]:
from datasets import load_dataset

dataset = load_dataset("Thecoder3281f/MIT_separated", "normal")

In [None]:
dataset

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model_name = "t5-large"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

tokenizer.padding_side = "right"
tokenizer.truncation_side = "right"

In [None]:
def preprocess(batch):
    model_inputs = tokenizer(
        batch["input"],
        padding="max_length",
        truncation=True,
        max_length=256,
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch["target"],
            padding="max_length",
            truncation=True,
            max_length=256,
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess, batched=True, remove_columns=["input", "target"])
tokenized_datasets

In [None]:
tokenized_datasets["train"][0]["attention_mask"]  # Example of tokenized input

In [None]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


In [None]:
import logging
import numpy as np

# set up logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs

def canonicalize(smiles):
    """Join tokens, parse to molecule, return canonical SMILES or None."""
    try:
        s = smiles.replace(" ", "")
        mol = Chem.MolFromSmiles(s)
        if mol is None:
            return None
        return Chem.MolToSmiles(mol, canonical=True)
    except Exception:
        return None

def tanimoto(a, b):
    """Compute Tanimoto similarity between two SMILES."""
    try:
        ma, mb = Chem.MolFromSmiles(a), Chem.MolFromSmiles(b)
        if not ma or not mb:
            return 0
        fa = AllChem.GetMorganFingerprintAsBitVect(ma, 2)
        fb = AllChem.GetMorganFingerprintAsBitVect(mb, 2)
        return DataStructs.TanimotoSimilarity(fa, fb)
    except Exception:
        return 0

def compute_metrics(eval_pred, tokenizer):
    preds, labels = eval_pred

    # handle tuple
    if isinstance(preds, tuple):
        preds = preds[0]

    # convert logits to token IDs if needed
    preds = np.array(preds)
    if preds.ndim == 3:  # (batch, seq_len, vocab_size)
        preds = np.argmax(preds, axis=-1)

    labels = np.array(labels)


    # ðŸ§  handle both top-1 and top-k outputs
    if preds.ndim == 3:  # (batch, k, seq_len)
        batch_size, k, seq_len = preds.shape
        preds = preds.reshape(batch_size * k, seq_len)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_preds = [decoded_preds[i * k:(i + 1) * k] for i in range(batch_size)]
    else:  # (batch, seq_len)
        decoded_preds = [[tokenizer.decode(p, skip_special_tokens=True)] for p in preds]

    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    top1_correct = 0
    tanimotos = []
    valids = []

    for k_preds, label in zip(decoded_preds, decoded_labels):
        label_c = canonicalize(label)
        

        if label_c is None:
            tanimotos.append(0)
            valids.append(0)
            continue
        valids.append(bool(Chem.MolFromSmiles(label_c)))
        best_tani = 0

        for i, pred in enumerate(k_preds):
            p_c = canonicalize(pred)
            if p_c is None:
                continue

            tani = tanimoto(p_c, label_c)
            best_tani = max(best_tani, tani)

            if p_c == label_c:
                if i == 0:
                    top1_correct += 1
                break

        tanimotos.append(best_tani)

    canonical_top1 = top1_correct / len(decoded_labels)
    mean_tanimoto = sum(tanimotos) / len(tanimotos)
    validity = sum(valids) / len(valids)

    logger.info(f"Canonical Top-1 Accuracy: {canonical_top1:.3f}")
    logger.info(f"Mean Tanimoto: {mean_tanimoto:.3f}")
    logger.info(f"Validity: {validity:.3f}")

    return {
        "canonical_top1": canonical_top1,
        "mean_tanimoto": mean_tanimoto,
        "validity": validity,
    }


In [None]:
train_dataset_small = tokenized_datasets["train"].shuffle(seed=42).select(range(40000))
val_dataset_small = tokenized_datasets["val"].shuffle(seed=42).select(range(3000))
test_dataset_small = tokenized_datasets["test"].shuffle(seed=42).select(range(3000))

train_dataset = tokenized_datasets["train"]
val_dataset = tokenized_datasets["val"]
test_dataset = tokenized_datasets["test"]

In [None]:
train_dataset

In [None]:
from transformers import Trainer, TrainingArguments, T5ForConditionalGeneration

In [None]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        logits = logits[0]
    return logits.argmax(dim=-1), labels

args = TrainingArguments(
    output_dir="test-large",
    eval_strategy="steps",
    save_strategy="steps",
    learning_rate=3e-4,
    # per_device_train_batch_size=64,
    auto_find_batch_size=True,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    max_steps=1000,
    save_steps=250,
    eval_steps=250,
    # num_train_epochs=1,
    logging_strategy="steps",
    logging_steps=100,
    report_to="tensorboard",
    weight_decay=0.05,
    logging_dir="./logs",
    greater_is_better=True,
    metric_for_best_model="canonical_top1",
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    eval_accumulation_steps=128,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset_small,
    eval_dataset=val_dataset_small,
    tokenizer=tokenizer,
    compute_metrics=lambda eval_pred: compute_metrics(eval_pred, tokenizer),
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

trainer.train(resume_from_checkpoint=False)


# best_trial = trainer.hyperparameter_search(
#     direction="minimize",
#     backend="optuna",
#     n_trials=5,
#     hp_space=lambda trial: {
#         "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-4),
#         "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [8, 16]),
#         "weight_decay": trial.suggest_float("weight_decay", 0, 0.3),
#     },
# )

# best_trial

In [None]:
metrics = trainer.evaluate(test_dataset_small)
print(metrics)


In [None]:
trainer.save_model("t5-mit-small-dataset-separated-lr1e-5-wd0.05-5000steps")