## This notebook is for running the model at inference 

In [None]:
import pandas as pd
import numpy as np

val_df = pd.read_csv("./data/val.csv", index_col=0)

test_sequences = val_df["sequences"].tolist()
test_sequences_3d = val_df["seq3d"].tolist()
test_labels = val_df["label"].tolist()

In [None]:
from transformers import T5Tokenizer
from torch.utils.data import DataLoader
import torch
from transformers import T5Tokenizer
from datasets import Dataset
import re
import datasets
import pandas as pd
from datasets import Dataset
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
import os
import torch
import torch.nn as nn
from transformers import T5EncoderModel, T5ForSequenceClassification, T5PreTrainedModel
from transformers import PretrainedConfig, T5Config
import evaluate
import numpy as np


def compute_metrics(eval_preds):
    metric = evaluate.combine(["f1", "precision", "recall"])
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    metrics = metric.compute(
        predictions=predictions, references=labels, average="weighted"
    )
    return metrics


class ProteinDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids_sequence": torch.tensor(item["input_ids_sequence"]),
            "attention_mask_sequence": torch.tensor(item["attention_mask_sequence"]),
            "input_ids_structure": torch.tensor(item["input_ids_structure"]),
            "attention_mask_structure": torch.tensor(item["attention_mask_structure"]),
            "labels": torch.tensor(item["labels"]),
        }


def create_dataset(tokenized_sequences, tokenized_structures, labels):
    input_ids_sequence = [item["input_ids"].squeeze() for item in tokenized_sequences]
    attention_mask_sequence = [
        item["attention_mask"].squeeze() for item in tokenized_sequences
    ]
    input_ids_structure = [item["input_ids"].squeeze() for item in tokenized_structures]
    attention_mask_structure = [
        item["attention_mask"].squeeze() for item in tokenized_structures
    ]

    dataset_dict = {
        "input_ids_sequence": input_ids_sequence,
        "attention_mask_sequence": attention_mask_sequence,
        "input_ids_structure": input_ids_structure,
        "attention_mask_structure": attention_mask_structure,
        "labels": labels,
    }

    return Dataset.from_dict(dataset_dict)


def preprocess_data(sequences, structures, tokenizer):
    tokenized_sequences = []
    tokenized_structures = []

    for sequence, structure in zip(sequences, structures):
        # Preprocess sequences
        sequence = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
        structure = " ".join(list(structure))

        sequence = "<AA2fold> " + sequence if sequence.isupper() else sequence
        structure = "<fold2AA> " + structure

        # Tokenize sequences and structures
        sequence_inputs = tokenizer(
            sequence, add_special_tokens=True, padding="longest", return_tensors="pt"
        )
        structure_inputs = tokenizer(
            structure, add_special_tokens=True, padding="longest", return_tensors="pt"
        )

        tokenized_sequences.append(sequence_inputs)
        tokenized_structures.append(structure_inputs)

    return tokenized_sequences, tokenized_structures


class T5ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config: T5Config):
        super().__init__()
        self.dense = nn.Linear(config.d_model, config.d_model)
        self.dropout = nn.Dropout(p=config.classifier_dropout)
        self.out_proj = nn.Linear(config.d_model, config.num_labels)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


