## 0. Install libraries and enable GPU

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display, clear_output
from IPython.core.display import HTML
import torch
import torch.nn as nn
import torch.optim as optim
import optuna
import random
import time
import re
import os
import gc
from torch.utils.data import DataLoader, TensorDataset
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score, roc_curve, auc, confusion_matrix, classification_report, roc_auc_score
from transformers import RobertaModel, RobertaTokenizerFast
from bertviz import head_view, model_view
from captum.attr import LayerIntegratedGradients, visualization
import pubchempy as pcp
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
import requests

# supress warnings:
import warnings
warnings.filterwarnings('ignore')

dir = '/home/yulia/Documents/'

In [None]:
# verify CUDA and PyTorch compatibility
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)
# Check if CUDA (GPU support) is available
if torch.cuda.is_available():
    print(f"Number of available GPUs: {torch.cuda.device_count()}")
  

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # set to the GPU ID you want to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
gc.collect()  # collect garbage
torch.cuda.empty_cache() # clear cuda cache

# 1. Load, preprocess and split the data

The SMILES will be sourced from the Wolfram Alpha API.
We will send the requests to the API for the largest data set (tve).
Since many of the substrates are the same for the three laccases, we will join the smiles from the tve data and fill in the gaps by sending requests to the API (it is faster than requesting all information).

Split the preprocessed data into training and test sets:
* training set: 80%
* test set: 20%

### Option 1: Load using Wolfram Alpha


```python
app_id = '2XHLXX-X33XRX4XXX' # valid Wolfram Alpha API credentials (api key):

def get_smiles(ids):
    
    query = ids + ' SMILES identifier' # query for the chemical
    url = f"http://api.wolframalpha.com/v2/query?input={query}&appid={app_id}&output=json" # Wolfram Alpha API endpoint
    # send the request to the Wolfram Alpha API:
    response = requests.get(url)
    if response.status_code == 200: # sucsess
        data = response.json()
        # the API response contains pods, the SMILES string is in one of them:
        pods = data.get("queryresult", {}).get("pods", [])
        try:
            ans = pods[1].get("subpods")[0].get('plaintext')
            return ans
        except:
            return 'Did not work'



# load the data, drop duplicates:

data_tve = pd.read_csv(os.path.join(dir, 'tve-smiles.csv'), encoding="utf-8", sep=';')
ids_tve = data_tve['IUPAC Name'].to_list()
smiles_tve = []
for ids in ids_tve:
    smiles_tve.append(get_smiles(ids))
data_tve['SMILES'] = smiles_tve  
  
# insert the result for N-(2,3,4,5,6-pentahydroxyhexyl)-3-(2,4-trihydroxyphenyl)propanamide manually:
data_tve.loc[data_tve['IUPAC Name']=='N-(2,3,4,5,6-pentahydroxyhexyl)-3-(2,4-trihydroxyphenyl)propanamide','SMILES'] = 'C1=CC(=C(C(=C1CCC(=O)NCC(C(C(C(CO)O)O)O)O)O)O)O'
data_tve = data_tve.drop_duplicates().reset_index(drop=True)



data_mth = pd.read_csv(os.path.join(dir, 'mth-smiles.csv'), encoding="utf-8", sep=';')

# join the already obtained SMILES and fill in the gaps:
data_mth = data_mth.merge(data_tve[['IUPAC Name','SMILES']], how='left', on='IUPAC Name')
ids_mth = data_mth[data_mth['SMILES'].isna()]['IUPAC Name'].to_list()

if len(ids_mth)!=0:
    for ids in ids_mth:
        data_mth.loc[data_mth['IUPAC Name']==ids,'SMILES'] = get_smiles(ids)
    
data_mth['Oxd'].fillna(0, inplace=True) # replace the only MV with 0
data_mth['Oxd'] = data_mth['Oxd'].round().astype('int64')

data_mth = data_mth.drop_duplicates().reset_index(drop=True)



data_bpu = pd.read_csv(os.path.join(dir, 'bpu-smiles.csv'), encoding="utf-8", sep=';')
data_bpu = data_bpu.merge(data_tve[['IUPAC Name','SMILES']], how='left', on='IUPAC Name')

# there should be no NANs, but just in case:
ids_bpu = data_bpu[data_bpu['SMILES'].isna()]['IUPAC Name'].to_list()

if len(ids_bpu)!=0:
    for ids in ids_bpu:
        data_bpu.loc[data_bpu['IUPAC Name']==ids,'SMILES'] = get_smiles(ids)
        
data_bpu = data_bpu.drop_duplicates().reset_index(drop=True)


# ensure that each substrate in all sets has recieved a corresponding SMILES string: 
assert (len(data_tve[data_tve['SMILES']=='Did not work']) + len(data_mth[data_mth['SMILES'].isna()]==True) + len(data_bpu[data_bpu['SMILES'].isna()]==True))==0
```

In [None]:
'''
# save the dataframes with SMILES:
data_tve.to_csv(os.path.join(dir, 'f-tve-smiles.csv'), index=False, encoding='utf-8', sep=';')
data_mth.to_csv(os.path.join(dir, 'f-mth-smiles.csv'), index=False, encoding='utf-8', sep=';')
data_bpu.to_csv(os.path.join(dir, 'bpu-lac-smiles.csv'), index=False, encoding='utf-8', sep=';') 
'''

### Option 2: Load from the files

In [None]:
data_tve = pd.read_csv(os.path.join(dir, 'f-tve-smiles.csv'), encoding="utf-8", sep=';')
data_mth = pd.read_csv(os.path.join(dir, 'f-mth-smiles.csv'), encoding="utf-8", sep=';')
data_bpu = pd.read_csv(os.path.join(dir, 'bpu-lac-smiles.csv'), encoding="utf-8", sep=';')

### Split the data into train/test

In [None]:
# define the split ratios and split the dataset into training, validation and test sets:
train_data_tve, test_data_tve, train_labels_tve, test_labels_tve = train_test_split(data_tve['SMILES'], data_tve['Oxd'], test_size=0.2, random_state=98765)
train_data_mth, test_data_mth, train_labels_mth, test_labels_mth = train_test_split(data_mth['SMILES'], data_mth['Oxd'], test_size=0.2, random_state=9876)
train_data_bpu, test_data_bpu, train_labels_bpu, test_labels_bpu = train_test_split(data_bpu['SMILES'], data_bpu['Oxd'], test_size=0.2, random_state=987)

## 2. Fine-tune a pre-trained transformer on the SMILES data

We will be using a pre-trained ChemBERTa (RoBERTa trained on SMILES) + binary classification top layer (with a dropout).
The tuning of the model parameters (weights and biases) follows via minimizing the adjusted for the class imbalance cross-entropy loss function.

We will be selecting the best combination of the following parameters:

* `batch_size`
* `learning_rate`
* `dropout_rate`
* `patience`.


#### Key Steps for the hyperparameter tuning:
* create a custom ChemBERTa model class inheriting from the RoBERTa trained on the `PubChem10M_SMILES_BPE_450k` dataset, which will be using a dropout layer and a binary classification layer on top 
* define the loop for hyperparameter tuning using the 5-fold CV, which splits the training data into 5 folds, and for each of the 5 splits does the following:
    * preprocess SMILES Data (tokenize, create attention masks, convert to pt tensors)
    * train the model using the 4 folds and compute average validation accuracy over the 5-th fold.



* ensure that finetuning implements early stopping by monitoring validation loss and the loss function is adjusted for the class weights to account for the class imbalance in the data.

The best set of parameters will be selected via random choice on a grid using 5-fold CV and the validation accuracy averaged over all epochs.

#### Train the final model and predict:
* rebuild the model with the best hyperparameters
* predict and evaluate on the test set.



