In [None]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

!pip install rdkit-pypi

In [None]:
!pip install rdkit
!pip install selfies
!pip install transformers datasets

In [None]:
# Download SA Score script (by Ertl)
!wget https://raw.githubusercontent.com/rdkit/rdkit/master/Contrib/SA_Score/sascorer.py
!wget https://github.com/rdkit/rdkit/raw/master/Contrib/SA_Score/fpscores.pkl.gz

In [None]:
import random, os, numpy as np, pandas as pd

def seed_everything(seed=42):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

In [None]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [None]:
df = pd.read_csv('USPTO_50K.csv')
df

In [None]:
# import pandas as pd
# from rdkit import Chem
# from rdkit.Chem import Descriptors
# import sascorer
# import numpy as np
# from tqdm.notebook import tqdm  # or tqdm.auto for non-Colab use

# # Enable tqdm with pandas apply
# tqdm.pandas()

# # If not already loaded:
# # df = pd.read_csv("your_uspto_50k.csv")  # make sure it has a 'reactions' column

# def compute_difficulty_metrics(reaction_smiles):
#     try:
#         product_smiles = reaction_smiles.split(">")[-1]
#         mol = Chem.MolFromSmiles(product_smiles)
#         if mol is None:
#             return np.nan, np.nan, np.nan, np.nan

#         sa_score = sascorer.calculateScore(mol)
#         ring_count = mol.GetRingInfo().NumRings()
#         heavy_atom_count = Descriptors.HeavyAtomCount(mol)

#         difficulty_score = (
#             0.5 * sa_score +
#             0.3 * ring_count +
#             0.2 * heavy_atom_count
#         )

#         return sa_score, ring_count, heavy_atom_count, difficulty_score

#     except Exception:
#         return np.nan, np.nan, np.nan, np.nan

# # Apply with progress bar
# df[['sa_score', 'ring_count', 'heavy_atoms', 'difficulty']] = df['reactions'].progress_apply(
#     compute_difficulty_metrics
# ).apply(pd.Series)

# # Remove problematic rows and sort
# df_sorted = df.dropna(subset=['difficulty']).sort_values('difficulty').reset_index(drop=True)

In [None]:
# import matplotlib.pyplot as plt
# import seaborn as sns

# plt.figure(figsize=(8, 5))
# sns.histplot(df_sorted['difficulty'], bins=40, kde=True)
# plt.title("Distribution of Retrosynthesis Difficulty Scores")
# plt.xlabel("Difficulty Score")
# plt.ylabel("Frequency")
# plt.grid(True)
# plt.show()

In [None]:
from rdkit import Chem

def canonicalize_smiles(smi):
    """Canonicalizes multi-fragment SMILES (e.g. 'CC.O') with dot handling."""
    parts = smi.split('.')
    mols = [Chem.MolFromSmiles(part) for part in parts]
    if any(m is None for m in mols):
        return smi  # fallback if any fragment fails
    canonical_parts = [Chem.MolToSmiles(m, canonical=True) for m in mols]
    return '.'.join(sorted(canonical_parts))  # sort ensures consistent order

# Split product/reactants from reaction SMILES
df['product'] = df['reactions'].apply(lambda x: x.split('>')[-1])
df['reactants'] = df['reactions'].apply(lambda x: x.split('>')[0])

# Canonicalize both sides
df['product'] = df['product'].apply(canonicalize_smiles)
df['reactants'] = df['reactants'].apply(canonicalize_smiles)

In [None]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)  # 81/9/10

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset

t5_tokenizer = AutoTokenizer.from_pretrained("sagawa/ReactionT5v2-retrosynthesis")

# Add pad token if missing
if t5_tokenizer.pad_token is None:
    t5_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

def tokenize_function(example):
    # Encoder input = product SMILES, tokenized with ChemBERTa
    input_ = t5_tokenizer(
        "retro: " + example['product'],
        padding="max_length",
        truncation=True,
        max_length=128
    )

    # Decoder target = reactants SMILES, tokenized with GPT2 tokenizer
    target = t5_tokenizer(
        example['reactants'],
        padding="max_length",
        truncation=True,
        max_length=128
    )

    input_['labels'] = target['input_ids']
    return input_

# train_ds = Dataset.from_pandas(train_df[['product', 'reactants']])
val_ds = Dataset.from_pandas(val_df[['product', 'reactants']])
test_ds = Dataset.from_pandas(test_df[['product', 'reactants']])

# train_tokenized = train_ds.map(tokenize_function, batched=False)
val_tokenized = val_ds.map(tokenize_function, batched=False)
test_tokenized = test_ds.map(tokenize_function, batched=False)

