In [81]:
# !sinfo -O Nodehost,Gres:.30,GresUsed:.45
# !salloc -N 1 --cpus-per-task=4 -p CS177h --gres=gpu:TeslaM4024GB:1
# !salloc -N 1 --cpus-per-task=8 -p CS177h --gres=gpu:TeslaM4024GB:2
# !salloc -N 1 --cpus-per-task=12 -p CS177h --gres=gpu:TeslaM4024GB:3
# !salloc -N 1 --cpus-per-task=16 -p CS177h --gres=gpu:TeslaM4024GB:4
# !jupyter-lab --no-brows --ip=0.0.0.0 --port=7774 

In [82]:
!nvidia-smi

Wed Dec 14 20:04:58 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 512.36       Driver Version: 512.36       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   65C    P0    56W /  N/A |   5916MiB /  6144MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# MSA Transformer + MLP Baseline

In [83]:
import os
import sys
from pathlib import Path

# PROJECT_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring"
PROJECT_PATH = Path() 
sys.path.append(str(PROJECT_PATH))
from tqdm import tqdm
from pathlib import Path

import numpy as np
import plotly.graph_objs as go

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import esm

DATASET_PATH     = PROJECT_PATH / "dataset" / "CASP14_fm"
MODEL_PATH       = PROJECT_PATH / "model"
EMBDEDDINGS_PATH = PROJECT_PATH / "embeddings"
TRANSFORMER_PATH = PROJECT_PATH / "esm_msa1b_t12_100M_UR50S.pt"

# hyperparameters and settings
LEARNING_RATE    = 1e-3
BATCH_SIZE       = 64
EPOCHS           = 20
lambda_reg       = 0.5
checkpoint_epoch = 20
model_name       = "pair_bos_256"

Define model

In [84]:
# !scp -r tengyue@10.15.89.191:/public/home/cs177h/tengyue/Project/ShanghaiTech-CS177H-MSA-Scoring/embeddings /public/home/cs177h/lianyh/perl5/project/embeddings

In [85]:
class MSAPredictor(nn.Module):
    def __init__(self, msa_transformer_path = TRANSFORMER_PATH):
        super(MSAPredictor, self).__init__()
        """
        if msa_transformer_path:
            self.encoder, msa_alphabet = esm.pretrained.load_model_and_alphabet_local(msa_transformer_path)
        else :
            self.encoder, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
        
        self.encoder = self.encoder.eval()
        self.batch_converter = msa_alphabet.get_batch_converter()
        """

        # Freeze parameters of MSATransformer
        """
        for param in self.encoder.parameters():
            param.requires_grad = False
        """
        # Regressor module (to be tested)
        # self.conv1 = nn.Conv2d(1, 6, 5)
        # self.pool = nn.MaxPool2d(3, 3)
        # self.conv2 = nn.Conv2d(6, 16, 5)
        # self.fc1 = nn.Linear(25232, 2048)
        # self.fc2 = nn.Linear(2048, 512)
        # self.fc3 = nn.Linear(512, 1)
        self.fc1 = nn.Linear(768, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 1)
        self.dp1 = nn.Dropout(0.2)
        self.dp2 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = torch.mean(x[:, 1:, :], dim = 1)
        # BATHCH_SIZE x 768

        # x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))

        # x = torch.flatten(x, 1)
        # x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        # x = self.fc3(x)
        # x = torch.sigmoid(self.fc3(x))
        
        x = self.dp1(F.leaky_relu(self.fc1(x)))
        x = self.dp2(F.leaky_relu(self.fc2(x)))
        x = self.fc3(x)
        # x = torch.sigmoid(self.fc3(x))
        return x

In [86]:
class MSAPredictorBOS(MSAPredictor):
    def __init__(self, msa_transformer_path = TRANSFORMER_PATH):
        super(MSAPredictorBOS, self).__init__(msa_transformer_path)
        
    def forward(self, x):
        
        x = x[:, 0, :]
        # BATHCH_SIZE x 768
        
        x = self.dp1(F.leaky_relu(self.fc1(x)))
        x = self.dp2(F.leaky_relu(self.fc2(x)))
        x = self.fc3(x)
        # x = torch.sigmoid(self.fc3(x))
        return x

In [87]:
class PairMSAPredictor(nn.Module):
    def __init__(self, msa_transformer_path = TRANSFORMER_PATH, type = 'pair_bos_256'):
        super(PairMSAPredictor, self).__init__()
        self.regressor = MSAPredictorBOS(msa_transformer_path) if type == 'pair_bos_256' else MSAPredictor(msa_transformer_path)
        
    def forward(self, x1, x2):
        y1, y2 = self.regressor(x1), self.regressor(x2)
        delta, y1, y2 = torch.sigmoid(y1 - y2), torch.sigmoid(y1), torch.sigmoid(y2)
        return delta, y1, y2