#### Why combine dropout with early stopping?
Dropout is applied during the forward pass in training to regularize the network and helps to prevent overfitting by randomly setting a portion of neurons to zero during training, which forces the model to learn more robust features. 

Early stopping monitors the model's performance *(we control for F1 score) on the validation set during training and halts training when the validation performance stomps to improve or starts to degrade for a specified number of epochs (often called the patience parameter). It ensures that the model doesn't over-train and minimizes the risk of overfitting beyond the optimal point.

In [None]:
class CustomRoberta(nn.Module):
    
    def __init__(self, dropout_prob):
        
        super(CustomRoberta, self).__init__()
        self.checkpoint = 'seyonec/PubChem10M_SMILES_BPE_450k' 
        self.roberta = RobertaModel.from_pretrained(self.checkpoint, output_attentions=True, output_hidden_states=True).to(device)
        self.config = self.roberta.config 
        self.current_embeddings = self.roberta.get_input_embeddings()
        self.tokenizer = RobertaTokenizerFast.from_pretrained(self.checkpoint)
        self.dropout = nn.Dropout(dropout_prob).to(device)
        self.classifier = nn.Linear(self.roberta.config.hidden_size, 2).to(device)
       
        
        
    def prepare_data(self, data, labels, batch_size, shuffle):
        
        # tokenize the SMILES, extract input ids and attention masks:
        tokenized = self.tokenizer(data.tolist(), padding='max_length', truncation=True, max_length=72, return_tensors='pt')    # if we set max_length 72, only one SMILES (len=163) is truncated, but if set max_length to 165, CUDA runs out of memory
        input_ids = tokenized['input_ids']
        attention_mask = tokenized['attention_mask']
        # convert labels to tensors:
        labels = torch.tensor(labels.values)
        # create DataLoaders for batching:
        dataset = TensorDataset(input_ids, attention_mask, labels)
        
        def seed_worker(worker_id):
            worker_seed = 42 + worker_id
            np.random.seed(worker_seed)
            random.seed(worker_seed)
        
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=None, worker_init_fn=seed_worker)
        
        return loader

    
    def forward(self, input_ids, attention_mask=None):
        
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # [CLS] token representation
        dropped_output = self.dropout(pooled_output)  # add dropout
        logits = self.classifier(dropped_output)  # pass through classification layer
        
        return logits, outputs.attentions, outputs.hidden_states

In [None]:
# set seeds:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def select_model(trial, smiles, property):
    
    set_seed(42)
     
    # defne the tuning grid:
    try:
        learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, step=5e-5) 
        patience = trial.suggest_int('patience', 6, 12, step=1)
        batch_size = trial.suggest_int('batch_size', 8, 64, step=8)
        dropout_rate = trial.suggest_float('dropout_rate', 0.0, 0.6, step=0.1)
    except Exception as e:
        raise RuntimeError(f"Failed to suggest hyperparameters: {e}")
    

    model = CustomRoberta(dropout_prob=dropout_rate)
     
    # print trial information:
    print(f'Trial {trial.number+1}: lr = {learning_rate}, patience = {patience}, batch_size = {batch_size}, dropout = {dropout_rate}')
     
    # define class weights for loss adjustment:
    class_weights = torch.tensor(
                                 compute_class_weight(class_weight='balanced', classes=np.unique(property), y=property.values), 
                                 dtype=torch.float
                                 ).to(device)
    
    # shuffle indices before splitting the set:
    indices = property.index.to_list()
    np.random.shuffle(indices)
    fold_size = len(property)//5
    
    fold_f1s = []
        
    # the 5-CV loop:
    for i in range(5):

        clear_output(wait=True)
        start = i*fold_size
        end = (i+1)*fold_size if i!=4 else len(property)
        ind = indices[start:end] 

        # data for the folds and loaders:
        valid_dt, valid_labs = smiles[ind], property[ind]
        train_dt, train_labs = smiles.drop(ind), property.drop(ind)

        train_loader = model.prepare_data(train_dt, train_labs, batch_size, shuffle=True)
        valid_loader = model.prepare_data(valid_dt, valid_labs, batch_size, shuffle=False)
 
 
        # define the optimizer, modified loss and the top layer with dropout:
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
        loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)


        best_val_loss = float('inf')
        best_val_f1 = 0
        patience_counter = 0
        train_loss_hist, val_loss_hist, val_f1_hist = [], [], []

        # loop over epochs:
        for epoch in range(500):

            # train:
            model.train()
            train_tot_loss = 0 # accumulated loss over an epoch
            train_tot_samples = 0 # accumulated number of examples over an epoch

            for batch in train_loader:

                inputs_train, mask_train, labels_train = [b.to(device) for b in batch]
                logits_train, _, _ = model.forward(inputs_train, mask_train)
                loss_train = loss_fn(logits_train, labels_train)

                # backward pass and optimization:
                optimizer.zero_grad()
                loss_train.backward()
                optimizer.step()
                
                train_tot_loss += loss_train.item()  # accum loss over all batches in epoch
                train_tot_samples += len(labels_train)  # num examples over all batches in epoch
            
            train_tot_loss = train_tot_loss / train_tot_samples # avg training loss per example over epoch
            train_loss_hist.append(train_tot_loss)
            print(f'CV iter {i+1}/5, Epoch {epoch+1} \nAverage Training Loss: {train_tot_loss}')
            scheduler.step() 

            # validate:
            model.eval()
            val_tot_loss = 0
            val_tot_samples = 0
            all_labs, all_probs, all_preds = [], [], []
            with torch.no_grad():
                for batch in valid_loader: 

                    inputs_val, mask_val, labels_val = [b.to(device) for b in batch]
                    logits_val, _, _ = model.forward(inputs_val, mask_val)
                    loss_val = loss_fn(logits_val, labels_val)
                    val_tot_loss += loss_val.item()
                    val_tot_samples += len(labels_val)  # num examples over all batches in epoch
                    probs_val = torch.sigmoid(logits_val[:, 1]).cpu().numpy() # use logits for the positives
                    preds_val = np.array([int(i > 0.5) for i in probs_val])
                    all_labs.append(labels_val.cpu().numpy())   
                    all_probs.append(probs_val)
                    all_preds.append(preds_val)

            val_f1 = f1_score(np.concatenate(all_labs), np.concatenate(all_preds)) # validation F1 per epoch
            val_loss = val_tot_loss / val_tot_samples # validation loss per example (over epoch)
            val_loss_hist.append(val_loss)
            val_f1_hist.append(val_f1)

            print(f'Average Validation Loss per epoch: {val_loss}')
            print(f'Average Validation F1 per epoch: {val_f1}')

            # check early stopping condition:
            if (val_loss < best_val_loss) or (val_f1 > best_val_f1):  # if either loss decreases or tnr or accuracy increase, continue w/o petience penalty
                best_val_loss = val_loss
                best_val_f1 = val_f1
                patience_counter = 0 # reset patience
            else:
                patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

        # after the CV iteration, append the results:        
        fold_f1s.append(np.mean(val_f1_hist))

    mean_val_f1 = np.mean(fold_f1s)
    return mean_val_f1

### Fine-tune the model via Optuna

In [None]:
def tune_model(smiles, property, n_trials=75):
    study_ = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=42))
    study_.optimize(lambda trial: select_model(trial, smiles, property), n_trials)
    
    # print best parameters and best accuracy
    print("Best hyperparameters:", study_.best_params)
    print("Best validation accuracy:", study_.best_value)
    return study_

## 3. Rebuild the model using the optimal hyperparameters and make predictions 

Rebuild the model using the optimal hyperparameter values (without CV, using all available training data) and make predictions on the test data, compute the following metrics:
* accuracy, 
* precision, 
* recall, 
* F1 score, 
* ROC, 
* AUROC.


For the final model, we also visualize the attention (using BertViz) to better understand to which structures the model is paying attention to make a prediction.

