In [1]:
# !sinfo -O Nodehost,Gres:.30,GresUsed:.45
# !salloc -N 1 --cpus-per-task=4 -p CS177h --gres=gpu:TeslaM4024GB:1
# !salloc -N 1 --cpus-per-task=8 -p CS177h --gres=gpu:TeslaM4024GB:2
# !salloc -N 1 --cpus-per-task=12 -p CS177h --gres=gpu:TeslaM4024GB:3
# !salloc -N 1 --cpus-per-task=16 -p CS177h --gres=gpu:TeslaM4024GB:4
# !jupyter-lab --no-brows --ip=0.0.0.0 --port=7774 

In [2]:
!nvidia-smi

Wed Dec 14 04:50:02 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 512.36       Driver Version: 512.36       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   54C    P5    17W /  N/A |   1364MiB /  6144MiB |     18%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# MSA Transformer + MLP Baseline

In [3]:
import os
import sys
from pathlib import Path

# PROJECT_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring"
PROJECT_PATH = Path() 
sys.path.append(str(PROJECT_PATH))
from tqdm import tqdm
from pathlib import Path

import numpy as np
import plotly.graph_objs as go
import matplotlib.pyplot as plt
plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import esm

DATASET_PATH     = PROJECT_PATH / "dataset" / "CASP14_fm"
MODEL_PATH       = PROJECT_PATH / "model"
EMBDEDDINGS_PATH = PROJECT_PATH / "embeddings"
TRANSFORMER_PATH = PROJECT_PATH / "esm_msa1b_t12_100M_UR50S.pt"

# hyperparameters and settings
LEARNING_RATE    = 1e-4
BATCH_SIZE       = 64
EPOCHS           = 50
lambda_reg       = 5.0
checkpoint_epoch = 0
model_name       = "pair_bos_256"

Define model

In [4]:
# !scp -r tengyue@10.15.89.191:/public/home/cs177h/tengyue/Project/ShanghaiTech-CS177H-MSA-Scoring/embeddings /public/home/cs177h/lianyh/perl5/project/embeddings

In [5]:
class MSAPredictor(nn.Module):
    def __init__(self, msa_transformer_path = TRANSFORMER_PATH):
        super(MSAPredictor, self).__init__()
        """
        if msa_transformer_path:
            self.encoder, msa_alphabet = esm.pretrained.load_model_and_alphabet_local(msa_transformer_path)
        else :
            self.encoder, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
        
        self.encoder = self.encoder.eval()
        self.batch_converter = msa_alphabet.get_batch_converter()
        """

        # Freeze parameters of MSATransformer
        """
        for param in self.encoder.parameters():
            param.requires_grad = False
        """
        # Regressor module (to be tested)
        # self.conv1 = nn.Conv2d(1, 6, 5)
        # self.pool = nn.MaxPool2d(3, 3)
        # self.conv2 = nn.Conv2d(6, 16, 5)
        # self.fc1 = nn.Linear(25232, 2048)
        # self.fc2 = nn.Linear(2048, 512)
        # self.fc3 = nn.Linear(512, 1)
        self.fc1 = nn.Linear(768, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 1)
        self.dp1 = nn.Dropout()
        self.dp2 = nn.Dropout()
        
    def forward(self, x):
        x = torch.mean(x[:, 1:, :], dim = 1)
        # BATHCH_SIZE x 768

        # x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))

        # x = torch.flatten(x, 1)
        # x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        # x = self.fc3(x)
        # x = torch.sigmoid(self.fc3(x))
        
        x = self.dp1(F.leaky_relu(self.fc1(x)))
        x = self.dp2(F.leaky_relu(self.fc2(x)))
        x = self.fc3(x)
        # x = torch.sigmoid(self.fc3(x))
        return x

In [6]:
class MSAPredictorBOS(MSAPredictor):
    def __init__(self, msa_transformer_path = TRANSFORMER_PATH):
        super(MSAPredictorBOS, self).__init__(msa_transformer_path)
        
    def forward(self, x):
        
        x = x[:, 0, :]
        # BATHCH_SIZE x 768
        
        x = self.dp1(F.leaky_relu(self.fc1(x)))
        x = self.dp2(F.leaky_relu(self.fc2(x)))
        x = self.fc3(x)
        # x = torch.sigmoid(self.fc3(x))
        return x

In [7]:
class PairMSAPredictor(nn.Module):
    def __init__(self, msa_transformer_path = TRANSFORMER_PATH, type = 'pair_bos_256'):
        super(PairMSAPredictor, self).__init__()
        self.regressor = MSAPredictorBOS(msa_transformer_path) if type == 'pair_bos_256' else MSAPredictor(msa_transformer_path)
        
    def forward(self, x1, x2):
        y1, y2 = self.regressor(x1), self.regressor(x2)
        delta, y1, y2 = torch.sigmoid(y1 - y2), torch.sigmoid(y1), torch.sigmoid(y2)
        return delta, y1, y2

