### Setup.

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import numpy as np
import random

seed = 7

torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

import ankh

from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import Trainer, TrainingArguments, EvalPrediction
from datasets import load_dataset

from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from scipy import stats
from functools import partial
import pandas as pd
from tqdm.auto import tqdm

In [None]:
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

### Select the available device.

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Available device:', device)

### Load Ankh large model.

In [None]:
model, tokenizer = ankh.load_large_model()
model.eval()
model.to(device=device)

In [None]:
print(f"Number of parameters:", get_num_params(model))

### Load the datasets

In [None]:
name = "proteinea/secondary_structure_prediction"
training_dataset = load_dataset(name, data_files={'train': ['training_hhblits.csv']})
casp12_dataset = load_dataset(name, data_files={'test': ['CASP12.csv']})
casp14_dataset = load_dataset(name, data_files={'test': ['CASP14.csv']})
ts115_dataset = load_dataset(name, data_files={'test': ['TS115.csv']})
cb513_dataset = load_dataset(name, data_files={'test': ['CB513.csv']})

In [None]:
input_column_name = 'input'
labels_column_name = 'dssp3' # You can change it to "dssp8" if you want to work with 8 states.
disorder_column_name = 'disorder'
training_sequences, training_labels, training_disorder = (
    training_dataset['train'][input_column_name], 
    training_dataset['train'][labels_column_name],
    training_dataset['train'][disorder_column_name]
)


casp12_sequences, casp12_labels, casp12_disorder = (
    casp12_dataset['test'][input_column_name], 
    casp12_dataset['test'][labels_column_name],
    casp12_dataset['test'][disorder_column_name]
)

casp14_sequences, casp14_labels, casp14_disorder = (
    casp14_dataset['test'][input_column_name], 
    casp14_dataset['test'][labels_column_name],
    casp14_dataset['test'][disorder_column_name]
)

ts115_sequences, ts115_labels, ts115_disorder = (
    ts115_dataset['test'][input_column_name], 
    ts115_dataset['test'][labels_column_name],
    ts115_dataset['test'][disorder_column_name]
)

cb513_sequences, cb513_labels, cb513_disorder = (
    cb513_dataset['test'][input_column_name], 
    cb513_dataset['test'][labels_column_name],
    cb513_dataset['test'][disorder_column_name]
)

In [None]:
def preprocess_dataset(sequences, labels, disorder, max_length=None):
    
    sequences = ["".join(seq.split()) for seq in sequences]
    
    if max_length is None:
        max_length = len(max(sequences, key=lambda x: len(x)))

    seqs = [list(seq)[:max_length] for seq in sequences]
    
    labels = ["".join(label.split()) for label in labels]
    labels = [list(label)[:max_length] for label in labels]
    
    disorder = [" ".join(disorder.split()) for disorder in disorder]
    disorder = [disorder.split()[:max_length] for disorder in disorder]
    
    assert len(seqs) == len(labels) == len(disorder)
    return seqs, labels, disorder

In [None]:
def embed_dataset(model, sequences, shift_left = 0, shift_right = -1):
    inputs_embedding = []
    with torch.no_grad():
        for sample in tqdm(sequences):
            ids = tokenizer.batch_encode_plus([sample], add_special_tokens=True, 
                                              padding=True, is_split_into_words=True, 
                                              return_tensors="pt")
            embedding = model(input_ids=ids['input_ids'].to(device))[0]
            embedding = embedding[0].detach().cpu().numpy()[shift_left:shift_right]
            inputs_embedding.append(embedding)
    return inputs_embedding

### Preprocess the dataset.

In [None]:
training_sequences, training_labels, training_disorder = preprocess_dataset(training_sequences, 
                                                                            training_labels, 
                                                                            training_disorder)
casp12_sequences, casp12_labels, casp12_disorder = preprocess_dataset(casp12_sequences, 
                                                                      casp12_labels, 
                                                                      casp12_disorder)

casp14_sequences, casp14_labels, casp14_disorder = preprocess_dataset(casp14_sequences, 
                                                                      casp14_labels, 
                                                                      casp14_disorder)
ts115_sequences, ts115_labels, ts115_disorder = preprocess_dataset(ts115_sequences, 
                                                                   ts115_labels, 
                                                                   ts115_disorder)
cb513_sequences, cb513_labels, cb513_disorder = preprocess_dataset(cb513_sequences, 
                                                                   cb513_labels, 
                                                                   cb513_disorder)

### Extract sequences embeddings.

In [None]:
training_embeddings = embed_dataset(model, training_sequences[:10])
casp12_embeddings = embed_dataset(model, casp12_sequences[:10])
casp14_embeddings = embed_dataset(model, casp14_sequences[:10])
ts115_embeddings = embed_dataset(model, ts115_sequences[:10])
cb513_embeddings = embed_dataset(model, cb513_sequences[:10])

### Create unique tag for each state, in this current task we have only 3 states

In [None]:
# Consider each label as a tag for each token
unique_tags = set(tag for doc in training_labels for tag in doc)
tag2id = {tag: id for id, tag in enumerate(unique_tags)}
id2tag = {id: tag for tag, id in tag2id.items()}

