In [1]:
import os
from GPTModel.TrainModel import *
from ProteinTokenizer import ProteinTokenizer
from GPTModel.GPTModel import *
from GPTModel.ModelConfig import *
from GPTModel.Device import *
from ContactMapPredictionHead import *
from ProteinStructureDataSet import *

In [2]:
# creates the classification model from a pretrained model - if pretrained_model_path is None, use random parameters for the model
def load_model(pretrained_model_path, cfg):
    model = GPTModel(cfg)
    if pretrained_model_path is not None:
        model.load_state_dict(torch.load(pretrained_model_path, weights_only=True))
    model.out_head = ContactMapPredictionHead(cfg)
    model.train()
    return model


In [None]:
pretrained_model_file = "TrainedModels/TrainedModelPretraining.pth"
pretrained_model_file = None

In [3]:
# create the model
CHOOSE_MODEL = "gpt2-small (124M)"
cfg, _ = GetModelConfig(CHOOSE_MODEL)
model = load_model(pretrained_model_file, cfg)

#load the structures data set
data_file_name = os.path.join("Datasets", f"classification_dataset_encoded.dat")
with open(data_file_name, "rb") as infile:
    sequences_with_structures = torch.load(infile)

# split to train and test sets
sample_count = len(sequences_with_structures)
train_set_ratio = 0.9
train_set_size = train_set_ratio * sample_count
train_data = sequences_with_structures[:int(train_set_size)]
val_data = sequences_with_structures[int(train_set_size):]

#create the dataloaders
train_dataloader = create_protein_structure_dataloader(train_data, context_length=cfg["context_length"], batch_size=10,
                         shuffle=True, drop_last=True, num_workers=0)

validation_dataloader = create_protein_structure_dataloader(val_data, context_length=cfg["context_length"], batch_size=10,
                         shuffle=False, drop_last=False, num_workers=0)
# get a cuda device if the system supports it
device = GetDevice()
model.to(device)

# create the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1)
#create the tokenizer
tokenizer = ProteinTokenizer()

In [None]:
#perform the training
train_losses, val_losses, track_tokens_seen = train_model_simple(model, train_dataloader, validation_dataloader, optimizer, device, 50,
                       1000, 2, None, tokenizer, "TrainedModelClassification")

In [None]:
# save the final model
file_name = "TrainedModels/TrainedModelClassification_final.pth"
#file_name = "TrainedModels/TrainedModelClassification_final_pretrained.pth"
torch.save(model.state_dict(), file_name)