In [None]:
from collections import defaultdict

test_by_class = defaultdict(list)

# Make sure this matches your tokenized dataset
assert len(test_tokenized) == len(test_df), "Mismatch between test_df and tokenized dataset"

for i, example in enumerate(test_tokenized):
    cls = test_df.iloc[i]["class"]  # ✅ use test_df here
    test_by_class[cls].append(example)

In [None]:
from transformers import EncoderDecoderModel, Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os

# Set output path
output_dir = "CurriculumLearning/retrosyn-t5-curriculum"

model = AutoModelForSeq2SeqLM.from_pretrained("sagawa/ReactionT5v2-retrosynthesis")

model.config.pad_token_id = t5_tokenizer.pad_token_id

checkpoints = sorted(
    [ckpt for ckpt in os.listdir(output_dir) if ckpt.startswith("checkpoint")],
    key=lambda x: int(x.split("-")[-1])
)

resume_checkpoint = os.path.join(output_dir, checkpoints[-1]) if checkpoints else None

if resume_checkpoint:
    print("Resuming from checkpoint:", resume_checkpoint)
    model = AutoModelForSeq2SeqLM.from_pretrained(resume_checkpoint)
else:
    print("No checkpoints found. Starting from scratch.")

# Make sure model is on the right device
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
import torch
from tqdm import tqdm
from rdkit import Chem
from datasets import load_metric

bleu = load_metric("bleu")

def compute_validity(smiles_list):
    return [Chem.MolFromSmiles(smi) is not None for smi in smiles_list]


def evaluate_model_manual(model, tokenizer, dataset, batch_size=4, max_length=128, num_beams=5, top_k=5):
    model.eval()
    total = 0
    top1_exact = 0
    topk_exact = 0
    valid_count = 0
    bleu_preds = []
    bleu_refs = []

    pbar = tqdm(range(0, len(dataset), batch_size), desc="Evaluating")

    for i in pbar:
        batch = dataset[i: i + batch_size]

        input_ids = torch.tensor(batch["input_ids"]).to(model.device)
        attention_mask = torch.tensor(batch["attention_mask"]).to(model.device)
        labels = batch["labels"]

        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                num_return_sequences=top_k,
                early_stopping=True,
                pad_token_id=model.config.pad_token_id
            )

        # Flatten and decode
        batch_size_actual = input_ids.size(0)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        # Fix: safe reshape of 3D tensor
        outputs_grouped = outputs.reshape(batch_size_actual, top_k, -1)

        for preds, label in zip(outputs_grouped, decoded_labels):
            label = "".join(label.split())
            total += 1

            topk_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
            topk_preds = ["".join(p.split()) for p in topk_preds]
            pred1 = topk_preds[0]

            # Top-1 match
            if pred1 == label:
                top1_exact += 1

            # Valid SMILES check (only on top-1)
            if Chem.MolFromSmiles(pred1):
                valid_count += 1

            # BLEU (top-1 only)
            bleu_preds.append(pred1.split())
            bleu_refs.append([label.split()])

            # Top-k match
            if label in topk_preds:
                topk_exact += 1

        # ✅ Update progress bar with current accuracy
        if total > 0:
            pbar.set_postfix({
                "top1": f"{top1_exact / total:.3f}",
                "topk": f"{topk_exact / total:.3f}",
                "valid": f"{valid_count / total:.3f}"
            })

    bleu_score = bleu.compute(predictions=bleu_preds, references=bleu_refs)["bleu"]

    return {
        "top1_exact_match": top1_exact / total,
        "topk_exact_match": topk_exact / total,
        "validity": valid_count / total,
        "bleu": bleu_score
    }

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

In [None]:
import torch
from tqdm import tqdm
import pandas as pd
from rdkit import Chem

def compute_validity(smiles_list):
    return sum(Chem.MolFromSmiles(smi) is not None for smi in smiles_list) / len(smiles_list)

def evaluate_model_predictions(model, tokenizer, dataset, df, batch_size=32, max_length=128, num_beams=5, prefix="baseline"):
    model.eval()
    top1_preds, top5_preds = [], []

    for i in tqdm(range(0, len(dataset), batch_size), desc=f"Generating {prefix} predictions"):
        batch = dataset[i: i + batch_size]
        input_ids = torch.tensor(batch["input_ids"]).to(model.device)
        attention_mask = torch.tensor(batch["attention_mask"]).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                num_return_sequences=5,
                early_stopping=True,
                pad_token_id=tokenizer.pad_token_id
            )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        decoded = [s.strip().replace(' ', '') for s in decoded]
        grouped = [decoded[i:i+5] for i in range(0, len(decoded), 5)]
        top1_preds.extend([g[0] for g in grouped])
        top5_preds.extend(grouped)

    df = df.copy()
    df[f"{prefix}_top1_preds"] = top1_preds
    df[f"{prefix}_top5_preds"] = top5_preds
    df[f"{prefix}_top1_validity"] = [compute_validity([s]) for s in top1_preds]
    df[f"{prefix}_top5_validity"] = [compute_validity(g) for g in top5_preds]

    return df