class CustomT5ForSequenceClassification(T5PreTrainedModel):
    def __init__(self, model_checkpoint, config):

        super().__init__(config)
        self.transformer = T5EncoderModel.from_pretrained(model_checkpoint)
        self.classification_head = T5ClassificationHead(config)

    def forward(
        self,
        input_ids_sequence,
        input_ids_structure,
        attention_mask_sequence=None,
        attention_mask_structure=None,
        labels=None,
    ):
        # here we could add an augmentation step in which we choose x% of the time the sequence only, the structure only, or both

        # Get embeddings for the sequence
        sequence_outputs = self.transformer(
            input_ids_sequence, attention_mask=attention_mask_sequence
        )
        sequence_embeddings = sequence_outputs.last_hidden_state.mean(dim=1)

        # Get embeddings for the structure
        structure_outputs = self.transformer(
            input_ids_structure, attention_mask=attention_mask_structure
        )
        structure_embeddings = structure_outputs.last_hidden_state.mean(dim=1)

        # Combine the embeddings
        combined_embeddings = (
            sequence_embeddings + structure_embeddings
        ) / 2.0  # can be changed to concatenation but the embedding d_model in the config should be adjusted accordingly

        # Feed to classifier head
        logits = self.classification_head(combined_embeddings)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

        return (loss, logits) if loss is not None else logits


model_checkpoint = "Rostlab/ProstT5"
model_name = model_checkpoint.split("/")[-1]
model_dir = (
    "/workspace/cath_classification/artifacts/model-ProstT5-v-11:v1/model.safetensors"
)


tokenizer = T5Tokenizer.from_pretrained(model_checkpoint, do_lower_case=False)


# val_df = pd.read_csv("./data/val.csv")
# test_sequences = val_df["sequences"].tolist()
# test_sequences_3d = val_df["seq3d"].tolist()
# test_labels = val_df["label"].tolist()

# Preprocess and tokenize the data
test_tokenized_sequences, test_tokenized_structures = preprocess_data(
    test_sequences, test_sequences_3d, tokenizer
)


# Create Dataset objects
test_dataset = create_dataset(
    test_tokenized_sequences, test_tokenized_structures, test_labels
)

# Create custom dataset
test_dataset = ProteinDataset(test_dataset)

num_labels = 10  # max(test_labels) + 1
batch_size = 64


preconfig = PretrainedConfig.from_pretrained(model_checkpoint)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preconfig.update({"num_labels": num_labels, "classifier_dropout": 0.1})

model = CustomT5ForSequenceClassification(model_checkpoint, preconfig).to(device)

In [None]:
from safetensors.torch import load_file as load_safetensors

state_dict = load_safetensors(model_dir)

# Load the state dict into the model
model.load_state_dict(state_dict, strict=False)

In [None]:
from torch.utils.data.dataloader import default_collate


def custom_collate_fn(batch):
    # Extract elements
    input_ids_sequence = [item["input_ids_sequence"] for item in batch]
    attention_mask_sequence = [item["attention_mask_sequence"] for item in batch]
    input_ids_structure = [item["input_ids_structure"] for item in batch]
    attention_mask_structure = [item["attention_mask_structure"] for item in batch]
    labels = [item["labels"] for item in batch]

    # Pad sequences to the maximum length in the batch
    input_ids_sequence_padded = torch.nn.utils.rnn.pad_sequence(
        input_ids_sequence, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    attention_mask_sequence_padded = torch.nn.utils.rnn.pad_sequence(
        attention_mask_sequence, batch_first=True, padding_value=0
    )
    input_ids_structure_padded = torch.nn.utils.rnn.pad_sequence(
        input_ids_structure, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    attention_mask_structure_padded = torch.nn.utils.rnn.pad_sequence(
        attention_mask_structure, batch_first=True, padding_value=0
    )

    labels = torch.stack(labels)

    return {
        "input_ids_sequence": input_ids_sequence_padded,
        "attention_mask_sequence": attention_mask_sequence_padded,
        "input_ids_structure": input_ids_structure_padded,
        "attention_mask_structure": attention_mask_structure_padded,
        "labels": labels,
    }


args = TrainingArguments(
    f"{model_name}-finetuned",
    evaluation_strategy="epoch",
    eval_strategy="epoch",
    per_device_eval_batch_size=batch_size,
    push_to_hub=False,
    fp16=True,
    fp16_full_eval=True,
    report_to="none",
)

trainer = Trainer(
    model,
    args,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=custom_collate_fn,
)
trainer.evaluate(test_dataset)