In [None]:
!pip install rdkit
!pip install transformers datasets

In [None]:
# Download SA Score script (by Ertl)
!wget https://raw.githubusercontent.com/rdkit/rdkit/master/Contrib/SA_Score/sascorer.py
!wget https://github.com/rdkit/rdkit/raw/master/Contrib/SA_Score/fpscores.pkl.gz

In [None]:
import random, os, numpy as np, pandas as pd

def seed_everything(seed=42):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

In [None]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [None]:
df = pd.read_csv('/content/drive/MyDrive/Molecular design/ReactionML/Low fidelity/USPTO/USPTO_50K.csv')
df

In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors
import sascorer
import numpy as np
from tqdm.notebook import tqdm  # or tqdm.auto for non-Colab use

# Enable tqdm with pandas apply
tqdm.pandas()

# If not already loaded:
# df = pd.read_csv("your_uspto_50k.csv")  # make sure it has a 'reactions' column

def compute_difficulty_metrics(reaction_smiles):
    try:
        product_smiles = reaction_smiles.split(">")[-1]
        mol = Chem.MolFromSmiles(product_smiles)
        if mol is None:
            return np.nan, np.nan, np.nan, np.nan

        sa_score = sascorer.calculateScore(mol)
        ring_count = mol.GetRingInfo().NumRings()
        heavy_atom_count = Descriptors.HeavyAtomCount(mol)

        difficulty_score = (
            0.5 * sa_score +
            0.3 * ring_count +
            0.2 * heavy_atom_count
        )

        return sa_score, ring_count, heavy_atom_count, difficulty_score

    except Exception:
        return np.nan, np.nan, np.nan, np.nan

# Apply with progress bar
df[['sa_score', 'ring_count', 'heavy_atoms', 'difficulty']] = df['reactions'].progress_apply(
    compute_difficulty_metrics
).apply(pd.Series)

# Remove problematic rows and sort
df = df.dropna(subset=['difficulty']).sort_values('difficulty').reset_index(drop=True)

In [None]:
# import matplotlib.pyplot as plt
# import seaborn as sns

# plt.figure(figsize=(8, 5))
# sns.histplot(df_sorted['difficulty'], bins=40, kde=True)
# plt.title("Distribution of Retrosynthesis Difficulty Scores")
# plt.xlabel("Difficulty Score")
# plt.ylabel("Frequency")
# plt.grid(True)
# plt.show()

In [None]:
from rdkit import Chem

def canonicalize_smiles(smi):
    """Canonicalizes multi-fragment SMILES (e.g. 'CC.O') with dot handling."""
    parts = smi.split('.')
    mols = [Chem.MolFromSmiles(part) for part in parts]
    if any(m is None for m in mols):
        return smi  # fallback if any fragment fails
    canonical_parts = [Chem.MolToSmiles(m, canonical=True) for m in mols]
    return '.'.join(sorted(canonical_parts))  # sort ensures consistent order

In [None]:
# Split product/reactants from reaction SMILES
df['product'] = df['reactions'].apply(lambda x: x.split('>')[-1])
df['reactants'] = df['reactions'].apply(lambda x: x.split('>')[0])

# Canonicalize both sides
df['product'] = df['product'].apply(canonicalize_smiles)
df['reactants'] = df['reactants'].apply(canonicalize_smiles)

In [None]:
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
import pandas as pd
import random
from collections import defaultdict
from sklearn.model_selection import train_test_split

def generate_scaffold(smiles, include_chirality=False):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
    return scaffold

def scaffold_split(df, smiles_column='product', frac_train=0.8, frac_val=0.1, frac_test=0.1, seed=42):
    assert abs(frac_train + frac_val + frac_test - 1.0) < 1e-6, "Fractions must sum to 1.0"

    scaffolds = defaultdict(list)
    for idx, smiles in enumerate(df[smiles_column]):
        scaffold = generate_scaffold(smiles)
        if scaffold is not None:
            scaffolds[scaffold].append(idx)

    # Sort scaffolds by frequency (descending)
    sorted_scaffolds = sorted(scaffolds.items(), key=lambda x: len(x[1]), reverse=True)
    train_idx, val_idx, test_idx = [], [], []
    total_count = len(df)
    train_cutoff = int(frac_train * total_count)
    val_cutoff = int((frac_train + frac_val) * total_count)

    running_total = 0
    for scaffold, indices in sorted_scaffolds:
        if running_total + len(indices) <= train_cutoff:
            train_idx.extend(indices)
        elif running_total + len(indices) <= val_cutoff:
            val_idx.extend(indices)
        else:
            test_idx.extend(indices)
        running_total += len(indices)

    return df.iloc[train_idx], df.iloc[val_idx], df.iloc[test_idx]

# Example usage
# df = pd.read_csv("your_data.csv")  # must contain 'product' column with SMILES
train_df, val_df, test_df = scaffold_split(df)

In [None]:
from sklearn.model_selection import train_test_split

# train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
# train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)  # 81/9/10

train_df = train_df.sort_values(by="difficulty").reset_index(drop=True)

In [None]:
from transformers import AutoTokenizer, EncoderDecoderModel
from datasets import Dataset

chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
gpt2_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

# Add pad token if missing
if gpt2_tokenizer.pad_token is None:
    gpt2_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