In [None]:
from tqdm.auto import tqdm

In [None]:
test_df = evaluate_model_predictions(
    model=model,
    tokenizer=t5_tokenizer,
    dataset=test_tokenized,
    df=test_df,
    batch_size=64,
    max_length=128,
    num_beams=5
)

In [None]:
import numpy as np
from sklearn.utils import resample

# Suppose these are lists of 1s and 0s indicating whether the top-1 prediction was correct
baseline_correct = test_df["baseline_top1_preds"] == test_df["reactants"]

# Convert to numpy
baseline_correct = np.array(baseline_correct, dtype=int)

# Compute mean and bootstrap std
def bootstrap_mean_std(data, n_samples=1000):
    means = [resample(data).mean() for _ in range(n_samples)]
    return np.mean(means), np.std(means)

baseline_mean, baseline_std = bootstrap_mean_std(baseline_correct)

print(f"Baseline Top-1 Accuracy = {baseline_mean:.3f} ± {baseline_std:.3f}")

In [None]:
from collections import defaultdict
import numpy as np

# Define metrics
def top1_accuracy(y_true, y_pred):
    return sum(p == t for p, t in zip(y_pred, y_true)) / len(y_true)

def top5_accuracy(y_true, y_pred_top5):
    return sum(t in top5 for t, top5 in zip(y_true, y_pred_top5)) / len(y_true)

def validity(y_pred):
    from rdkit import Chem
    return sum(Chem.MolFromSmiles(smi) is not None for smi in y_pred) / len(y_pred)

def validity_wrapper(_, y_pred):
    from rdkit import Chem
    return sum(Chem.MolFromSmiles(smi) is not None for smi in y_pred) / len(y_pred)

# Bootstrapping utility
def bootstrap_metric(y_true, y_pred, metric_fn, n_samples=1000, seed=42):
    np.random.seed(seed)
    data = list(zip(y_true, y_pred))
    scores = []
    for _ in range(n_samples):
        sample = [data[i] for i in np.random.randint(0, len(data), len(data))]
        sample_true, sample_pred = zip(*sample)
        score = metric_fn(sample_true, sample_pred)
        scores.append(score)
    return np.mean(scores), np.std(scores)

In [None]:
from collections import defaultdict

results = defaultdict(dict)

for reaction_class in sorted(test_df["class"].unique()):
    class_df = test_df[test_df["class"] == reaction_class]
    if len(class_df) < 10:
        continue

    y_true = class_df["reactants"].tolist()
    y_top1 = class_df["baseline_top1_preds"].tolist()
    y_top5 = class_df["baseline_top5_preds"].tolist()

    acc1_mean, acc1_std = bootstrap_metric(y_true, y_top1, top1_accuracy)
    acc5_mean, acc5_std = bootstrap_metric(y_true, y_top5, top5_accuracy)

    results[reaction_class] = {
        "Top-1": (acc1_mean, acc1_std),
        "Top-5": (acc5_mean, acc5_std),
    }

    print(f"Class {reaction_class}:")
    print(f"  Top-1 Accuracy = {acc1_mean:.3f} ± {acc1_std:.3f}")
    print(f"  Top-5 Accuracy = {acc5_mean:.3f} ± {acc5_std:.3f}")

In [None]:
metrics = evaluate_model_manual(
    model=model,
    tokenizer=t5_tokenizer,
    dataset=test_tokenized,
    batch_size=64,
    num_beams=5,
    top_k=5
)

print(metrics)

In [None]:
# from datasets import Dataset

# results_by_class = {}

# for cls, examples in test_by_class.items():
#     cls_dataset = Dataset.from_list(examples)
#     metrics = evaluate_model_manual(
#         model=model,
#         tokenizer=t5_tokenizer,
#         dataset=cls_dataset,
#         batch_size=64,
#         num_beams=5,
#         top_k=5
#     )
#     results_by_class[cls] = metrics

#     # 📢 Print metrics for this class
#     print(f"\n🧪 Reaction Class {cls}")
#     for key, value in metrics.items():
#         print(f"  {key:<20s}: {value:.4f}")