### Encode the tags in the dataset

In [None]:
def encode_tags(labels):
    labels = [[tag2id[tag] for tag in doc] for doc in labels]
    return labels

In [None]:
train_labels_encodings = encode_tags(training_labels)
casp12_labels_encodings = encode_tags(casp12_labels)
casp14_labels_encodings = encode_tags(casp14_labels)
ts115_labels_encodings = encode_tags(ts115_labels)
cb513_labels_encodings = encode_tags(cb513_labels)

### Mask disordered tokens, Mask is set to -100 which is the default value for `ignore_index` in the cross entropy loss in PyTorch.

In [None]:
def mask_disorder(labels, masks):
    for label, mask in zip(labels,masks):
        for i, disorder in enumerate(mask):
            if disorder == "0.0":
                label[i] = -100
    return labels

In [None]:
train_labels_encodings = mask_disorder(train_labels_encodings, training_disorder)
casp12_labels_encodings = mask_disorder(casp12_labels_encodings, casp12_disorder)
casp14_labels_encodings = mask_disorder(casp14_labels_encodings, casp14_disorder)
ts115_labels_encodings = mask_disorder(ts115_labels_encodings, ts115_disorder)
cb513_labels_encodings = mask_disorder(cb513_labels_encodings, cb513_disorder)

In [None]:
class SSPDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        embedding = self.encodings[idx]
        labels = self.labels[idx]
        return {'embed': torch.tensor(embedding), 'labels': torch.tensor(labels)}

    def __len__(self):
        return len(self.labels)

In [None]:
training_dataset = SSPDataset(training_embeddings, train_labels_encodings[:10])
casp12_dataset = SSPDataset(casp12_embeddings, casp12_labels_encodings[:10])
casp14_dataset = SSPDataset(casp14_embeddings, casp14_labels_encodings[:10])
ts115_dataset = SSPDataset(ts115_embeddings, ts115_labels_encodings[:10])
cb513_dataset = SSPDataset(cb513_embeddings, cb513_labels_encodings[:10])

### Function for computing metrics, Accuracy is used in this task.

In [None]:
def align_predictions(predictions: np.ndarray, label_ids: np.ndarray):
        preds = np.argmax(predictions, axis=2)

        batch_size, seq_len = preds.shape

        out_label_list = [[] for _ in range(batch_size)]
        preds_list = [[] for _ in range(batch_size)]

        for i in range(batch_size):
            for j in range(seq_len):
                if label_ids[i, j] != torch.nn.CrossEntropyLoss().ignore_index:
                    out_label_list[i].append(id2tag[label_ids[i][j]])
                    preds_list[i].append(id2tag[preds[i][j]])

        return preds_list, out_label_list

def compute_metrics(p: EvalPrediction):
    preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
    return {
        "accuracy": accuracy_score(out_label_list, preds_list),
        "precision": precision_score(out_label_list, preds_list),
        "recall": recall_score(out_label_list, preds_list),
        "f1": f1_score(out_label_list, preds_list),
    }

### Model initialization function for HuggingFace's trainer.

In [None]:
def model_init(num_tokens, embed_dim):
    hidden_dim = int(embed_dim / 2)
    num_hidden_layers = 1 # Number of hidden layers in ConvBert.
    nlayers = 1 # Number of ConvBert layers.
    nhead = 4
    dropout = 0.2
    conv_kernel_size = 7
    downstream_model = ankh.ConvBertForMultiClassClassification(num_tokens=num_tokens,
                                                                input_dim=embed_dim, 
                                                                nhead=nhead, 
                                                                hidden_dim=hidden_dim, 
                                                                num_hidden_layers=num_hidden_layers, 
                                                                num_layers=nlayers, 
                                                                kernel_size=conv_kernel_size,
                                                                dropout=dropout)
    return downstream_model.cuda()

### Create and configure HuggingFace's TrainingArguments instance.

In [None]:
model_type = 'ankh_large'
experiment = f'ssp3_{model_type}'

training_args = TrainingArguments(
    output_dir=f'./results_{experiment}',
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    warmup_steps=1000,
    learning_rate=1e-03,
    weight_decay=0.0,
    logging_dir=f'./logs_{experiment}',
    logging_steps=200,
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    gradient_accumulation_steps=16,
    fp16=False,
    fp16_opt_level="02",
    run_name=experiment,
    seed=seed,
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    greater_is_better=True,
    save_strategy="epoch"
)

### Create HuggingFace Trainer.

In [None]:
model_embed_dim = 1536 # Embedding dimension for ankh large.

trainer = Trainer(
    model_init=partial(model_init, num_tokens=len(unique_tags), embed_dim=model_embed_dim),
    args=training_args,
    train_dataset=training_dataset,
    eval_dataset=casp12_dataset,
    compute_metrics=compute_metrics,
)

### Train the model.

In [None]:
trainer.train()

In [None]:
predictions, labels, metrics_output = trainer.predict(test_dataset)

In [None]:
metrics_output