### Setup.

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import numpy as np
import random

seed = 7

torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

import ankh

from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import Trainer, TrainingArguments, EvalPrediction
from datasets import load_dataset

from sklearn import metrics
from scipy import stats
from functools import partial
import pandas as pd
from tqdm.auto import tqdm

In [None]:
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

### Select the available device.

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Available device:', device)

### Load Ankh large model.

In [None]:
model, tokenizer = ankh.load_large_model()
model.eval()
model.to(device=device)

In [None]:
print(f"Number of parameters:", get_num_params(model))

### Load the datasets

In [None]:
dataset = load_dataset("ElnaggarLab/solubility")

In [None]:
training_sequences, training_labels = dataset['train']['sequences'], dataset['train']['labels']
validation_sequences, validation_labels = dataset['validation']['sequences'], dataset['validation']['labels']
test_sequences, test_labels = dataset['test']['sequences'], dataset['test']['labels']

In [None]:
def load_dataset(sequences, labels, max_length=None):
    '''
        Args:
            sequences: list, the list which contains the protein primary sequences.
            labels: list, the list which contains the dataset labels.
            max_length, Integer, the maximum sequence length, 
            if there is a sequence that is larger than the specified sequence length will be post-truncated. 
    '''
    if max_length is None:
        max_length = len(max(training_sequences, key=lambda x: len(x)))
    splitted_sequences = [list(seq[:max_length]) for seq in sequences]
    return splitted_sequences, labels

In [None]:
def embed_dataset(model, sequences, shift_left = 0, shift_right = -1):
    inputs_embedding = []
    with torch.no_grad():
        for sample in tqdm(sequences):
            ids = tokenizer.batch_encode_plus([sample], add_special_tokens=True, 
                                              padding=True, is_split_into_words=True, 
                                              return_tensors="pt")
            embedding = model(input_ids=ids['input_ids'].to(device))[0]
            embedding = embedding[0].detach().cpu().numpy()[shift_left:shift_right]
            inputs_embedding.append(embedding)
    return inputs_embedding

### Preprocess the dataset.

In [None]:
training_sequences, training_labels = load_dataset(training_sequences, training_labels)
validation_sequences, validation_labels = load_dataset(validation_sequences, validation_labels)
test_sequences, test_labels = load_dataset(test_sequences, test_labels)

### Extract sequences embeddings.

In [None]:
training_embeddings = embed_dataset(model, training_sequences)
validation_embeddings = embed_dataset(model, validation_sequences)
test_embeddings = embed_dataset(model, test_sequences)

In [None]:
class SolubilityDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __getitem__(self, idx):
        embedding = self.sequences[idx]
        label = self.labels[idx]
        return {'embed':torch.tensor(embedding), 'labels': torch.tensor(label, dtype=torch.float32).unsqueeze(-1)}

    def __len__(self):
        return len(self.sequences)

In [None]:
training_dataset = SolubilityDataset(training_embeddings, training_labels)
validation_dataset = SolubilityDataset(validation_embeddings, validation_labels)
test_dataset = SolubilityDataset(test_embeddings, test_labels)

### Model initialization function for HuggingFace's trainer.

In [None]:
def model_init(embed_dim):
    hidden_dim = int(embed_dim / 2)
    num_hidden_layers = 1 # Number of hidden layers in ConvBert.
    nlayers = 1 # Number of ConvBert layers.
    nhead = 4
    dropout = 0.2
    conv_kernel_size = 7
    pooling = 'max' # available pooling methods ['avg', 'max']
    downstream_model = ankh.ConvBertForBinaryClassification(input_dim=embed_dim, 
                                                            nhead=nhead, 
                                                            hidden_dim=hidden_dim, 
                                                            num_hidden_layers=num_hidden_layers, 
                                                            num_layers=nlayers, 
                                                            convsize=conv_kernel_size,
                                                            dropout=dropout, 
                                                            pooling=pooling)
    return downstream_model.cuda()

### Function for computing metrics, Accuracy is used in this task.

In [None]:
def compute_metrics(p: EvalPrediction):
    preds = (torch.sigmoid(torch.tensor(p.predictions)).numpy() > 0.5).tolist()
    labels = p.label_ids.tolist()
    return {
        "accuracy": metrics.accuracy_score(labels, preds),
        "precision": metrics.precision_score(labels, preds),
        "recall": metrics.recall_score(labels, preds),
        "f1": metrics.f1_score(labels, preds),
    }

### Create and configure HuggingFace's TrainingArguments instance.

In [None]:
model_type = 'ankh_large'
experiment = f'solubility_{model_type}'

training_args = TrainingArguments(
    output_dir=f'./results_{experiment}',
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    warmup_steps=1000,
    learning_rate=1e-03,
    weight_decay=0.0,
    logging_dir=f'./logs_{experiment}',
    logging_steps=200,
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    gradient_accumulation_steps=16,
    fp16=False,
    fp16_opt_level="02",
    run_name=experiment,
    seed=seed,
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    greater_is_better=True,
    save_strategy="epoch"
)

### Create HuggingFace Trainer.

In [None]:
model_embed_dim = 1536 # Embedding dimension for ankh large.

trainer = Trainer(
    model_init=partial(model_init, embed_dim=model_embed_dim),
    args=training_args,
    train_dataset=training_dataset,
    eval_dataset=validation_dataset,
    compute_metrics=compute_metrics,
)

### Train the model.

In [None]:
trainer.train()

In [None]:
predictions, labels, metrics_output = trainer.predict(test_dataset)

In [None]:
metrics_output