In [None]:
# rebuild the model with the best parameters:
def rebuild_model(train_data, train_labels, test_data, test_labels, study):
    
    set_seed(42)
     
    best_model = CustomRoberta(dropout_prob=study.best_params['dropout_rate']).to(device)
    best_model.config.output_attentions = True
    best_model.config.output_hidden_states = True
    
    train_loader = best_model.prepare_data(train_data, train_labels, batch_size=study.best_params['batch_size'], shuffle=True)
    test_loader = best_model.prepare_data(test_data, test_labels, batch_size=len(test_data), shuffle=False)
    
    # modified loss:    
    class_weights = torch.tensor(
                                compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels.values),
                                dtype=torch.float
                                ).to(device)    
    loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.AdamW(best_model.parameters(), lr=study.best_params['learning_rate'])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
        
    best_tloss = float('inf')
    best_tf1 = 0
    patience = study.best_params['patience']
    patience_counter = 0
    
    # train:
    for epoch in range(500):
            
            best_model.train()
            total_loss = 0
            total_samples = 0
            
            all_labs, all_probs, all_preds = [], [], []
            
            for batch in train_loader:
                inputs_train, mask_train, labels = [b.to(device) for b in batch]

                # forward pass:
                logits, _, _ = best_model(input_ids=inputs_train, attention_mask=mask_train)
                loss = loss_fn(logits, labels)             
                
                # backward pass and optimization:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                total_samples += len(labels)
                probs = torch.sigmoid(logits[:, 1]).detach().cpu().numpy() # use logits for the positives
                preds = np.array([int(i > 0.5) for i in probs])
                all_labs.append(labels.cpu().numpy())   
                all_probs.append(probs)
                all_preds.append(preds)

            train_f1 = f1_score(np.concatenate(all_labs), np.concatenate(all_preds)) # validation loss per epoch
            train_loss = total_loss / total_samples # avg training loss per example over epoch
            
            
            print(f'Epoch {epoch+1}, Training Loss: {train_loss}, Training F1: {train_f1}')
            scheduler.step() 
            
            # check early stopping condition:
            if (train_loss < best_tloss) or (train_f1 > best_tf1):
                best_tloss = train_loss
                best_tf1 = train_f1
                patience_counter = 0  # reset patience
                
                    
                if os.path.exists('/home/yulia/Documents/runs/best_model.pth'):
                    os.remove('/home/yulia/Documents/runs/best_model.pth')
                else:
                    print(f"{'/home/yulia/Documents/runs/best_model.pth'} does not exist and will be created.")
                torch.save(best_model.state_dict(), '/home/yulia/Documents/runs/best_model.pth')  # save the best model
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print("Early stopping")
                break


    # restore the best weights and evaluate on the hold-out set:
    best_model.load_state_dict(torch.load('/home/yulia/Documents/runs/best_model.pth', map_location=device))

    # test:
    best_model.eval()
    logits_, probs_, true_labels = [], [], []
    
    with torch.no_grad():
        
        for batch in test_loader:  
            
            inputs, mask, labels = [b.to(device) for b in batch]
            logits, attentions, h_states = best_model.forward(inputs, mask)
            probs = torch.sigmoid(logits[:, 1]).cpu().numpy()  # use the logits for the 2nd class (positive)          

            logits_.append(logits.cpu().numpy())
            probs_.append(probs)
            true_labels.append(labels.cpu().numpy())
            
    logits_ = np.concatenate(logits_)
    probs_ = np.concatenate(probs_)
    true_labels = np.concatenate(true_labels)        
    
    return best_model, logits_, attentions, h_states, inputs, mask, probs_, true_labels

In [None]:
sns.set_style('darkgrid')
from sklearn.metrics import auc
def plot_roc(fpr, tpr, cm):

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    
    # calculate ROC AUC
    roc_auc = auc(fpr, tpr)
    
    # Plot the ROC curve using matplotlib's ax.plot
    ax1.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc, color=[0.72110727, 0.11649366, 0.2828143, 0.85])
    ax1.fill_between(fpr, 0, tpr, alpha=0.2, color=[0.95686275, 0.42745098, 0.2627451, 0.85])  # auc shading
    ax1.plot([0, 1], [0, 1], linestyle='--', color=[0.28742791, 0.41499423, 0.68512111, 0.85])  # random classifier
    
    # adjust ticks, labels, and title
    ax1.tick_params(axis='both', labelsize=8)
    ax1.set_xlabel('False Positive Rate', fontsize=9)
    ax1.set_ylabel('True Positive Rate', fontsize=9)
    ax1.set_title('ROC Curve, AUC = %0.2f' % roc_auc, fontsize=11)
    ax1.legend(loc='lower right')
    
    # plot the confusion matrix heatmap
    sns.heatmap(cm, annot=True, fmt='d', cmap='RdYlBu_r', alpha=0.85, annot_kws={"fontsize":8}, ax=ax2) # to remove the colorbar, add: cbar=False
    cbar = ax2.collections[0].colorbar
    cbar.ax.tick_params(labelsize=8)
    ax2.tick_params(axis='both', labelsize=8)
    ax2.set_xlabel('Predicted', fontsize=9)
    ax2.set_ylabel('True', fontsize=9)
    ax2.set_title('Confusion Matrix', fontsize=11)
    
    fig.show()


# compute metrics:
def evaluate_model(true_labels, probs, threshold=0.5):
    
    predictions = [int(i > threshold) for i in probs] 
    true_labels = [int(i) for i in true_labels]  # ensure labels are integers
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    
    # ROC and AUC:
    fpr, tpr, thresholds = roc_curve(np.array(true_labels), np.array(probs))
    roc_auc = auc(fpr, tpr)
    print(f'Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1: {f1}, AUC: {roc_auc}')
    
    # compute the confusion matrix:
    cm = confusion_matrix(np.array(true_labels), predictions)
    # print the classification report:
    report = classification_report(np.array(true_labels), predictions, target_names=['Class 0', 'Class 1'])
    print(report)
    
    plot_roc(fpr, tpr, cm)

## 4. Example use
Fine-tune, rebuild with optimized parameters and evaluate (for the f-tve data)

In [None]:
# define optuna study and optimize
study_tve = tune_model(train_data_tve, train_labels_tve)
final_model_tve, logits_tve, attentions_tve, h_states_tve, inputs_tve, mask_tve, probs_tve, true_labels_tve = rebuild_model(train_data_tve, train_labels_tve, test_data_tve, test_labels_tve, study_tve)
evaluate_model(true_labels_tve, probs_tve)

## 5. Attention and LIG analysis

### Attention Visualization (BertViz)


By setting `idx=14` we choose the 15th sample in the test set.


```python
# tokenize

def tokenize(smiles): 
    return tokenizer(smiles, padding=True, truncation=True, max_length=180, return_tensors='pt').to(device)

```

#### Key Aspects of `head_view`:


* Transformers have multiple attention heads in each layer. Each head learns to focus on different parts of the input text, capturing distinct patterns or relationships (e.g., coreference, syntax, or semantic connections). The head_view allows you to see the output of each head in every layer, displaying what each head "attends to."

* It visualizes the attention scores, which indicate how much each token attends to every other token in the input sequence. These scores are typically represented as heatmaps, with intensity showing the magnitude of attention.

* You can inspect specific layers (e.g., layer 3, head 5) to see attention patterns at different stages of processing. Early layers may focus on local and syntactic relationships, while later layers often capture broader, more semantic dependencies.

* The tokens of the input sequence are displayed along both axes of the attention heatmap.
Rows correspond to query tokens (tokens being "attended to"), while columns represent key tokens (tokens "providing attention").
This setup lets you see which tokens each query token attends to most strongly.



#### Practical Insights from `head_view`:

