# Influence Function-based Data Selection for Model Enhancement

Now you understand how our model performs on Lipophilicity dataset through `the MLM regression training notebook`. The goal in this task is to further enhance the performance by selecting external datapoints for training. You can use the [`task2.py`](../scripts/Task2.py) file for your implementation.


There is an external dataset (provided in the file [`External-Dataset_for_Task2.csv`](../tasks/External-Dataset_for_Task2.csv)) with molecular SMILES strings and corresponding lipophilicity values that we can include in the training process. However, we suspect that not all external data points are relevant. We aim to only select those that will likely improve the model's performance.  To achieve this, we will use **influence functions** to compute the impact of each external data point on the model’s behavior. This will help us identify the most valuable data points for training. By influence functions, we can analyze the distribution of influence scores and identify the high-impact samples (e.g., top-k positively scored samples). For influence computation, we refer to [Koh & Liang’s paper (2017)](https://arxiv.org/abs/1703.04730) on influence functions to calculate and log the influence scores for all samples in the external dataset.

The calculation of the influence function involves three main steps: computing the gradient of the training loss with respect to the model parameters, estimating the inverse of the Hessian matrix, and combining these to evaluate the effect of the training point on the test loss. The challenge for using it in deep neural networks is that storing and inverting the Hessian requires \( O(d^3) \) operations, where \( d \) is the number of model parameters, making it infeasible to compute for large neural networks. To address this, Koh & Liang (2017) proposed approximating the inverse Hessian-vector product (iHVP) using techniques like **Stochastic estimation/LiSSA** [(Agarwal et al., 2016)](https://arxiv.org/abs/1602.03943).

1. Your task is to compute the influence scores for each data point in the external dataset using the LiSSA approximation. This will help us identify which external samples are most influential in improving the model's performance. For this, you will:
- use the trained model and the external dataset.
- compute the gradients for each data point in the external dataset.
- use the LiSSA approximation to estimate the influence of each external sample on the model's performance on the test set.

2. Once the influence scores for the external dataset are computed, we can combine the high-impact samples selected with the Lipophilicity training dataset and fine-tune the model again. We can then evaluate the model’s performance on the Lipophilicity test set and compare it to the baseline in `the MLM regression training notebook`.

In [3]:
from google.colab import drive
# drive.mount('/content/drive/')# Note: Commented out for local execution. Uncomment if using Google Colab.


Mounted at /content/drive/


In [4]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

## Computing Gradients


In [None]:
import torch
import sklearn
import os
import datasets
import numpy as np
import transformers
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from transformers import AutoModel
from sklearn.model_selection import train_test_split
import torch.nn as nn

class MoLFormerRegressor(nn.Module):
    def __init__(self, model_name, model_path):
        super(MoLFormerRegressor, self).__init__()
        self.transformer = AutoModel.from_pretrained(model_path, trust_remote_code=True)
        self.regression_head = nn.Linear(self.transformer.config.hidden_size, 1)  # Maps 768 → 1

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        if hasattr(outputs, 'pooler_output'):
            x = outputs.pooler_output  # Common choice for regression
        else:
            x = outputs.last_hidden_state[:, 0, :]  # Use CLS token if no pooler_output
        x = self.regression_head(x)  # Linear layer outputs a single value per sample
        return x.squeeze()  # Ensure shape is [batch_size]

DATASET_PATH = "scikit-fingerprints/MoleculeNet_Lipophilicity"
MODEL_NAME = "ibm/MoLFormer-XL-both-10pct"

# Load the external dataset
ext_data = pd.read_csv("./data/external_dataset.csv")

########################################################
## Entry point
########################################################

if __name__ == "__main__":
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the trained regression model
    model_path = "./models/mlm_finetuned_molformer"  # Path to the fine-tuned model

    # Load the fine-tuned model
    # regression_model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(device)
    regression_model = MoLFormerRegressor(MODEL_NAME, model_path).to(device)

    # Load the external dataset
    ext_smiles = ext_data['SMILES'].tolist()
    ext_targets = ext_data['Label'].tolist()

    # ext_targets = [str(element) for element in ext_targets]

    # Tokenizer for SMILES strings
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

    # Define the dataset class
    class ExternalDataset(Dataset):
        def __init__(self, smiles_list, targets, tokenizer, max_length=128):
            self.smiles_list = smiles_list
            self.targets = targets
            self.tokenizer = tokenizer
            self.max_length = max_length

        def __len__(self):
            return len(self.smiles_list)

        def __getitem__(self, idx):
            smiles = self.smiles_list[idx]
            target = self.targets[idx]

            # Tokenize the SMILES string
            encoding = self.tokenizer(
                smiles,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            )

            # Convert the target to a tensor
            target = torch.tensor(target, dtype=torch.float32)

            return {
                'input_ids': encoding['input_ids'].squeeze(0),
                'attention_mask': encoding['attention_mask'].squeeze(0),
                'target': target
            }

    # Create the external dataset and DataLoader
    ext_dataset = ExternalDataset(ext_smiles, ext_targets, tokenizer)
    ext_loader = DataLoader(ext_dataset, batch_size=1, shuffle=False)

    # Define the loss function
    criterion = torch.nn.MSELoss()


    # Function to compute gradients
    def compute_gradients(model, data_loader):
      gradients = []
      model.eval()

      for i,batch in enumerate(tqdm(data_loader, desc="Computing gradients")):
          input_ids = batch['input_ids'].to(device)
          attention_mask = batch['attention_mask'].to(device)
          targets = batch['target'].to(device)

          # Forward pass
          predictions = model(input_ids, attention_mask)  # Now directly returns a tensor

          # Ensure predictions and targets have the same shape
          predictions = predictions.view(-1)  # Flatten to match targets
          targets = targets.view(-1)

          loss = criterion(predictions, targets)

          # Backward pass to compute gradients
          model.zero_grad()
          loss.backward()
          path = "./outputs/gradients"

# Create the folder if it does not exist
          os.makedirs(path, exist_ok=True)
          # Store gradients
          gradients = [param.grad.clone().cpu() for param in model.parameters()]
          torch.save(gradients, "./outputs/gradients/grad_"+str(i)+".pth")


    # Compute gradients for the external dataset
    ext_gradients = compute_gradients(regression_model, ext_loader)

# Influence Score and Final Training

In [None]:
import os
# Function to approximate inverse Hessian-vector product (iHVP) using LiSSA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
regression_model = MoLFormerRegressor(MODEL_NAME, model_path).to(device)

ext_data = pd.read_csv("./data/external_dataset.csv")
original_data = pd.read_csv("./data/lipophilicity_dataset.csv")

ext_smiles = ext_data['SMILES'].tolist()
ext_targets = ext_data['Label'].tolist()

orig_smiles = original_data['SMILES'].tolist()
orig_targets = original_data['label'].tolist()

train_smiles, test_smiles, train_targets, test_targets = train_test_split(
    orig_smiles, orig_targets, test_size=0.2, random_state=42
)
def lissa_inverse_hvp(v, model, data_loader, num_samples=20, recursion_depth=10, scale=1e-4):
    ihvp = [torch.zeros_like(p) for p in model.parameters()]

    for _ in range(recursion_depth):
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = batch['target'].to(device)

            # Forward pass
            predictions = model(input_ids, attention_mask)
            loss = criterion(predictions.squeeze(), targets)

            # Compute gradients
            model.zero_grad()
            grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)

            # Compute Hessian-vector product (HVP)
            hvp = torch.autograd.grad(
            grads, model.parameters(), grad_outputs=v, retain_graph=True
            )

            # Update iHVP estimate
            for i in range(len(ihvp)):
                ihvp[i] = v[i].cuda() + (1 - scale) * ihvp[i] - hvp[i] / num_samples

    return ihvp

