In [None]:
import os, json, random, pandas as pd, torch
from transformers import EarlyStoppingCallback

from datasets import Dataset
from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM,
                          DataCollatorForSeq2Seq, Seq2SeqTrainingArguments,
                          Seq2SeqTrainer)
from datasets import load_dataset

MODEL_NAME = "sagawa/ReactionT5v2-retrosynthesis"
TRAIN_CSV = "data_train.csv"
TEST_CSV  = "product_smiles_test.csv"
SUBMISSION = "submission.csv"
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)

df = pd.read_csv(TRAIN_CSV, sep="\s*>>\s*", engine="python", header=None,
                 names=["product", "reactant"])
print(f"Loaded {len(df):,} training examples")
df.head()

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

max_input = 32
max_label = 32

def preprocess(examples):
    model_inputs = tokenizer(examples["product"], max_length=max_input,
                             truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["reactant"], max_length=max_label,
                           truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


raw_ds = Dataset.from_pandas(df)

tokenised_ds = raw_ds.map(preprocess, batched=True, remove_columns=raw_ds.column_names, batch_size=32)
tokenised_ds = tokenised_ds.train_test_split(test_size=0.1, seed=SEED)
tokenised_ds


In [None]:
from rdkit import Chem

def process_chemicals(chemicals):
    processed = []
    for chem in chemicals:
        try:
            mol = Chem.MolFromSmiles(chem)
            if mol is not None:
                processed.append(Chem.MolToSmiles(mol))
        except:
            pass 
    return processed

def compute_top1_metric(eval_preds):
    preds, labels = eval_preds

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = [
        tokenizer.decode([t for t in label if t != -100], skip_special_tokens=True)
        for label in labels
    ]

    correct = 0
    total = 0

    for pred, label in zip(decoded_preds, decoded_labels):
        pred = pred.strip()
        label = label.strip()

        if not label:
            continue 

        pred_set = set(process_chemicals(pred.split('.')))
        label_set = set(process_chemicals(label.split('.')))

        if pred_set == label_set:
            correct += 1
        total += 1

    return {"top1_accuracy": correct / total if total > 0 else 0}


In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

model = AutoModelForSeq2SeqLM.from_pretrained("reactiont5-finetuned/checkpoint-13500")
tokenizer = AutoTokenizer.from_pretrained("sagawa/ReactionT5v2-retrosynthesis")

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

print("Training device →", device)

fp16 = False
bf16 = False

args = Seq2SeqTrainingArguments(
    output_dir="reactiont5-finetuned",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-4,
    num_train_epochs=10,  
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    predict_with_generate=True,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="top1_accuracy",
    greater_is_better=True,
    save_total_limit=2,
    label_smoothing_factor=0.1,
    lr_scheduler_type="cosine",
    # warmup_ratio=0.2,
    weight_decay=0.01
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    train_dataset=tokenised_ds["train"],
    eval_dataset=tokenised_ds["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_top1_metric
)


trainer.train()


In [None]:
BATCH_SIZE = 16  

# # === Inference ===
test_df = pd.read_csv(TEST_CSV, header=None, names=["product"])
test_inputs = tokenizer(test_df["product"].tolist(), padding=True, truncation=True,
                        max_length=max_input, return_tensors="pt").to(model.device)

all_preds = []
model.eval()

for i in range(0, len(test_df), BATCH_SIZE):
    batch = test_df["product"].iloc[i:i+BATCH_SIZE].tolist()
    test_inputs = tokenizer(batch, padding=True, truncation=True,
                            max_length=max_input, return_tensors="pt").to(model.device)

    with torch.no_grad():
        gen = model.generate(input_ids=test_inputs["input_ids"],
                             attention_mask=test_inputs["attention_mask"],
                             max_length=max_label, num_beams=10)

    pred_smiles = tokenizer.batch_decode(gen, skip_special_tokens=True)
    all_preds.extend([p.strip() for p in pred_smiles])

pd.Series(all_preds).to_csv(SUBMISSION, index=False, header=False)
print(f"Wrote {SUBMISSION}")