* Interpretation of Model Behavior:
By examining which tokens a specific token attends to, you can infer how the model processes relationships (e.g., syntactic dependencies like subject-verb or coreferences in pronouns).

* Understanding Linguistic Representation:
`head_view` reveals how attention is distributed, offering a window into how the model captures semantic and syntactic features.


In [None]:
def vis_attentions(idx, inputs, mask, attentions):

    tokenizer = RobertaTokenizerFast.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k')
    tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs.tolist()]
    tokens = np.array(tokens[idx]) # tokens for the selected sample

    # truncate to content (remove <pad>):
    msk = mask[idx].cpu().numpy()
    tokens_truncated = tokens[msk!=0]
    extended_mask = mask.unsqueeze(1).unsqueeze(2) # extend the mask for attention computation
    masked_attentions = [attn * extended_mask for attn in attentions]

    # slice attention tensors for the selected example (idx = ...)
    sliced_attentions = []
    for attn in masked_attentions: # iterate over layers
        # extract attention for the selected idx only:
        attn_idx = attn[idx].cpu() # shape: [num_heads, seq_len, seq_len]
        l = int(msk.sum())
        sliced_attentions.append(attn_idx[:, :l, :l])


    # stack attention tensors for BERTViz:
    selected_attentions = tuple([layer_attn.unsqueeze(0) for layer_attn in sliced_attentions])
    return head_view(selected_attentions, tokens_truncated, html_action='return') 

In [None]:
vis_attentions(14, inputs_tve, mask_tve, attentions_tve) # 15-th example in the test set

### Visualize attention weights

In [None]:
num_layers = 6
num_heads = 12
idx = 14

tokenizer = RobertaTokenizerFast.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k')
tokens_tve = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs_tve.tolist()]
tokens_tve = np.array(tokens_tve[idx]) # tokens for the selected sample
msk = mask_tve[idx].cpu().numpy()
tokens_tve = tokens_tve[msk!=0]
seq_len = len(tokens_tve)

def plot_attention_weights(layer_idx, head_idx):
    # Extract the attention weights for the specific layer and head
    attention_weights = attentions_tve[layer_idx][idx][head_idx].cpu().detach().numpy()
    attention_weights = attention_weights[:seq_len, :seq_len]
    # Plot the heatmap for the 1st layer and 1st attention head:
    plt.figure(figsize=(5, 4))
    sns.heatmap(attention_weights, cmap='Spectral_r', cbar_kws={'shrink': 0.8})
    plt.title(f"Attention Weights - Layer {layer_idx + 1}, Head {head_idx + 1}", fontsize=12)
    plt.xlabel('Tokens', fontsize=10)
    plt.ylabel('Tokens', fontsize=10)
    plt.xticks(ticks=range(seq_len), labels=tokens_tve, fontsize=8, rotation=90)
    plt.yticks(ticks=range(seq_len), labels=tokens_tve, fontsize=8, rotation=0)
    plt.tight_layout()
    plt.show()

# Create sliders for selecting layer and head
layer_slider = widgets.IntSlider(
    value=0,  # Default value
    min=0,
    max=num_layers - 1,
    step=1,
    description='Layer:',
    continuous_update=False  # Update only on release
)

head_slider = widgets.IntSlider(
    value=0,  # Default value
    min=0,
    max=num_heads - 1,
    step=1,
    description='Head:',
    continuous_update=False  # Update only on release
)

# Bind sliders to the plotting function using `interactive`
interactive_plot = widgets.interactive(
    plot_attention_weights,
    layer_idx=layer_slider,
    head_idx=head_slider
)

# Display the sliders and plot
display(interactive_plot)

### Visualize integrated gradients *(f-tve example)

Define a generic function that generates attributions for each substrate (smiles) oxidation result and stores them in a list using `VisualizationDataRecord` class. 

In the `Captum.LayerIntegratedGradients` function, the target parameter refers to the output of the model that you want to compute the attributions for. Typically, this would be the index of the class label you're interested in, especially in a classification model. However, it does not necessarily have to be the true label.

For Classification Tasks:

If you're trying to explain the model's decision for a particular predicted class, you would set the target parameter to the index of that class (e.g., target=predicted_label).
If you want to compute attributions with respect to the true class (e.g., for post hoc analysis or debugging), you would set target=true_label.

In [None]:
def interpret(model, tokenized_smiles, mask, logits, target_label):
    
    def wrapped_model(inputs, masks):
        return model(inputs, masks)[0]
    
    # Pool out the embeddings. Which answers: Which token embeddings contributed to this prediction? 
    # The attributions visualizations are per-token
    lig = LayerIntegratedGradients(wrapped_model, model.roberta.embeddings) # args: (model, model.embedding)
    
    model.eval()  
    model.zero_grad()

    input_indices = tokenized_smiles.unsqueeze(0) if tokenized_smiles.dim() == 1 else tokenized_smiles
    mask = mask.unsqueeze(0) if mask.dim() == 1 else mask
    # generate reference indices for each sample
    pad_token_id = model.tokenizer.pad_token_id
    reference_indices = torch.full_like(input_indices, pad_token_id)


    logits = logits
    # compute attributions and approximation delta using layer integrated gradients
    attributions, delta = lig.attribute(inputs=input_indices, 
                                        baselines=reference_indices,             
                                        target=target_label,
                                        additional_forward_args=mask,
                                        n_steps=50, 
                                        return_convergence_delta=True)

    return attributions, delta




vis_data_records = []    

def add_attribs_to_vis(attribs, tokens, probs, pred_label, true_label, delta, vis_data_rec):
    attribs = attribs.sum(dim=2).squeeze(0)
    # attribs = attribs / torch.norm(attribs) # omit normalization for now
    attribs = attribs.cpu().detach().numpy()

    smiles_restored = re.search('<s>(.*)</s>', ''.join(tokens))
    smiles_restored = smiles_restored.group(1)
    # remove <s>, </s> from attribs
    attribs = attribs[1:-1]
    attribs = attribs / np.max(np.abs(attribs))  # normalize to [-1, 1]
    
    prob_pos = probs[1]#torch.sigmoid(logits[1]).cpu().numpy()
    # storing couple samples in an array for visualization purposes
    vis_data_rec.append(visualization.VisualizationDataRecord(
                            attribs,
                            prob_pos,
                            pred_label,
                            true_label,
                            1, # this is to denote that we are predicting the probability of the "Oxd=1" event
                            attribs.sum(),
                            smiles_restored,
                            delta))
    return vis_data_rec


In [None]:
threshold = 0.5
preds_tve = [int(i > threshold) for i in probs_tve] 

tokenizer = RobertaTokenizerFast.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k')
tokens_tve = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs_tve.tolist()]

In [None]:
valid_vis_data_records = []  # store only successful cases

for i in range(len(preds_tve)):
    attributions_tve, delta_tve = interpret(final_model_tve, inputs_tve[i], mask_tve[i], probs_tve[i], target_label=preds_tve[i])
    try:
        match = re.search(r'<s>(.*)</s>', ''.join(tokens_tve[i]))
        if not match:
            print(f"[Skipping] SMILES extraction failed for index {i}")
            continue
        smiles_restored = match.group(1)
        
        substrate_info = data_tve.loc[data_tve['SMILES'] == smiles_restored, 'Substrate Name']
        if substrate_info.empty:
            print(f"[Skipping] No substrate found for SMILES at index {i}")
            continue
        
        # ensure word-attribution length matches:
        if len(tokens_tve[i]) != len(attributions_tve.squeeze().tolist()):
            print(f"[Skipping] Length mismatch at index {i}: tokens={len(tokens_tve[i])}, attributions={len(attributions_tve.squeeze().tolist())}")
            continue
        
        
        add_attribs_to_vis(attributions_tve, tokens_tve[i], logits_tve[i], preds_tve[i], true_labels_tve[i], delta_tve, valid_vis_data_records)
    
    except Exception as e:
        print(f'Exception for case {i}: {str(e)}')
        continue
    
    