# Function to compute influence scores
def compute_influence_scores(model, ext_gradients, train_loader):
    influence_scores = []

    for ext_grad in tqdm(sorted(os.listdir(ext_gradients)), desc="Computing influence scores"):
        # Approximate iHVP using LiSSA
        ext_grad = torch.load(os.path.join(ext_gradients,ext_grad))
        ihvp = lissa_inverse_hvp(ext_grad, model, train_loader)

        # Compute influence score
        influence_score = -sum([torch.sum(g.cuda() * h).item() for g, h in zip(ext_grad, ihvp)])
        influence_scores.append(influence_score)

    return influence_scores

# Load the training dataset (replace with your actual training dataset)
train_dataset = ExternalDataset(train_smiles, train_targets, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataset = ExternalDataset(test_smiles, test_targets, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)
ext_gradients_path = "./outputs/gradients"
# Compute influence scores for the external dataset
influence_scores = compute_influence_scores(regression_model, ext_gradients_path, test_loader)

# Select top-k high-impact samples
sorted_indices = np.argsort(influence_scores)[::-1]  # Descending order
top_k = 20 # Select top 20 high-impact samples
selected_indices = sorted_indices[:top_k]

# Get the selected samples
selected_smiles = [ext_smiles[i] for i in selected_indices]
selected_targets = [ext_targets[i] for i in selected_indices]

# Combine selected external samples with the original training dataset
combined_smiles = train_smiles + selected_smiles
combined_targets = train_targets + selected_targets

# Create the combined dataset and DataLoader
combined_dataset = ExternalDataset(combined_smiles, combined_targets, tokenizer)
combined_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True)

# Fine-tune the model on the combined dataset
optimizer = torch.optim.Adam(regression_model.parameters(), lr=1e-4)
NUM_EPOCHS = 20

for epoch in range(NUM_EPOCHS):
    regression_model.train()
    epoch_loss = 0.0

    for batch in combined_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['target'].to(device)

        optimizer.zero_grad()
        predictions = regression_model(input_ids, attention_mask)
        loss = criterion(predictions.squeeze(), targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}], Loss: {epoch_loss / len(combined_loader):.4f}")

# Evaluate the fine-tuned model
regression_model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['target'].to(device)

        predictions = regression_model(input_ids, attention_mask)
        all_predictions.extend(predictions.squeeze().cpu().numpy())
        all_targets.extend(targets.cpu().numpy())

# Calculate evaluation metrics
mse = mean_squared_error(all_targets, all_predictions)
rmse = np.sqrt(mse)
mae = mean_absolute_error(all_targets, all_predictions)
r2 = r2_score(all_targets, all_predictions)

print(f"Mean Squared Error (MSE): {mse:.4f}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
print(f"Mean Absolute Error (MAE): {mae:.4f}")
print(f"R² Score: {r2:.4f}")