In [88]:
if model_name == "bos_256" :
    model = MSAPredictorBOS().cuda()
elif model_name == "mean_256" :
    model = MSAPredictor().cuda()
elif model_name[:4] == "pair" :
    model = PairMSAPredictor(type = model_name).cuda()

NUM_GPU = torch.cuda.device_count()
USE_PARALLEL = NUM_GPU > 1
if USE_PARALLEL :
    model = torch.nn.DataParallel(model)

In [89]:
loss_fn = nn.BCELoss()
loss_reg_fn = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

if checkpoint_epoch > 0 :
    checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{checkpoint_epoch}.pth")
    (model.module if USE_PARALLEL else model).load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"The {model_name} model loaded has been trained for {epoch} epoche(s), with {checkpoint['train_mse']} training loss, {checkpoint['valid_mse']} validation MSE and {checkpoint['test_acc']} test accuracy. ")
else :
    print(f"Start training {model_name} model from the 1st epoch.")

The pair_bos_256 model loaded has been trained for 20 epoche(s), with 0.024229296451756872 training loss, 0.021515664323171795 validation MSE and 0.9578947368421052 test accuracy. 


Prepare dataset

In [90]:
from dataset import EmbeddingScoreDataset, PairDataset
from torch.utils.data import random_split

train_dataset = PairDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = True) if model_name[:4] == "pair" else EmbeddingScoreDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = True)
test_dataset  = PairDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = False)

In [91]:
from dataset import read_embedding
x1 = read_embedding('T1024-D1_rand10_fm', 584, EMBDEDDINGS_PATH)
x2 = read_embedding('T1024-D1_aug_fm', 584, EMBDEDDINGS_PATH)
x3 = read_embedding('T1024-D1_deduplicated_fm', 584, EMBDEDDINGS_PATH)
x4 = read_embedding('T1024-D1_meta_fm', 584, EMBDEDDINGS_PATH)

x = torch.stack([x1, x2, x3, x4])
with torch.no_grad():
    score = torch.sigmoid(model.regressor(x.cuda()))
score.cpu().numpy()

array([[0.40373233],
       [0.7987787 ],
       [0.8523313 ],
       [0.9566743 ]], dtype=float32)

In [28]:
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, num_workers = NUM_GPU * 4, pin_memory = True, shuffle = True)
test_loader  = DataLoader(test_dataset, batch_size = 1, shuffle = False)

In [29]:
# explore batch shape
it = iter(train_loader)
it = next(it)
if model_name[:4] == "pair" :
    print(it['embedding1'].shape, it['score1'].shape, len(it['name1']))
    print(it['embedding2'].shape, it['score2'].shape, len(it['name2']))
else :
    print(it['embedding'].shape, it['score'].shape, len(it['name']))

torch.Size([64, 584, 768]) torch.Size([64, 1]) 64
torch.Size([64, 584, 768]) torch.Size([64, 1]) 64