print("Visualizing attributions (Integrated Gradients)")
if valid_vis_data_records:
    _ = visualization.visualize_text(valid_vis_data_records)

In [None]:
def get_smiles_attribs(tokens, masks, attributions, vis_data_records, data):
    """
    Extracts the attributions for each character in the SMILES string.
    
    Args:
        smiles (str): The SMILES string.
        attributions (list): The list of attribution values corresponding to each character in the SMILES string.
        
    Returns:
        dict: A dictionary mapping each character in the SMILES string to its attribution value.
    """

    smiles_list, names_list, attr_list, tokens_list = [], [], [], []

    
    # The visualization function internally processes attributions
    attributions_for_vis = []
    for record in valid_vis_data_records:
    # This is the data format viz functions use
        attributions_for_vis.append(record.word_attributions)
            
            
    for i in range(len(tokens)):

        try:
            attrs = attributions_for_vis[i]  
            msk = masks[i].cpu().numpy()


            # Step 1: Apply mask to ignore padding
            valid_indices = np.where(msk == 1)[0]
            masked_attributions = attrs[valid_indices]
            masked_tokens = [tokens[i][z] for z in valid_indices]

            # Step 2: Filter non-zero attributions and special tokens
            non_zero_indices = non_zero_indices = np.where(
                (masked_attributions != 0) & 
                (~np.isin(masked_tokens, ['<s>', '</s>'])))[0]
            attr = [masked_attributions[j] for j in non_zero_indices]
            tckn = [masked_tokens[j] for j in non_zero_indices]

            # Step 3: Combine tokens and attributions
            smiles_restored = re.search('<s>(.*)</s>', ''.join(tokens[i]))
            smiles_restored = smiles_restored.group(1)
            subname = data.loc[data['SMILES'] == smiles_restored]['Substrate Name'].to_numpy()[0]

            smiles_list.append(smiles_restored)
            names_list.append(subname)
            attr_list.append(attr)
            tokens_list.append(tckn)
        except Exception as e:
            print(f'Error processing index {i}: {str(e)}')
            continue
    return smiles_list, names_list, attr_list, tokens_list

The error below just means that one or several SMILES, which soes not hinder the results rendering. 

In [None]:
smiles_list_tve, names_list_tve, attr_list_tve, tokens_list_tve = get_smiles_attribs(tokens_tve, mask_tve, attributions_tve, valid_vis_data_records, data_tve)

In [None]:
def expand_to_character_level(tokens, attributions):
    """
    Convert token-level attributions to character-level, preserving:
    - Multi-character chemical elements (Cl, Br, etc.)
    - Special symbols (parentheses, brackets)
    - While expanding other multi-character tokens
    
    Args:
        tokens: List of tokens (e.g., ['CC', 'Cl', '(', 'Br', ...])
        attributions: List of attribution values (same length as tokens)
    
    Returns:
        Tuple of (expanded_elements, expanded_attributions)
    """
    # Elements to keep as single tokens
    multi_elements = {'Cl', 'Br', 'Si', 'Se', 'Te', 'As', 'Pt', 
                     'Fe', 'Al', 'Zn', 'Li', 'Na', 'Ca', 'Ba', 
                     'Cu', 'Ag', 'Au', 'Sn', 'Mg'}
    
    # Special symbols to keep as single tokens
    special_symbols = {'(', ')', '[', ']', '=', '#', '@', '/', '\\', 
                      '-', '+', ':', '.', '%', '0', '1', '2', '3', 
                      '4', '5', '6', '7', '8', '9'}
    
    expanded_elements = []
    expanded_attrs = []
    
    i = 0
    while i < len(tokens):
        token = tokens[i]
        
        # Case 1: Multi-character element (e.g., 'Cl')
        if token in multi_elements:
            expanded_elements.append(token)
            expanded_attrs.append(attributions[i])
            i += 1
        
        # Case 2: Special symbol (keep as single token)
        elif token in special_symbols:
            expanded_elements.append(token)
            expanded_attrs.append(attributions[i])
            i += 1
        
        # Case 3: Multi-character token that should be expanded (e.g., 'CCCC')
        else:
            for char in token:
                # Check if character is part of a multi-element (like 'C' in 'Cl')
                if i+1 < len(tokens) and char + tokens[i+1] in multi_elements:
                    combined = char + tokens[i+1]
                    expanded_elements.append(combined)
                    expanded_attrs.append(attributions[i+1])
                    i += 2  # Skip next token since we've handled it
                else:
                    expanded_elements.append(char)
                    expanded_attrs.append(attributions[i])
            i += 1
    
    return np.array(expanded_elements), np.array(expanded_attrs)



expanded_tokens_list_tve = []
expanded_attrs_list_tve = []
for i in range(len(tokens_tve)):
    try:
        expanded_tokens_tve, expanded_attrs_tve = expand_to_character_level(tokens_list_tve[i], attr_list_tve[i])
        expanded_tokens_list_tve.append(expanded_tokens_tve)
        expanded_attrs_list_tve.append(expanded_attrs_tve)
    except Exception as e:
        print(f'Error processing index {i}: {str(e)}')
        continue  


In [None]:
def visualize_multiple_smiles_attributions(smiles_list, attributions_list, substrate_labels=None, figsize=(10, 14)):
    """
    Visualize attributions for multiple SMILES strings with substrate labels on the right.
    
    Args:
        smiles_list: List of SMILES token lists (e.g., [['CC', '1', ...], ['C', '=', 'O', ...]])
        attributions_list: List of attribution lists (same length as smiles_list)
        titles: Optional list of titles for each SMILES (default: "SMILES 1", "SMILES 2", ...)
        substrate_labels: Optional list of labels to display on the right (e.g., ["Substrate A", "Substrate B"])
        figsize: Figure size (width, height)
    """
    n_smiles = len(smiles_list)
    
    '''
    fig, axes = plt.subplots(n_smiles, 1, figsize=figsize, squeeze=False)
    axes = axes.flatten()
    '''
    # Create figure with adjusted layout
    fig = plt.figure(figsize=figsize)
    #gs = plt.GridSpec(n_smiles + 1, 1, height_ratios=[1]*n_smiles + [0.8])  # Last row for colorbar
    
    gs = plt.GridSpec(n_smiles, 1)  # Only allocate rows for SMILES
    axes = [fig.add_subplot(gs[i]) for i in range(n_smiles)]
    
    cbar_width=0.45
    cbar_ax = fig.add_axes([0.05, 0.01, cbar_width, 0.01]) 
    #cbar_ax = fig.add_subplot(gs[-1])  # Dedicated axis for colorbar
    
    if substrate_labels is None:
        substrate_labels = [''] * n_smiles  # Default to no label
    
    # Normalize all attributions together for consistent coloring
    all_attributions = np.concatenate(attributions_list)
    norm = plt.Normalize(vmin=min(all_attributions), vmax=max(all_attributions))
    

    cmap = plt.cm.RdYlBu  # Red for positive, Blue for negative
    for idx, (tokens, attributions, label) in enumerate(zip(smiles_list, attributions_list, substrate_labels)):
        ax = axes[idx]
        x_pos = 0
        
        # Plot colored SMILES tokens
        for token, attr in zip(tokens, attributions):
            color = cmap(norm(attr))
            ax.text(x_pos, 0.5, token, fontsize=10, ha='left', va='center',
                    bbox=dict(boxstyle='round,pad=0.1', facecolor=color, alpha=0.85))
            x_pos += len(token) * 0.02 + 0.02  # Adjust spacing based on token length
        
        # Add substrate label on the right
        ax.text(x_pos + 0.5, 0.5, label, fontsize=9, ha='center', va='center', color='black')
        ax.set_xlim(-0.05, x_pos + 1.0)  # Extra space for the label
        ax.set_ylim(0, 0.65)
        ax.axis('off')
        #ax.set_title(title)
    # Add a single shared colorbar at the bottom
    
    # Add colorbar to dedicated axis
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    
   
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal', label='Attribution Score')
    cbar.ax.tick_params(labelsize=9)  # Smaller font for colorbar
      # Adjust colorbar size

    fig.suptitle("SMILES Attributions", fontsize=14, y=1.02)
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.06)
    plt.show()

