In [3]:
import sys
sys.path.append("/public/home/cs177h/tengyue/Project/ShanghaiTech-CS177H-MSA-Scoring")

from tqdm import tqdm
from pathlib import Path

import numpy as np
import plotly.graph_objs as go
import matplotlib.pyplot as plt
plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import esm

DATASET_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring" / "dataset" / "CASP14_fm"
MODEL_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring" / "model"
EMBDEDDINGS_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring" / "embeddings"
TRANSFORMER_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring" / "esm_msa1b_t12_100M_UR50S.pt"

# hyperparameters
MAX_DEPTH = 256
EPOCHES = 50
LEARNING_RATE = 1e-4

In [57]:
from dataset import MSAScoreDataset
train_dataset = MSAScoreDataset(root = DATASET_PATH, is_train = True)
test_dataset = MSAScoreDataset(root = DATASET_PATH, is_train = False)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 2660/2660 [00:42<00:00, 63.19it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 190/190 [00:07<00:00, 25.26it/s]


In [65]:
most = 0
name_ = None
for name in train_dataset.msa_name_list:
    if len(train_dataset.msa_data[name]['msa']) > most:
        most = max(len(train_dataset.msa_data[name]['msa']), most)
        name_ = name
print(most)
print(name_)

256
T1024-D1_base_fm


In [55]:
msa_transformer, msa_alphabet = esm.pretrained.load_model_and_alphabet_local(TRANSFORMER_PATH)
msa_transformer = msa_transformer.eval().cuda()
msa_batch_converter = msa_alphabet.get_batch_converter()

In [28]:
print(len(train_dataset.msa_name_list))
"problem: 269" "done: 1184"

2660


'problem: 269done: 1184'

In [70]:
padding = 0
for i in range(0, len(train_dataset.msa_name_list)):

    name = train_dataset.msa_name_list[i]
    t_data = [train_dataset.msa_data[name]['msa']]
    msa_batch_labels, msa_batch_strs, msa_batch_tokens = msa_batch_converter(t_data)

    msa_transformer.eval()
    with torch.no_grad():
        torch.cuda.empty_cache()
        results = msa_transformer(msa_batch_tokens.cuda(non_blocking = True), repr_layers=[12])
        embeddings = results['representations'][12][0, 0, :, :]
    
    torch.save(embeddings, EMBDEDDINGS_PATH / (name + ".pt"))
    if (len(embeddings) > padding):
        padding = len(embeddings)
    print(i," : ",name, " ", embeddings.shape, " ", padding)
    

0  :  T1024-D1_aug_fm   torch.Size([194, 768])   194
1  :  T1024-D1_base_fm   torch.Size([194, 768])   194
2  :  T1024-D1_cov50_fm   torch.Size([194, 768])   194
3  :  T1024-D1_deduplicated_fm   torch.Size([194, 768])   194
4  :  T1024-D1_original_fm   torch.Size([194, 768])   194
5  :  T1024-D1_our_fm   torch.Size([194, 768])   194
6  :  T1024-D1_rand11_fm   torch.Size([194, 768])   194
7  :  T1024-D1_rand12_fm   torch.Size([194, 768])   194
8  :  T1024-D1_rand13_fm   torch.Size([194, 768])   194
9  :  T1024-D1_rand14_fm   torch.Size([194, 768])   194
10  :  T1024-D1_rand15_fm   torch.Size([194, 768])   194
11  :  T1024-D1_rand16_fm   torch.Size([194, 768])   194
12  :  T1024-D1_rand17_fm   torch.Size([194, 768])   194
13  :  T1024-D1_rand18_fm   torch.Size([194, 768])   194
14  :  T1024-D1_rand19_fm   torch.Size([194, 768])   194
15  :  T1024-D1_rand1_fm   torch.Size([194, 768])   194
16  :  T1024-D1_rand20_fm   torch.Size([194, 768])   194
17  :  T1024-D1_rand21_fm   torch.Size([194