In [30]:
def calc_test_metric(model, print_wrong_predictions = False, print_correct_predictions = False):
    model.eval()
    mse, correct, tot = 0.0, 0, 0
    with torch.no_grad():
        with tqdm(total = len(test_loader), ncols = 80, file = sys.stdout) as bar:
            for sample in test_loader :
                x1, x2, y1, y2 = sample["embedding1"].cuda(), sample["embedding2"].cuda(), sample["score1"].cuda(), sample["score2"].cuda()
                x, y = torch.vstack([x1, x2]), torch.vstack([y1, y2])
                
                predictor = model.module if USE_PARALLEL else model
                pred = torch.sigmoid(predictor.regressor(x) if model_name[:4] == 'pair' else predictor(x))

                if torch.argmax(y) == torch.argmax(pred) :
                    correct += 1
                    if print_correct_predictions:
                        for name, y_gt, y_pred in zip([sample['name1'][0], sample['name2'][0]], y, pred) :
                            tqdm.write(f"{name:>20}  y_gt:{y_gt.item():.4f}  y_pred:{y_pred.item() :.4f}")
                        tqdm.write("-----------------------------------------------")
                else :
                    if print_wrong_predictions:
                        for name, y_gt, y_pred in zip([sample['name1'][0], sample['name2'][0]], y, pred) :
                            tqdm.write(f"{name:>20}  y_gt:{y_gt.item() :.4f}  y_pred:{y_pred.item() :.4f}")
                        tqdm.write("===================================================")
                
                mse += torch.sum((y - pred) ** 2).item()
                tot += 2

                bar.set_postfix({
                    "acc": f"{correct / (tot // 2):.4f}",
                    "mse": f"{mse / tot:.4f}"
                })
                bar.update(1)
    
    return {'accuracy' : correct / (tot // 2), 'mse' : mse / tot}

Train:

In [31]:
for epoch in range(checkpoint_epoch + 1, EPOCHS + 1):
    losses, losses_reg, tot = 0, 0, 0
    
    model.train()

    with tqdm(total = len(train_loader), ncols = 120) as bar:
        for batch, sample in enumerate(train_loader) :
            
            bar.set_description(f"[epoch#{epoch:>2}/{EPOCHS:>2}][{batch * BATCH_SIZE:>5}/{len(train_dataset):>5}]")

            if model_name[:4] == 'pair' :
                x1, x2, y1, y2 = sample["embedding1"].cuda(), sample["embedding2"].cuda(), sample["score1"].cuda(), sample["score2"].cuda()
                # delta = torch.sigmoid((y1 - y2) / SIGMOID_SCALE)
                delta = torch.where(y1 > y2, 1.0, 0.0)

                logit, y1_pred, y2_pred = model(x1, x2)

                loss_reg = loss_reg_fn(torch.vstack([y1_pred, y2_pred]), torch.vstack([y1, y2])) 
                loss_bce = loss_fn(logit, delta)
                loss = loss_bce + lambda_reg * loss_reg
            else :
                x, y = sample["embedding"].cuda(), sample["score"].cuda()
                pred = torch.sigmoid(model(x))
                loss_reg = loss_reg_fn(pred, y)
                loss = loss_reg

            # Backpropagation
            optimizer.zero_grad()     
            loss.backward()
            optimizer.step()
            
            losses_reg += loss_reg.item() * len(sample) 
            losses += loss.item() * len(sample)
            tot += len(sample)

            bar.set_postfix({
                # "batch loss" : f"{loss.item():.5f}",
                "mse" : f"{losses_reg / tot:.5f}", 
                "loss": f"{losses / tot:.5f}"
            })
            bar.update(1)

    test_metric =  calc_test_metric(model)
    train_loss, train_mse, test_acc, valid_mse = losses / tot, losses_reg / tot, test_metric['accuracy'], test_metric['mse']
    torch.save({
        'epoch'               : epoch,
        'model_state_dict'    : (model.module if USE_PARALLEL else model).state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_mse'           : train_mse,
        'train_loss'          : train_loss,
        'valid_mse'           : valid_mse,
        'test_acc'            : test_acc
    }, MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")
    

[epoch# 1/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:51<00:00,  4.03it/s, mse=0.04752, loss=0.63095]


100%|███████████████████| 95/95 [00:01<00:00, 61.07it/s, acc=0.9474, mse=0.0327]


[epoch# 2/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:46<00:00,  4.44it/s, mse=0.04117, loss=0.58139]


100%|██████████████████| 95/95 [00:00<00:00, 107.49it/s, acc=0.9474, mse=0.0323]


[epoch# 3/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:35<00:00,  5.86it/s, mse=0.04027, loss=0.56038]


100%|██████████████████| 95/95 [00:00<00:00, 164.36it/s, acc=0.9474, mse=0.0347]


[epoch# 4/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:35<00:00,  5.84it/s, mse=0.03881, loss=0.54502]


100%|██████████████████| 95/95 [00:00<00:00, 178.78it/s, acc=0.9684, mse=0.0332]


[epoch# 5/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:34<00:00,  6.01it/s, mse=0.03918, loss=0.53572]


100%|██████████████████| 95/95 [00:00<00:00, 144.09it/s, acc=0.9684, mse=0.0296]


[epoch# 6/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:35<00:00,  5.92it/s, mse=0.03828, loss=0.52797]


100%|██████████████████| 95/95 [00:00<00:00, 164.08it/s, acc=0.9579, mse=0.0310]


[epoch# 7/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:35<00:00,  5.94it/s, mse=0.03842, loss=0.52212]


100%|██████████████████| 95/95 [00:00<00:00, 176.25it/s, acc=0.9684, mse=0.0300]


[epoch# 8/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:34<00:00,  6.00it/s, mse=0.03821, loss=0.51277]


100%|██████████████████| 95/95 [00:00<00:00, 191.80it/s, acc=0.9684, mse=0.0301]


[epoch# 9/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:34<00:00,  6.00it/s, mse=0.03789, loss=0.50793]


100%|██████████████████| 95/95 [00:00<00:00, 118.79it/s, acc=0.9789, mse=0.0316]


[epoch#10/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:34<00:00,  5.96it/s, mse=0.03752, loss=0.50392]


100%|██████████████████| 95/95 [00:00<00:00, 187.36it/s, acc=0.9579, mse=0.0280]


[epoch#11/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:34<00:00,  6.01it/s, mse=0.03739, loss=0.50225]


100%|██████████████████| 95/95 [00:00<00:00, 167.49it/s, acc=0.9684, mse=0.0308]


[epoch#12/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:35<00:00,  5.89it/s, mse=0.03714, loss=0.50105]


100%|██████████████████| 95/95 [00:00<00:00, 163.40it/s, acc=0.9789, mse=0.0290]


[epoch#13/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:35<00:00,  5.92it/s, mse=0.03701, loss=0.49385]


100%|██████████████████| 95/95 [00:00<00:00, 132.39it/s, acc=0.9579, mse=0.0292]


[epoch#14/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:34<00:00,  5.97it/s, mse=0.03843, loss=0.49344]


100%|██████████████████| 95/95 [00:00<00:00, 148.94it/s, acc=0.9579, mse=0.0308]


[epoch#15/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:37<00:00,  5.48it/s, mse=0.03803, loss=0.48684]


100%|██████████████████| 95/95 [00:00<00:00, 111.63it/s, acc=0.9684, mse=0.0293]


[epoch#16/20][ 1600/13300]:  12%|███▏                       | 25/208 [00:09<01:12,  2.52it/s, mse=0.03744, loss=0.48407]


KeyboardInterrupt: 

In [None]:
calc_test_metric(model, print_wrong_predictions = True)

T1030-D1_original_fm  y_gt:0.8572  y_pred:0.8304                                
  T1030-D1_rand13_fm  y_gt:0.8149  y_pred:0.9053                                
    T1030-D2_base_fm  y_gt:0.9118  y_pred:0.9039                                
T1030-D2_original_fm  y_gt:0.6008  y_pred:0.9076                                
T1046s1-D1_rand13_fm  y_gt:0.6354  y_pred:0.6863                                
T1046s1-D1_rosetta_fm  y_gt:0.9791  y_pred:0.6065                               
T1047s2-D2_rand16_fm  y_gt:0.7530  y_pred:0.8924                                
 T1047s2-D2_rand5_fm  y_gt:0.9398  y_pred:0.7512                                
    T1053-D1_base_fm  y_gt:0.9542  y_pred:0.7899                                
     T1053-D1_our_fm  y_gt:0.7264  y_pred:0.8641                                
  T1070-D3_rand13_fm  y_gt:0.7369  y_pred:0.7078                                
  T1070-D3_rand16_fm  y_gt:0.8322  y_pred:0.7017                                
  T1100-D1_rand13_fm  y_gt:0

{'accuracy': 0.9263157894736842, 'mse': 0.014492791874554793}

In [43]:
def plot_training_history(model_name = "mean_256", trained_epoches = 20):
    epoches = list(range(1, trained_epoches + 1))
    train_losses = []
    test_losses = []
    test_accs = []
    for epoch in range(1, trained_epoches+1):
        checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")
        train_losses.append(checkpoint["train_mse"])
        test_losses.append(checkpoint["valid_mse"])
        test_accs.append(checkpoint["test_acc"])
    
    trace1 = go.Scatter(
        x = epoches,
        y = train_losses,
        name= "Train Loss",
        xaxis='x',
        yaxis='y1',
        mode='lines+markers'
    )
    trace2 = go.Scatter(
        x = [epoch + 0.5 for epoch in epoches],
        y = test_losses,
        name= "Valid MSE",
        xaxis='x', 
        yaxis='y1',
        mode='lines+markers'
    )
    trace3 = go.Scatter(
        x = epoches,
        y = test_accs,
        name='Test Accuracy',
        xaxis='x', 
        yaxis='y2',
        mode='lines+markers'
    )
    
    data = [trace1, trace2, trace3]
    layout = go.Layout(
        yaxis2=dict(overlaying = 'y', side = 'right', title = "Accuracy", range = [0.75, 1.0]),
        yaxis1=dict(title = "MSE Loss"),
        xaxis = dict(title = "Epoch"),
        legend=dict(x=0.75, y=0.55, font=dict(size=12, color="black"))
    )
    
    fig = go.Figure(data=data, layout=layout)
    fig.show()

plot_training_history("pair_bos_256", 20)