visualize_multiple_smiles_attributions(
    expanded_tokens_list_tve, 
    expanded_attrs_list_tve, 
    names_list_tve,
    figsize=(12, 14)
)

#### Mapping attributions from SMILES to the molecular graph 

The code below attempts to approximate the mapping of SMILES attributions onto their corresponding substructures in the molecular
graph. While this mapping is not perfectly precise, it offers a valuable visual aid for identifying which chemical motifs the
model considers influential in determining whether a substrate can be oxidized by the given laccase.



##### 1. SMILES Tokenization & Mapping to Molecular Graph

**Function:** `smiles_to_mol_mapping()`
 - maps SMILES characters to RDKit atoms/bonds.


**Key Steps:**

* RDKit Molecule Creation:

```python
mol = Chem.MolFromSmiles(smiles)  # Converts SMILES to RDKit Mol object
```

* Character-to-Atom Mapping (highlights):
    - Bracketed Atoms (e.g., [Na]): Treated as single tokens.
    - Alphabetic Characters (e.g., C, N): Mapped to atoms, handling two-letter symbols (e.g., Cl).
    - Bonds (=, #, etc.): Tracked but not directly mapped to RDKit bonds yet.
    - Branches (()) and Rings (digits): Handled via stack-based parsing.

**Output:**

* `char_to_atom`: Dict mapping SMILES character positions to RDKit atom indices.
* `char_to_bond`: Dict mapping SMILES positions to bond symbols (e.g., =, #).

##### 2. Attribution Score Processing

**Function:** `score_to_color()`
 - normalizes scores and assign colors (red = negative, blue = positive).

**Key Steps:**

* Robust Normalization: Uses 1st/99th percentiles to minimize outlier effects:

```python
vmin, vmax = np.percentile(all_scores_flat, [1, 99])
```

* Colormap: `RdYlBu` (Red-Yellow-Blue) for intuitive visualization.


##### 3. Mapping Attributions to Molecular Substructures

**Function:** `visualize_attribution_scores_improved()`
- aggregates token-level scores to atoms/bonds in the molecular graph.

* Atom-Level Mapping:
For each atom, average scores of all SMILES characters mapped to it (score aggregation):

```python
atom_scores[atom_idx].append(attribution_scores[char_pos])
```

* Thresholding: Ignores scores below min_threshold to reduce noise.

* Bond-Level Mapping (Complex):
    - Challenge: SMILES bond symbols (e.g., =) don’t directly correspond to RDKit bond indices.
    - Workaround: Checks if a bond character lies between two atoms in the SMILES string:

```python
if min(bpos, epos) < char_pos < max(bpos, epos):
    bond_highlights[bond.GetIdx()] = ...
```

##### 4. Visualization

**Functions:** `visualize_attribution_scores_improved()`, `create_smiles_attribution_comparison()`

**Outputs:**

* SMILES String: Colored by character-level attributions (left panel).
* Molecular Graph: Atoms/bonds highlighted based on aggregated scores (right panel).


**RDKit Rendering:**

* Uses `MolDraw2DCairo` for high-resolution rendering.
* Atom highlights are circles with radii scaled by score magnitude.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from PIL import Image
import io
plt.rcParams['figure.dpi'] = 800

def score_to_color(score, all_scores, alpha=0.85):
    """Convert attribution score to color with proper normalization"""
    if all_scores is not None and len(all_scores) > 0:
        # Use more robust normalization
        '''
        all_scores_flat = np.array(all_scores).flatten()
        all_scores_flat = all_scores_flat[~np.isnan(all_scores_flat)]  # Remove NaN
        '''
        all_scores_flat = []
        for item in expanded_attrs_list_tve:
            all_scores_flat.extend(item)

        all_scores_flat = np.array(all_scores_flat)
        
        if len(all_scores_flat) > 0:
            vmin, vmax = np.percentile(all_scores_flat, [1, 99])  # Use percentiles to avoid outliers
            if vmax == vmin:
                vmax = vmin + 1e-6
            
            # Normalize score
            
            norm_score = (score - vmin) / (vmax - vmin)
            norm_score = np.clip(norm_score, -1, 1)
            
            # Use RdYlBu colormap: Red for negative, Yellow for neutral, Blue for positive
            cmap = plt.cm.RdYlBu  # Reverse so red is negative, blue is positive
            return cmap(norm_score)
    
    return (0.8, 0.8, 0.8, alpha)  # Default gray

def smiles_to_mol_mapping(smiles):
    """
    Improved mapping from SMILES characters to molecular structure
    Returns cleaner mappings with better error handling
    """
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        raise ValueError(f"Invalid SMILES string: {smiles}")
    
    # Get atom mapping using RDKit's built-in functionality
    atom_mapping = {}
    bond_char_mapping = {}
    
    # Use RDKit's SMILES parser to get atom indices
    # This is more reliable than manual parsing
    try:
        # Create a mapping by parsing the SMILES step by step
        mol_with_idx = Chem.MolFromSmiles(smiles)
        
        # Simple character-to-atom mapping
        char_to_atom = {}
        char_to_bond = {}
        
        # Track current position in SMILES
        i = 0
        atom_idx = 0
        stack = []  # For branch tracking
        
        while i < len(smiles):
            char = smiles[i]
            
            # Handle bracketed atoms [...]
            if char == '[':
                bracket_start = i
                bracket_end = smiles.find(']', i) + 1
                if bracket_end > bracket_start:
                    for j in range(bracket_start, bracket_end):
                        char_to_atom[j] = atom_idx
                    i = bracket_end
                    atom_idx += 1
                else:
                    i += 1
                    
            # Handle regular atoms (including aromatic lowercase)
            elif char.isalpha():
                char_to_atom[i] = atom_idx
                # Check for two-letter elements
                if i + 1 < len(smiles) and smiles[i + 1].islower():
                    char_to_atom[i + 1] = atom_idx
                    i += 2
                else:
                    i += 1
                atom_idx += 1
                
            # Handle bonds
            elif char in '=-#:.':
                char_to_bond[i] = char
                i += 1
                
            # Handle branches
            elif char == '(':
                stack.append(atom_idx - 1)  # Push current atom
                i += 1
            elif char == ')':
                if stack:
                    stack.pop()
                i += 1
                
            # Handle ring closures
            elif char.isdigit():
                # Ring closure - this connects to a previous atom
                char_to_bond[i] = f'ring_{char}'
                i += 1
                
            # Handle stereochemistry and other symbols
            else:
                i += 1
        
        return mol, char_to_atom, char_to_bond
        
    except Exception as e:
        print(f"Error in SMILES parsing: {e}")
        return mol, {}, {}

def visualize_attribution_scores_improved(subname, smiles, attribution_scores, 
                                        min_threshold=0.01, debug=False):
    """
    Improved visualization of attribution scores with better mapping
    """
    
    # Ensure attribution scores match SMILES length
    if len(attribution_scores) != len(smiles):
        print(f"Warning: Length mismatch - SMILES: {len(smiles)}, Scores: {len(attribution_scores)}")
        # Simple alignment strategy
        if len(attribution_scores) < len(smiles):
            attribution_scores = list(attribution_scores) + [0.0] * (len(smiles) - len(attribution_scores))
        else:
            attribution_scores = attribution_scores[:len(smiles)]
    
    # Get molecule and mappings
    mol, char_to_atom, char_to_bond = smiles_to_mol_mapping(smiles)
    
    if debug:
        print(f"SMILES: {smiles}")
        print(f"Attribution scores: {attribution_scores}")
        print(f"Char to atom mapping: {char_to_atom}")
        print(f"Char to bond mapping: {char_to_bond}")
    
    # Aggregate scores for atoms
    atom_scores = {}
    for char_pos, atom_idx in char_to_atom.items():
        if char_pos < len(attribution_scores):
            if atom_idx not in atom_scores:
                atom_scores[atom_idx] = []
            atom_scores[atom_idx].append(attribution_scores[char_pos])
    
    # Average scores for each atom
    atom_highlights = {}
    atom_radii = {}
    
    for atom_idx, scores in atom_scores.items():
        avg_score = np.mean(scores)
        if abs(avg_score) > min_threshold:
            atom_highlights[atom_idx] = [score_to_color(avg_score, attribution_scores)]
            # Scale radius based on score magnitude
            atom_radii[atom_idx] = 0.4 
    
    # Handle bond highlighting (simplified)
    bond_highlights = {}
    for char_pos, bond_info in char_to_bond.items():
        if char_pos < len(attribution_scores):
            score = attribution_scores[char_pos]
            if abs(score) > min_threshold:
                # For now, we'll skip bond highlighting to focus on atoms
                # This is where the complexity lies - mapping character positions to actual bonds
                # Map bond characters in SMILES to real bonds
                for char_pos, bond_symbol in char_to_bond.items():
                    if char_pos < len(attribution_scores):
                        score = attribution_scores[char_pos]
                        if abs(score) > min_threshold:
                            # Try to map bond SMILES position to a real bond
                            for bond in mol.GetBonds():
                                begin_idx = bond.GetBeginAtomIdx()
                                end_idx = bond.GetEndAtomIdx()

                                # Map begin/end atom indices to SMILES character positions
                                begin_positions = [pos for pos, idx in char_to_atom.items() if idx == begin_idx]
                                end_positions = [pos for pos, idx in char_to_atom.items() if idx == end_idx]

                                # Check if the bond character is between begin and end in SMILES
                                for bpos in begin_positions:
                                    for epos in end_positions:
                                        if min(bpos, epos) < char_pos < max(bpos, epos):
                                            bond_highlights[bond.GetIdx()] = [score_to_color(score, attribution_scores)]
                                            break
    
    # Create the drawing
    drawer = rdMolDraw2D.MolDraw2DCairo(800, 800)
    drawer.drawOptions().addAtomIndices = debug  # Show atom indices in debug mode
    drawer.drawOptions().bondLineWidth = 2
    drawer.drawOptions().highlightBondWidthMultiplier = 2
    
    # Draw molecule with highlights
    drawer.DrawMoleculeWithHighlights(
        mol,
        '', #subname,
        atom_highlights,
        bond_highlights,
        atom_radii,
        {}
    )
    drawer.FinishDrawing()
    
    # Convert to PIL Image
    bio = io.BytesIO(drawer.GetDrawingText())
    img = Image.open(bio)
    
    if debug:
        print(f"Highlighted atoms: {list(atom_highlights.keys())}")
        print(f"Atom scores: {atom_scores}")
    
    return img

def create_smiles_attribution_comparison(subname, smiles, attribution_scores, debug=False):
    """
    Create a side-by-side comparison of SMILES attribution and molecular graph
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Left side: SMILES with character-level coloring
    ax1.set_title(f'{subname}\nSMILES Attribution', fontsize=12)
    ax1.axis('off')
    
    # Create character-level visualization
    char_colors = [score_to_color(score, attribution_scores) for score in attribution_scores]
    
    # Display SMILES with colored characters
    for i, (char, color) in enumerate(zip(smiles, char_colors)):
        ax1.text(i * 0.05, 0.75, char, fontsize=12, ha='center', va='center',
                bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.85))
    
    ax1.set_xlim(-0.05, len(smiles) * 0.05)
    ax1.set_ylim(0, 1)
    
    # Right side: Molecular graph with attribution
    ax2.set_title(f'{subname}\nMolecular Graph', fontsize=12)
    ax2.axis('off')
    
    # Generate molecular graph with attribution
    mol_img = visualize_attribution_scores_improved(subname, smiles, attribution_scores, debug=debug)
    ax2.imshow(mol_img)
    
    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap=plt.cm.RdYlBu, 
                              norm=plt.Normalize(vmin=-1,#min(attribution_scores), 
                                               vmax=1))#max(attribution_scores)))
    sm.set_array([])
    #cbar = plt.colorbar(sm, ax=[ax1, ax2], shrink=0.5, aspect=20)
    #cbar.set_label('Attribution Score', rotation=270, labelpad=20)
    
    plt.tight_layout()
    return fig

# Additional debugging function
def debug_attribution_mapping(smiles, attribution_scores):
    """
    Debug function to understand the mapping between SMILES and molecular structure
    """
    mol, char_to_atom, char_to_bond = smiles_to_mol_mapping(smiles)
    
    print("=== ATTRIBUTION MAPPING DEBUG ===")
    print(f"SMILES: {smiles}")
    print(f"Length: {len(smiles)}")
    print(f"Attribution scores: {attribution_scores}")
    print(f"Scores length: {len(attribution_scores)}")
    print()
    
    print("Character-by-character analysis:")
    for i, char in enumerate(smiles):
        score = attribution_scores[i] if i < len(attribution_scores) else 0.0
        atom_idx = char_to_atom.get(i, None)
        bond_info = char_to_bond.get(i, None)
        
        print(f"  {i:2d}: '{char}' -> Score: {score:6.3f}, Atom: {atom_idx}, Bond: {bond_info}")
    
    print()
    print("Atom aggregation:")
    atom_scores = {}
    for char_pos, atom_idx in char_to_atom.items():
        if char_pos < len(attribution_scores):
            if atom_idx not in atom_scores:
                atom_scores[atom_idx] = []
            atom_scores[atom_idx].append((char_pos, attribution_scores[char_pos]))
    
    for atom_idx, score_list in atom_scores.items():
        avg_score = np.mean([score for pos, score in score_list])
        print(f"  Atom {atom_idx}: positions {[pos for pos, score in score_list]}, "
              f"scores {[score for pos, score in score_list]}, avg: {avg_score:.3f}")

In [None]:
# Example usage and testing
if __name__ == "__main__":
    # Test cases
    test_cases = [
        {
            'name': names_list_tve[2],
            'smiles': smiles_list_tve[2],
            'attribution_scores': expanded_attrs_list_tve[2]
        },
        {
            'name': names_list_tve[5],
            'smiles': smiles_list_tve[5],
            'attribution_scores': expanded_attrs_list_tve[5]
        },
        {
            'name': names_list_tve[7],
            'smiles': smiles_list_tve[7],
            'attribution_scores': expanded_attrs_list_tve[7]
        },
        {
            'name': names_list_tve[16],
            'smiles': smiles_list_tve[16],
            'attribution_scores': expanded_attrs_list_tve[16]
        }
    ]
    
    for test_case in test_cases:
        print(f"\n{'='*50}")
        print(f"Testing: {test_case['name']}")
        print(f"{'='*50}")
        
        # Debug the mapping
        debug_attribution_mapping(test_case['smiles'], test_case['attribution_scores'])
        
        # Create visualization
        fig = create_smiles_attribution_comparison(
            test_case['name'], 
            test_case['smiles'], 
            test_case['attribution_scores'],
            debug=True
        )
        plt.show()

## 6. Augment the data by generating alternative SMILES representations

We augment the training data by leveraging alternative non-canonical SMILES representaions and retrain the model to see whether this strategy enhances the performance, stability and generalization capabilities of the trained model.
We will be generating different non-canonical SMILES representations using the RDKit package.

#### Justification
SMILES representation: not unique, the same molecule can have multiple valid representations, which translate the same chemical properties in terms of chemical structure and covalent bonds and such, but represent different strings when regarded as text. 
We enrich the dataset by using this property. We encode the names of the substrates by using different SMILES representations, eliminating the redundant (exactly the same) representations. This logic is supported by the following. 


#### Treatment of Special Symbols in SMILES
In SMILES, special characters represent bonds and molecular structure:

* `=` represents a double bond.
* `#` represents a triple bond.
* `@` indicates stereochemistry (e.g., `@` or `@@` for chirality).
* `/` and `\` indicate the stereochemistry of double bonds (Z/E or cis/trans configurations).

When you input a SMILES string (e.g., `C=C`, `C#C`, `C@C`, `C/C=C/C`), the ChemBERTa tokenizer processes it just like any other text. Here's how:

Tokenization:

The SMILES string is tokenized into sub-units based on the training corpus. ChemBERTa uses a BPE tokenizer to split the input into tokens (for instance, `C`, `=`, `C`, or `C`, `@`, `C` as separate tokens).
The tokenizer learns which symbols or combinations of symbols represent meaningful information by seeing many examples of SMILES strings during training.
Handling of Special Characters:

Special bond characters like `=`, `#`, `@`, `/`, `\` are treated as part of the tokenization process. They are either represented by their own tokens or grouped with nearby atoms if that's how the tokenizer was trained to split SMILES strings.
For instance, in the SMILES string C=C, the tokenizer might treat `C`, `=`, and `C` as separate tokens, while in more complex cases (like stereochemistry with `/` and `\\`), tokens might be split differently depending on the specific tokenizer settings.
Learned Representations:

The model doesn't understand the chemical meaning of these symbols inherently. Instead, it learns from the data to associate certain patterns (e.g., `C=C`) with particular molecular properties or activities through training on labeled datasets (e.g., for classification or regression tasks).
ChemBERTa learns embeddings (dense vector representations) for each token, including the ones representing bonds, and then uses these embeddings as part of its predictive mechanism.

In [None]:
def augment(data):

    for i in range(10):
        random.seed(123*i)

        smiles_ = []    
        for smiles in data['SMILES'].tolist():
            mol = Chem.MolFromSmiles(smiles)
            if mol: # ensure the molecule object is valid
                smiles_.append(Chem.MolToSmiles(mol, canonical=False, doRandom = True))
            else:
                smiles_.append(None)  # handle invalid SMILES
        data[f'SMILES{i+2}'] = smiles_
        
    return data

In [None]:
def data_split(data, train_ind):
    '''
    function that splits the extended data into train and test sets and 
    alters the training set by melting the augmented df to long format and removing redundant SMILES representations
    '''
    data = data.drop_duplicates()
    # replace MVs with 0, if any:
    data['Oxd'].fillna(0, inplace=True) 
    data['Oxd'] = data['Oxd'].round().astype('int64')
    # use the trained/test indices from before:
    
    train_dt, test_dt =  data.loc[train_ind], data.drop(train_ind)
    
    # prepare the data by augmenting the training set (trnsform wide-to-long, drop redundant SMILES representations) and separating features and labels for training and test sets: 
    train_dt = pd.melt(train_dt, id_vars=['Oxd', 'Substrate Name', 'IUPAC Name'], value_vars=[s for s in data.columns.tolist() if s.count('SMILES')>0], var_name='SMILES_type')
    train_dt = train_dt.drop(['Substrate Name', 'IUPAC Name', 'SMILES_type'], axis=1)
    train_dt = train_dt.drop_duplicates()

    train_data, train_labels = train_dt['value'].rename('SMILES'), train_dt['Oxd']
    test_data, test_labels = test_dt['SMILES'], test_dt['Oxd'] # disregard alternative representations for the test data  
    
    return train_data, test_data, train_labels, test_labels

In [None]:
# extend the datasets:
data_tve_ = data_tve.copy()
augment(data_tve_)
train_data_tve_, test_data_tve_, train_labels_tve_, test_labels_tve_ = data_split(data_tve_, train_data_tve.index.to_list()) 

data_mth_ = data_mth.copy()
augment(data_mth_)
train_data_mth_, test_data_mth_, train_labels_mth_, test_labels_mth_ = data_split(data_mth_, train_data_mth.index.to_list()) 

data_bpu_ = data_bpu.copy()
augment(data_bpu_)
train_data_bpu_, test_data_bpu_, train_labels_bpu_, test_labels_bpu_ = data_split(data_bpu_, train_data_bpu.index.to_list()) 

# display set sizes:
print("Number of training examples (f-tve):", len(train_data_tve_))
print("Number of test examples (f-tve):", len(test_data_tve_))
print('\nNumber of training examples (f-mth)', len(train_data_mth_))
print("Number of test examples (f-mth):", len(test_data_mth_))
print('\nNumber of training examples (bpu-lac)', len(train_data_bpu_))
print("Number of test examples (bpu-lac):", len(test_data_bpu_))

### Retrain and rebuild the model on the augmented data (for f-tve)

In [None]:
study_tve_ = tune_model(train_data_tve_, train_labels_tve_)
final_model_tve_, logits_tve_, attentions_tve_, h_states_tve_, inputs_tve_, mask_tve_, preds_tve_, true_labels_tve_ = rebuild_model(train_data_tve_, train_labels_tve_, test_data_tve_, test_labels_tve_, study_tve_)
evaluate_model(true_labels_tve_, preds_tve_)

## Extra: 

For training the model with frozen layers: 

* RoBERTa is frozen → requires_grad = False
* only the classifier (self.classifier) is trainable
* with torch.no_grad() is used for inference in frozen layers
* optimizer updated to only train classifier parameters

```python

class CustomRoberta(nn.Module):
    
    def __init__(self, dropout_prob):
        
        super(CustomRoberta, self).__init__()
        self.checkpoint = 'seyonec/PubChem10M_SMILES_BPE_450k' 
        self.roberta = RobertaModel.from_pretrained(self.checkpoint, output_attentions=True, output_hidden_states=True).to(device)
        
        # freeze all RoBERTa layers:
        for param in self.roberta.parameters():
            param.requires_grad = False
        
        self.config = self.roberta.config 
        self.current_embeddings = self.roberta.get_input_embeddings()
        self.tokenizer = RobertaTokenizerFast.from_pretrained(self.checkpoint)
        self.dropout = nn.Dropout(dropout_prob).to(device)
        self.classifier = nn.Linear(self.roberta.config.hidden_size, 2).to(device)
        
        
        # unfreeze the classification head of the RoBERTa model:
        for param in self.classifier.parameters():
            param.requires_grad = True
       
    def prepare_data(self, data, labels, batch_size, shuffle):
        
        # tokenize the SMILES:
        tokenized = self.tokenizer(data.tolist(), padding=True, truncation=True, max_length=180, return_tensors='pt')
        # extract input ids and attention masks:
        input_ids = tokenized['input_ids']
        attention_mask = tokenized['attention_mask']
        # convert labels to tensors:
        labels = torch.tensor(labels.values)
        # create DataLoaders for batching:
        dataset = TensorDataset(input_ids, attention_mask, labels)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, worker_init_fn=lambda _: np.random.seed(42))
        return loader#, input_ids, attention_mask    

    
    def forward(self, input_ids, attention_mask=None):
        with torch.no_grad():
            outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs[1]  # [CLS] token representation
        
        dropped_output = self.dropout(pooled_output)  # add dropout
        logits = self.classifier(dropped_output)  # pass through classification layer
        
        return logits, outputs.attentions, outputs.hidden_states

```