In [8]:
if model_name == "bos_256" :
    model = MSAPredictorBOS().cuda()
elif model_name == "mean_256" :
    model = MSAPredictor().cuda()
elif model_name[:4] == "pair" :
    model = PairMSAPredictor(type = model_name).cuda()

NUM_GPU = torch.cuda.device_count()
USE_PARALLEL = NUM_GPU > 1
if USE_PARALLEL :
    model = torch.nn.DataParallel(model)

In [9]:
loss_fn = nn.BCELoss()
loss_reg_fn = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

if checkpoint_epoch > 0 :
    checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{checkpoint_epoch}.pth")
    (model.module if USE_PARALLEL else model).load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"The {model_name} model loaded has been trained for {epoch} epoche(s), with {checkpoint['train_mse']} training loss, {checkpoint['valid_mse']} validation MSE and {checkpoint['test_acc']} test accuracy. ")
else :
    print(f"Start training {model_name} model from the 1st epoch.")

Start training pair_bos_256 model from the 1st epoch.


Prepare dataset

In [10]:
from dataset import EmbeddingScoreDataset, PairDataset
from torch.utils.data import random_split

train_dataset = PairDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = True) if model_name[:4] == "pair" else EmbeddingScoreDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = True)
test_dataset  = PairDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = False)

In [11]:
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, num_workers = NUM_GPU * 4, pin_memory = True, shuffle = True)
test_loader  = DataLoader(test_dataset, batch_size = 1, shuffle = False)

In [12]:
# explore batch shape
it = iter(train_loader)
it = next(it)
if model_name[:4] == "pair" :
    print(it['embedding1'].shape, it['score1'].shape, len(it['name1']))
    print(it['embedding2'].shape, it['score2'].shape, len(it['name2']))
else :
    print(it['embedding'].shape, it['score'].shape, len(it['name']))

torch.Size([64, 584, 768]) torch.Size([64, 1]) 64
torch.Size([64, 584, 768]) torch.Size([64, 1]) 64


