In [1]:
import os

import torch

from GPTModel.GPTModel import *
from GPTModel.ModelConfig import *
from GPTModel.Device import *
from ContactMapPredictionHead import *
from ProteinStructureDataSet import *

In [2]:
# loads a pretrained model from model path, using the configuration in cfg
def load_model(model_path, cfg):
    model = GPTModel(cfg)
    model.out_head = ContactMapPredictionHead(cfg) # replace the GPTModel output head with the classification head
    model.load_state_dict(torch.load(model_path, weights_only=True))
    model.eval()
    return model

In [3]:
"""
Calculates top L/2 contact map accuracy.
Args:
    pred_map: (L, L) Tensor of predicted probabilities
    true_map: (L, L) Tensor of binary ground truth (1: contact, 0: no contact)
    seq_len: Protein sequence length (L)
"""
def calculate_top_l2_accuracy(pred_map, true_map, seq_len):

    # 1. Define top-k: L/2
    k = seq_len // 2

    # 2. Mask diagonal to exclude self-contacts (i=j) and close contacts (e.g., |i-j|<=5)
    # Using a simple diagonal mask here
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=6).bool()

    # Apply mask and flatten
    pred_masked = pred_map[mask]
    true_masked = true_map[mask]

    # 3. Get top-k indices
    if len(pred_masked) < k:
        return 0.0 # Not enough contacts to evaluate

    topk_vals, topk_indices = torch.topk(pred_masked, k)

    # 4. Calculate accuracy
    correct = true_masked[topk_indices].sum()
    accuracy = correct.float() / k

    return accuracy.item()

In [4]:
# performs prediction for the specified model and dataset, and calculate the average top L/2 accuracy
def predict_and_calculate_accuracy(model_filename, cfg, test_data, device, num_samples = 100):
    #load the fine tuned model
    model = load_model(model_filename, cfg)
    model.to(device)
    # go over the test samples and calculate the accuracy of the prediction for each sample
    accuracy_sum = 0.0
    for i in range(num_samples):
        sample = torch.unsqueeze(torch.tensor(test_data[i]["sequence"]),dim = 0) # add a batch dimension to the input sample
        target = test_data[i]["contact_map"] #get the label(the contact map)
        sample_device = sample.to(device)
        target_device = target.to(device)
        logits = model(sample_device) # calculate the model output - the shape is (1, sequence_length, sequence_length, 2)
        if i == 53:
            contact_map = torch.squeeze(torch.argmax(logits, dim = -1))
        prediction = torch.argmax(logits, dim = -1) #predict the contact map by choosing the highest value for each token combination i,j in the output logits
        acc = calculate_top_l2_accuracy(prediction[0,:,:], target_device, target_device.shape[0]) #calculate the accuracy among the top l/2 predicted probabilities
        accuracy_sum += acc # sum the accuracy over all the samples in order to calculate the accuracy mean
    # calculate the mean accuracy over all the test set samples
    accuracy_mean = accuracy_sum / num_samples
    return accuracy_mean

In [5]:
#load the model configuration
CHOOSE_MODEL = "gpt2-small (124M)" # use the small GPT model
cfg, _ = GetModelConfig(CHOOSE_MODEL)

#load the structures data set
data_file_name = os.path.join("Datasets", f"classification_dataset_encoded.dat")
with open(data_file_name, "rb") as infile:
    sequences_with_structures = torch.load(infile)

# split to train and test sets
sample_count = len(sequences_with_structures)
train_set_ratio = 0.9
train_set_size = train_set_ratio * sample_count
test_data = sequences_with_structures[int(train_set_size):]

# get the best device that is supported by the system
device = GetDevice()

In [6]:
pretrain_file_name = "TrainedModels/TrainedModelClassification_final_pretrained.pth" #The trained model path
no_pretrain_file_name = "TrainedModels/TrainedModelClassification_final.pth" #The trained model path

In [7]:
num_samples = 200
pretrain_accuracy = predict_and_calculate_accuracy(pretrain_file_name, cfg, test_data, device, num_samples)
no_pretrain_accuracy = predict_and_calculate_accuracy(no_pretrain_file_name, cfg, test_data, device, num_samples)

# print the mean accuracy
print(f"The mean accuracy without pretraining is {no_pretrain_accuracy}, The mean accuracy with pretraining is {pretrain_accuracy}")

The mean accuracy without pretraining is 0.0861730356537737, The mean accuracy with pretraining is 0.09749160054838285