def tokenize_function(example):
    # Encoder input = product SMILES, tokenized with ChemBERTa
    input_ = chemberta_tokenizer(
        "retro: " + example['product'],
        padding="max_length",
        truncation=True,
        max_length=128
    )

    # Decoder target = reactants SMILES, tokenized with GPT2 tokenizer
    target = gpt2_tokenizer(
        example['reactants'],
        padding="max_length",
        truncation=True,
        max_length=128
    )

    input_['labels'] = target['input_ids']
    return input_

train_ds = Dataset.from_pandas(train_df[['product', 'reactants']])
val_ds = Dataset.from_pandas(val_df[['product', 'reactants']])
test_ds = Dataset.from_pandas(test_df[['product', 'reactants']])

train_tokenized = train_ds.map(tokenize_function, batched=False)
val_tokenized = val_ds.map(tokenize_function, batched=False)
test_tokenized = test_ds.map(tokenize_function, batched=False)

In [None]:
def get_curriculum_indices(epoch, total_epochs, dataset_length, min_frac=0.1, max_frac=1.0):
    '''
    Pacing function as per Hacohen and Weinshall
    '''
    frac = min(1.0, (epoch + 1) / (0.5 * total_epochs))  # Full dataset by halfway point
    return list(range(int(dataset_length * frac)))

def get_log_pacing_indices(epoch, total_epochs, dataset_length, min_frac=0.05, max_frac=1.0):
    frac = min_frac + (max_frac - min_frac) * np.log1p(epoch + 1) / np.log1p(total_epochs)
    return list(range(int(dataset_length * frac)))

def get_exp_pacing_indices(epoch, total_epochs, dataset_length, min_frac=0.05, max_frac=1.0):
    alpha = 2  # higher = slower start
    frac = min_frac + (max_frac - min_frac) * (1 - np.exp(-alpha * (epoch + 1) / total_epochs))
    return list(range(int(dataset_length * frac)))

def get_step_pacing_indices(epoch, total_epochs, dataset_length, num_steps=5):
    step = min(epoch * num_steps // total_epochs + 1, num_steps)
    frac = step / num_steps
    return list(range(int(dataset_length * frac)))

In [None]:
from transformers import EncoderDecoderModel, Trainer, TrainingArguments, AutoTokenizer
import torch
import os

# Set output path
output_dir = "/content/drive/MyDrive/Molecular design/ReactionML/CurriculumLearning/retrosyn-chemberta-distilgpt2-curriculum-scaffold"

# Load tokenizers
chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
gpt2_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

# Ensure pad token exists in GPT2
if gpt2_tokenizer.pad_token is None:
    gpt2_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Look for latest checkpoint
checkpoints = sorted(
    [ckpt for ckpt in os.listdir(output_dir) if ckpt.startswith("checkpoint-")],
    key=lambda x: int(x.split("-")[-1])
)

resume_checkpoint = os.path.join(output_dir, checkpoints[-1]) if checkpoints else None

# Load model from checkpoint if available, else initialize from scratch
if resume_checkpoint:
    print(f"✅ Resuming from checkpoint: {resume_checkpoint}")
    model = EncoderDecoderModel.from_pretrained(resume_checkpoint)
else:
    print("🚀 Starting from scratch")
    model = EncoderDecoderModel.from_encoder_decoder_pretrained(
        "seyonec/ChemBERTa-zinc-base-v1",  # encoder
        "distilgpt2"                       # decoder
    )

# Set model config
model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id
model.config.pad_token_id = gpt2_tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.decoder.resize_token_embeddings(len(gpt2_tokenizer))

# TrainingArguments
training_args = TrainingArguments(
    output_dir=output_dir,
    eval_strategy="no",
    save_strategy="no",  # no automatic save
    per_device_train_batch_size=64,
    per_device_eval_batch_size=8,
    eval_accumulation_steps=2,
    prediction_loss_only=False,
    num_train_epochs=1,  # we manually loop
    logging_dir="./logs",
    logging_steps=50,
    save_total_limit=2,
    load_best_model_at_end=False,
    fp16=True
)

In [None]:
from transformers import Trainer
from tqdm import tqdm

class CurriculumTrainer(Trainer):
    def set_dataset(self, dataset):
        self.train_dataset = dataset

# ✅ Initialize trainer once with dummy dataset
trainer = CurriculumTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized.select([0]),
    tokenizer=gpt2_tokenizer
)

total_epochs = 100
epoch_bar = tqdm(range(total_epochs), desc="Curriculum Training", unit="epoch")

for epoch in epoch_bar:
    indices = get_curriculum_indices(epoch, total_epochs, len(train_tokenized))
    current_data = train_tokenized.select(indices)

    # ✅ Set new dataset
    trainer.set_dataset(current_data)

    # ✅ Train for *one curriculum step* (not full 100 epochs again!)
    trainer.args.num_train_epochs = 1
    trainer.state.epoch = 0  # 🔁 Reset internal epoch count for logging

    trainer.train(resume_from_checkpoint=None)

    # ✅ Nice progress bar update
    epoch_bar.set_description(f"Epoch {epoch+1:03d} | Samples: {len(indices)}")

    # 💾 Save checkpoints
    if (epoch + 1) % 20 == 0:
        checkpoint_path = os.path.join(output_dir, f"checkpoint-epoch-{epoch+1}")
        model.save_pretrained(checkpoint_path)
        gpt2_tokenizer.save_pretrained(checkpoint_path)
        epoch_bar.write(f"💾 Saved checkpoint at {checkpoint_path}")