In [13]:
def calc_test_metric(model, print_wrong_predictions = False, print_correct_predictions = False):
    model.eval()
    mse, correct, tot = 0.0, 0, 0
    with torch.no_grad():
        with tqdm(total = len(test_loader), ncols = 80, file = sys.stdout) as bar:
            for sample in test_loader :
                x1, x2, y1, y2 = sample["embedding1"].cuda(), sample["embedding2"].cuda(), sample["score1"].cuda(), sample["score2"].cuda()
                x, y = torch.vstack([x1, x2]), torch.vstack([y1, y2])
                
                predictor = model.module if USE_PARALLEL else model
                pred = torch.sigmoid(predictor.regressor(x) if model_name[:4] == 'pair' else predictor(x))

                if torch.argmax(y) == torch.argmax(pred) :
                    correct += 1
                    if print_correct_predictions:
                        for name, y_gt, y_pred in zip([sample['name1'][0], sample['name2'][0]], y, pred) :
                            tqdm.write(f"{name:>20}  y_gt:{y_gt.item():.4f}  y_pred:{y_pred.item() :.4f}")
                        tqdm.write("-----------------------------------------------")
                else :
                    if print_wrong_predictions:
                        for name, y_gt, y_pred in zip([sample['name1'][0], sample['name2'][0]], y, pred) :
                            tqdm.write(f"{name:>20}  y_gt:{y_gt.item() :.4f}  y_pred:{y_pred.item() :.4f}")
                        tqdm.write("===================================================")
                
                mse += torch.sum((y - pred) ** 2).item()
                tot += 2

                bar.set_postfix({
                    "acc": f"{correct / (tot // 2):.4f}",
                    "mse": f"{mse / tot:.4f}"
                })
                bar.update(1)
    
    return {'accuracy' : correct / (tot // 2), 'mse' : mse / tot}

Train:

In [14]:
for epoch in range(checkpoint_epoch + 1, EPOCHS + 1):
    losses, losses_reg, tot = 0, 0, 0
    
    model.train()

    with tqdm(total = len(train_loader), ncols = 120) as bar:
        for batch, sample in enumerate(train_loader) :
            
            bar.set_description(f"[epoch#{epoch:>2}/{EPOCHS:>2}][{batch * BATCH_SIZE:>5}/{len(train_dataset):>5}]")

            if model_name[:4] == 'pair' :
                x1, x2, y1, y2 = sample["embedding1"].cuda(), sample["embedding2"].cuda(), sample["score1"].cuda(), sample["score2"].cuda()
                # delta = torch.sigmoid((y1 - y2) / SIGMOID_SCALE)
                delta = torch.where(y1 > y2, 1.0, 0.0)

                logit, y1_pred, y2_pred = model(x1, x2)

                loss_reg = loss_reg_fn(torch.vstack([y1_pred, y2_pred]), torch.vstack([y1, y2])) 
                loss_bce = loss_fn(logit, delta)
                loss = loss_bce + lambda_reg * loss_reg
            else :
                x, y = sample["embedding"].cuda(), sample["score"].cuda()
                pred = torch.sigmoid(model(x))
                loss_reg = loss_reg_fn(pred, y)
                loss = loss_reg

            # Backpropagation
            optimizer.zero_grad()     
            loss.backward()
            optimizer.step()
            
            losses_reg += loss_reg.item() * len(sample) 
            losses += loss.item() * len(sample)
            tot += len(sample)

            bar.set_postfix({
                # "batch loss" : f"{loss.item():.5f}",
                "mse" : f"{losses_reg / tot:.5f}", 
                "loss": f"{losses / tot:.5f}"
            })
            bar.update(1)

    test_metric =  calc_test_metric(model)
    train_loss, train_mse, test_acc, valid_mse = losses / tot, losses_reg / tot, test_metric['accuracy'], test_metric['mse']
    torch.save({
        'epoch'               : epoch,
        'model_state_dict'    : (model.module if USE_PARALLEL else model).state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_mse'           : train_mse,
        'train_loss'          : train_loss,
        'valid_mse'           : valid_mse,
        'test_acc'            : test_acc
    }, MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")
    

[epoch# 1/50][13248/13300]: 100%|██████████████████████████| 208/208 [00:36<00:00,  5.69it/s, mse=0.05472, loss=0.95031]


100%|██████████████████| 95/95 [00:00<00:00, 136.15it/s, acc=0.9474, mse=0.0498]


[epoch# 2/50][13248/13300]: 100%|██████████████████████████| 208/208 [00:34<00:00,  5.97it/s, mse=0.04019, loss=0.82993]


100%|██████████████████| 95/95 [00:00<00:00, 108.72it/s, acc=0.9579, mse=0.0323]


[epoch# 3/50][11776/13300]:  89%|███████████████████████▏  | 185/208 [00:33<00:03,  6.03it/s, mse=0.03414, loss=0.77551]

In [None]:
calc_test_metric(model, print_wrong_predictions = True)

T1030-D1_original_fm  y_gt:0.8572  y_pred:0.8713                                
  T1030-D1_rand13_fm  y_gt:0.8149  y_pred:0.8837                                
T1047s2-D2_rand16_fm  y_gt:0.7530  y_pred:0.9189                                
 T1047s2-D2_rand5_fm  y_gt:0.9398  y_pred:0.7483                                
  T1073-D1_rand12_fm  y_gt:0.9280  y_pred:0.8576                                
  T1073-D1_rand13_fm  y_gt:0.8941  y_pred:0.8616                                
  T1100-D1_rand13_fm  y_gt:0.6080  y_pred:0.7381                                
 T1100-D1_rosetta_fm  y_gt:0.6775  y_pred:0.7155                                
100%|██████████████████| 95/95 [00:00<00:00, 154.99it/s, acc=0.9579, mse=0.0249]


{'accuracy': 0.9578947368421052, 'mse': 0.024935889437696653}

In [None]:
def plot_training_history(model_name = "mean_256", trained_epoches = 20):
    epoches = list(range(1, trained_epoches + 1))
    train_losses = []
    test_losses = []
    test_accs = []
    for epoch in range(1, trained_epoches+1):
        checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")
        train_losses.append(checkpoint["train_mse"])
        test_losses.append(checkpoint["valid_mse"])
        test_accs.append(checkpoint["test_acc"])
    
    trace1 = go.Scatter(
        x = epoches,
        y = train_losses,
        name= "Train Loss",
        xaxis='x',
        yaxis='y1',
        mode='lines+markers'
    )
    trace2 = go.Scatter(
        x = [epoch + 0.5 for epoch in epoches],
        y = test_losses,
        name= "Valid MSE",
        xaxis='x', 
        yaxis='y1',
        mode='lines+markers'
    )
    trace3 = go.Scatter(
        x = epoches,
        y = test_accs,
        name='Test Accuracy',
        xaxis='x', 
        yaxis='y2',
        mode='lines+markers'
    )
    
    data = [trace1, trace2, trace3]
    layout = go.Layout(
        yaxis2=dict(overlaying = 'y', side = 'right', title = "Accuracy", range = [0.75, 1.0]),
        yaxis1=dict(title = "MSE Loss"),
        xaxis = dict(title = "Epoch"),
        legend=dict(x=0.75, y=0.55, font=dict(size=12, color="black"))
    )
    
    fig = go.Figure(data=data, layout=layout)
    fig.show()

plot_training_history("pair_bos_256", 15)