### Setup.

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import numpy as np
import random

seed = 7

torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

from torch import nn
from torch.utils.data import Dataset, DataLoader
import ankh
from transformers import Trainer, TrainingArguments, EvalPrediction
from datasets import load_dataset
import transformers.models.convbert as c_bert
from scipy import stats
from functools import partial
import pandas as pd
from tqdm.auto import tqdm

In [None]:
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

### Select the available device.

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Available device:', device)

### Load Ankh large model.

In [None]:
model, tokenizer = ankh.load_large_model()
model.eval()
model.to(device=device)

In [None]:
print(f"Number of parameters:", get_num_params(model))

### Load the datasets

In [None]:
dataset = load_dataset("ElnaggarLab/fluorosence")

In [None]:
training_sequences, training_labels = dataset['train']['primary'], dataset['train']['log_fluorescence']
validation_sequences, validation_labels = dataset['validation']['primary'], dataset['validation']['log_fluorescence']
test_sequences, test_labels = dataset['test']['primary'], dataset['test']['log_fluorescence']

In [None]:
# Get the mean of the labels to initialize 
# the final layer's bias with it for faster convergence in regression tasks.
training_labels_mean = np.mean(dataset['train']['log_fluorescence'])
training_labels_mean

In [None]:
def load_dataset(sequences, labels, max_length=None):
    '''
        Args:
            sequences: list, the list which contains the protein primary sequences.
            max_length, Integer, the maximum sequence length, 
            if there is a sequence that is larger than the specified sequence length will be post-truncated. 
    '''
    if max_length is None:
        max_length = len(max(training_sequences, key=lambda x: len(x)))
    splitted_sequences = [list(seq[:max_length]) for seq in sequences]
    return splitted_sequences, labels

In [None]:
def embed_dataset(model, sequences, shift_left = 0, shift_right = -1):
    inputs_embedding = []
    with torch.no_grad():
        for sample in tqdm(sequences):
            ids = tokenizer.batch_encode_plus([sample], add_special_tokens=True, 
                                              padding=True, is_split_into_words=True, 
                                              return_tensors="pt")
            embedding = model(input_ids=ids['input_ids'].to(device))[0]
            embedding = embedding[0].detach().cpu().numpy()[shift_left:shift_right]
            inputs_embedding.append(embedding)
    return inputs_embedding

### Preprocess the dataset.

In [None]:
training_sequences, training_labels = load_dataset(training_sequences, training_labels)
validation_sequences, validation_labels = load_dataset(validation_sequences, validation_labels)
test_sequences, test_labels = load_dataset(test_sequences, test_labels)

### Extract sequences embeddings.

In [None]:
training_embeddings = embed_dataset(model, training_sequences)
validation_embeddings = embed_dataset(model, validation_sequences)
test_embeddings = embed_dataset(model, test_sequences)

In [None]:
class FluorescenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __getitem__(self, idx):
        sample = self.sequences[idx]
        label = self.labels[idx]
        return {
            'embed': torch.tensor(sample),
            'labels': torch.tensor(label, dtype=torch.float32).unsqueeze(-1)}

    def __len__(self):
        return len(self.labels)

In [None]:
training_dataset = FluorescenceDataset(training_embeddings, training_labels)
validation_dataset = FluorescenceDataset(validation_embeddings, validation_labels)
test_ds = FluorescenceDataset(test_embeddings, test_labels)

### Model initialization function for HuggingFace's trainer.

In [None]:
def model_init(embed_dim, training_labels_mean=None):
    hidden_dim = int(embed_dim / 2)
    num_hidden_layers = 1
    nlayers = 1
    nhead = 4
    dropout = 0.2
    conv_kernel_size = 7
    pooling = 'max' # available pooling methods ['avg', 'max']
    downstream_model = ankh.ConvBertForRegression(input_dim=embed_dim, 
                                                  nhead=nhead, 
                                                  hidden_dim=hidden_dim, 
                                                  num_hidden_layers=num_hidden_layers, 
                                                  num_layers=nlayers, 
                                                  kernel_size=conv_kernel_size,
                                                  dropout=dropout, 
                                                  pooling=pooling, 
                                                  training_labels_mean=training_labels_mean)
    return downstream_model.cuda()

### Function for computing metrics, Spearman correlation is used in this regression tasks.

In [None]:
def compute_metrics(p: EvalPrediction):
    return {
        "spearmanr": stats.spearmanr(p.label_ids, p.predictions).correlation,
    }

### Create and configure HuggingFace's TrainingArguments instance.

In [None]:
model_type = 'ankh_large'
experiment = f'flu_{model_type}'

training_args = TrainingArguments(
    output_dir=f'./results_{experiment}',
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    warmup_steps=1000,
    learning_rate=1e-03,
    weight_decay=0.0,
    logging_dir=f'./logs_{experiment}',
    logging_steps=200,
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    gradient_accumulation_steps=16,
    fp16=False,
    fp16_opt_level="02",
    run_name=experiment,
    seed=seed,
    load_best_model_at_end=True,
    metric_for_best_model="eval_spearmanr",
    greater_is_better=True,
    save_strategy="epoch"
)

### Create HuggingFace Trainer.

In [None]:
model_embed_dim = 1536

trainer = Trainer(
    model_init=partial(model_init, embed_dim=model_embed_dim, training_labels_mean=training_labels_mean),
    args=training_args,
    train_dataset=training_dataset,
    eval_dataset=validation_dataset,
    compute_metrics=compute_metrics,
)

### Train the model.

In [None]:
trainer.train()