In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import matplotlib.cm as cm
import os
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split


In [None]:
def plot_pca(embeddings_smiles, embeddings_text, title, save_path):
    pca = PCA(n_components=2)
    smiles_pca = pca.fit_transform(embeddings_smiles)
    text_pca = pca.transform(embeddings_text)
    plt.figure(figsize=(8, 6))
    plt.scatter(smiles_pca[:, 0], smiles_pca[:, 1], label="SMILES", alpha=0.7)
    plt.scatter(text_pca[:, 0], text_pca[:, 1], label="Text", alpha=0.7)
    plt.legend()
    plt.title(title, size=18)
    plt.xlabel("PCA Component 1")
    plt.ylabel("PCA Component 2")
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"PCA plot saved at: {save_path}")
    plt.show()

def recall_at_k(smiles_embeddings, text_embeddings, k=5):
    cosine_sim = np.dot(smiles_embeddings, text_embeddings.T)
    top_k_indices = np.argsort(cosine_sim, axis=1)[:, -k:]
    ground_truth_indices = np.arange(len(smiles_embeddings)).reshape(-1, 1)
    recall_k = np.mean(np.any(top_k_indices == ground_truth_indices, axis=1))
    return recall_k

def mean_reciprocal_rank(smiles_embeddings, text_embeddings):
    cosine_sim = np.dot(smiles_embeddings, text_embeddings.T)
    sorted_indices = np.argsort(-cosine_sim, axis=1)
    ranks = np.where(sorted_indices == np.arange(len(smiles_embeddings)).reshape(-1, 1))[1] + 1
    mrr = np.mean(1 / ranks)
    return mrr

def retrieval_accuracy(smiles_embeddings, text_embeddings):
    cosine_sim = np.dot(smiles_embeddings, text_embeddings.T)
    retrieved_indices = np.argmax(cosine_sim, axis=1)
    ground_truth_indices = np.arange(len(smiles_embeddings))
    return accuracy_score(ground_truth_indices, retrieved_indices)

def extract_embeddings(model, dataloader, device):
    model.eval()
    smiles_embeddings = []
    text_embeddings = []
    with torch.no_grad():
        for batch in dataloader:
            smiles_input = batch['title_input_ids'].to(device)
            text_input = batch['abstract_input_ids'].to(device)
            smiles_mask = batch['title_attention_mask'].to(device)
            text_mask = batch['abstract_attention_mask'].to(device)
            z_smiles = model(smiles_input, attention_mask=smiles_mask).last_hidden_state[:, 0, :]
            z_text = model(text_input, attention_mask=text_mask).last_hidden_state[:, 0, :]
            z_smiles = F.normalize(z_smiles, p=2, dim=-1).cpu().numpy()
            z_text = F.normalize(z_text, p=2, dim=-1).cpu().numpy()
            smiles_embeddings.append(z_smiles)
            text_embeddings.append(z_text)
    smiles_embeddings = np.vstack(smiles_embeddings)
    text_embeddings = np.vstack(text_embeddings)
    return smiles_embeddings, text_embeddings

def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            smiles_input = batch['title_input_ids'].to(device)
            text_input = batch['abstract_input_ids'].to(device)
            smiles_mask = batch['title_attention_mask'].to(device)
            text_mask = batch['abstract_attention_mask'].to(device)
            z_smiles = model(smiles_input, attention_mask=smiles_mask).last_hidden_state[:, 0, :]
            z_text = model(text_input, attention_mask=text_mask).last_hidden_state[:, 0, :]
            z_smiles = F.normalize(z_smiles, p=2, dim=-1)
            z_text = F.normalize(z_text, p=2, dim=-1)
            loss = contrastive_loss(z_smiles, z_text)
            total_loss += loss.item()
    return total_loss / len(dataloader)


from torch.utils.data import Dataset

class TitleAbstractContrastiveDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        title = self.data.iloc[idx]["title"]
        abstract = self.data.iloc[idx]["text"]
        
        title_encoding = self.tokenizer(
            title,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        abstract_encoding = self.tokenizer(
            abstract,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        return {
            "title_input_ids": title_encoding["input_ids"].squeeze(0),
            "title_attention_mask": title_encoding["attention_mask"].squeeze(0),
            "abstract_input_ids": abstract_encoding["input_ids"].squeeze(0),
            "abstract_attention_mask": abstract_encoding["attention_mask"].squeeze(0)
        }


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_paths = {
#  'bert-base-uncased': 'bert-base-uncased',
#  'ChemBERTa-MLM': 'DeepChem/ChemBERTa-77M-MLM',
#  'ChemBERTa-MTR': 'DeepChem/ChemBERTa-77M-MTR',
#  "MoLFormer":'ibm-research/MoLFormer-XL-both-10pct',
#  'SciBERT uncased': 'allenai/scibert_scivocab_uncased',
#  'SciBERT cased': 'allenai/scibert_scivocab_cased',
#  'ModernBERT': 'answerdotai/ModernBERT-base',
#  'ModernBERT base 500k': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-1n4g-no-tags',
#  'ModernBERT base Procedures 500k': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-new-procedure-500k-notags',
#  'ModernBERT base Procedures 1M5': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-new-procedure-1M5-notags',
#  'ModernBERT base Procedures 1M5 10 epochs': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-new-procedure-10-epochs-notags',
#  'ModernBERT base Procedures 1M5 20 epochs': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-new-procedure-20-epochs-notags',
#  'ModernBERT base Procedures 1M5 30 epochs': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-new-procedure-30-epochs-notags',
#  'ModernBERT base Procedures 1M5 40 epochs': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-new-procedure-40-epochs-notags',
#  'ModernBERT base Procedures 1M5 50 epochs': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-new-procedure-50-epochs-notags',
#  'ModernBERT base Procedures 1M5 60 epochs': '/home/david/modernbert_chemistry/fineweb/fine-web-modernbert-base-8192-multi-tok-new-procedure-60-epochs-notags',
}

train_dataset = TitleAbstractContrastiveDataset(train_df, None)
val_dataset   = TitleAbstractContrastiveDataset(val_df, None)
test_dataset  = TitleAbstractContrastiveDataset(test_df, None)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader   = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader  = DataLoader(test_dataset, batch_size=16, shuffle=False)

temperature_values = [0.1] 
results = [] 

for model_name, model_path in model_paths.items():
    print(f"\n===== Running experiments for model: {model_name} =====")
    
    for temp in temperature_values:
        print(f"\n🔹 Training {model_name} with temperature τ = {temp}")

        def contrastive_loss(z1, z2, temperature=0.1):
            z1 = F.normalize(z1, p=2, dim=-1)
            z2 = F.normalize(z2, p=2, dim=-1)

            sim_matrix = torch.matmul(z1, z2.T) / temperature
            sim_matrix = sim_matrix.clamp(-10.0, 10.0)
            sim_matrix = sim_matrix.float()
            labels = torch.arange(z1.size(0), device=sim_matrix.device)
            return F.cross_entropy(sim_matrix, labels)

        tokenizer = AutoTokenizer.from_pretrained(model_path)

        train_dataset.tokenizer = tokenizer
        val_dataset.tokenizer = tokenizer
        test_dataset.tokenizer = tokenizer

        model = AutoModel.from_pretrained(model_path, output_hidden_states=True)
        model.to(device)
        
        optimizer = optim.AdamW(model.parameters(), lr=2e-5)
        scaler = GradScaler()
        accumulation_steps = 4

        for epoch in range(5):
            total_train_loss = 0
            model.train()
            print(f"\nEpoch {epoch+1} [Training]:")
            train_progress = tqdm(train_dataloader, desc=f"Epoch {epoch+1} [Training]")

            for step, batch in enumerate(train_progress):
                title_input = batch["title_input_ids"].to(device)
                abstract_input = batch["abstract_input_ids"].to(device)
                title_mask = batch["title_attention_mask"].to(device)
                abstract_mask = batch["abstract_attention_mask"].to(device)
                
                with autocast():
                    z_title    = model(
                                    title_input,
                                    attention_mask=title_mask
                                ).last_hidden_state[:, 0, :]   
                    z_abstract = model(
                                    abstract_input,
                                    attention_mask=abstract_mask
                                ).last_hidden_state[:, 0, :]
                    loss = contrastive_loss(z_title, z_abstract) / accumulation_steps

                scaler.scale(loss).backward()
                total_train_loss += loss.item() * accumulation_steps
                if (step + 1) % accumulation_steps == 0 or step == len(train_dataloader) - 1:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad(set_to_none=True)
                train_progress.set_postfix(loss=loss.item())

            avg_train_loss = total_train_loss / len(train_dataloader)
            print(f"\nEpoch {epoch+1} [Training] Average Loss: {avg_train_loss:.4f}")
            avg_val_loss = evaluate(model, val_dataloader, device)
            print(f"Epoch {epoch+1} [Validation] Loss: {avg_val_loss:.4f}")
            torch.cuda.empty_cache()

        test_smiles, test_text = extract_embeddings(model, test_dataloader, device)
        top1_acc = retrieval_accuracy(test_smiles, test_text)
        mrr_score = mean_reciprocal_rank(test_smiles, test_text)
        recall5_score = recall_at_k(test_smiles, test_text, k=5)

        exp_result = {
            "Model": model_name,
            "Temperature": temp,
            "Top-1 Accuracy": top1_acc,
            "MRR": mrr_score,
            "Recall@5": recall5_score
        }
        results.append(exp_result)
        
        save_path_1 = f"/home/david/modernbert_chemistry/fineweb/results_allmodels/{model_name}_contrastive-pca-_title_text_1-temp-{temp:.2f}.png"
        plot_pca(test_smiles, test_text, f"PCA Plot for {model_name} at Temp {temp}", save_path_1)


        del model, test_smiles, test_text
        torch.cuda.empty_cache()


results_df = pd.DataFrame(results)

fig, ax = plt.subplots(figsize=(14, 8))

bar_width = 2.8
spacing_factor = 10

indices = np.arange(len(results_df)) * spacing_factor
cmap = plt.colormaps["plasma_r"]

bars1 = plt.bar(indices, results_df["Top-1 Accuracy"], bar_width, label="Top-1 Accuracy", color=cmap(0.2))
bars2 = plt.bar(indices + bar_width, results_df["MRR"], bar_width, label="MRR", color=cmap(0.5))
bars3 = plt.bar(indices + 2 * bar_width, results_df["Recall@5"], bar_width, label="Recall@5", color=cmap(0.8))

plt.xlabel("Experiment Index", size=15)
plt.ylabel("Score", size=15)
plt.title("Contrastive Retrieval Performance", size=18)
plt.yticks(fontsize=13)
plt.xticks(indices + bar_width, results_df["Model"], rotation=45, ha='right', size=14)


for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width()/2.,
            height,
            f'{height*100:.1f}%',
            ha='center',
            va='bottom',
            fontsize=12
        )

plt.legend(
    fontsize=14,
    loc="upper left",
    bbox_to_anchor=(1.05, 1.0),
    borderaxespad=0
)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

