设置环境

In [12]:
import sys
sys.path.append("/public/home/cs177h/tengyue/Project/ShanghaiTech-CS177H-MSA-Scoring")

from tqdm import tqdm
from pathlib import Path

import numpy as np
import plotly.graph_objs as go
import matplotlib.pyplot as plt
plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import esm

DATASET_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring" / "dataset" / "CASP14_fm"
MODEL_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring" / "model"
EMBDEDDINGS_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring" / "embeddings"
TRANSFORMER_PATH = Path() / "Project" / "ShanghaiTech-CS177H-MSA-Scoring" / "esm_msa1b_t12_100M_UR50S.pt"

# hyperparameters
MAX_DEPTH = 256
EPOCHES = 30
LEARNING_RATE = 1e-4

读入数据

In [3]:
from dataset import EScoreDataset
train_dataset = EScoreDataset(EMBDEDDINGS_PATH, root = DATASET_PATH, is_train = True)
test_dataset = EScoreDataset(EMBDEDDINGS_PATH, root = DATASET_PATH, is_train = False)

In [4]:
print(train_dataset.msa_name_list[0])
print(train_dataset.msa_score[train_dataset.msa_name_list[0]])

T1024-D1_aug_fm
70.208


预测模型

In [18]:
class MSAPredictor(nn.Module):
    def __init__(self, msa_transformer_path = TRANSFORMER_PATH):
        super(MSAPredictor, self).__init__()
        """
        if msa_transformer_path:
            self.encoder, msa_alphabet = esm.pretrained.load_model_and_alphabet_local(msa_transformer_path)
        else :
            self.encoder, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
        
        self.encoder = self.encoder.eval()
        self.batch_converter = msa_alphabet.get_batch_converter()
        """

        # Freeze parameters of MSATransformer
        """
        for param in self.encoder.parameters():
            param.requires_grad = False
        """

        self.em = []

        # Regressor module (to be tested)
        # self.conv1 = nn.Conv2d(1, 6, 5)
        # self.pool = nn.MaxPool2d(3, 3)
        # self.conv2 = nn.Conv2d(6, 16, 5)
        # self.fc1 = nn.Linear(25232, 2048)
        # self.fc2 = nn.Linear(2048, 512)
        # self.fc3 = nn.Linear(512, 1)
        self.fc1 = nn.Linear(768, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 1)
        
    def forward(self, x):
        
        self.em = []

        with torch.no_grad():
            for i in range(x.size(0)):

                xi = x[i, 1:, :]
                print(xi.size())
                xi = torch.mean(xi, dim = 1)
                self.em.append(xi)
                
        x = torch.vstack(self.em)
        # BATHCH_SIZE x 768

        # x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))

        # x = torch.flatten(x, 1)
        # x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        # x = self.fc3(x)
        # x = torch.sigmoid(self.fc3(x))
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        x = x.squeeze(-1)
        return x

解析模型

In [19]:
class MSAPredictorBOS(MSAPredictor):
    def __init__(self, msa_transformer_path = TRANSFORMER_PATH):
        super(MSAPredictorBOS, self).__init__(msa_transformer_path)
        self.em = []
        
    def forward(self, x):
        
        self.em = []
        with torch.no_grad():
            for i in range(x.size(0)):

                xi = x[i, 1, :]
                self.em.append(xi)
        
        x = torch.vstack(self.em)
        # BATHCH_SIZE x 768
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        x = x.squeeze(-1)
        return x

生成模型

In [20]:
model_name = "bos_256"

if model_name == "bos_256" :
    model = MSAPredictorBOS().cuda()
elif model_name == "mean_256" :
    model = MSAPredictor().cuda()

NUM_GPU = torch.cuda.device_count()
USE_PARALLEL = NUM_GPU > 1
if USE_PARALLEL :
    model = torch.nn.DataParallel(model)

迭代器

In [21]:
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, num_workers = NUM_GPU * 4, pin_memory = True, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 2, shuffle = False)

In [22]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

checkpoint_epoch = 0
checkpoint = 0
if checkpoint_epoch > 0 :
    checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{checkpoint_epoch}.pth")
    (model.module if USE_PARALLEL else model).load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"The {model_name} model loaded has been trained for {epoch} epoche(s), with {checkpoint['train_mse']} training loss, {checkpoint['valid_mse']} validation MSE and {checkpoint['test_acc']} test accuracy. ")
else :
    print(f"Start training {model_name} model from the 1st epoch.")

Start training bos_256 model from the 1st epoch.


In [23]:
def calc_test_metric(model):
    model.eval()
    mse, correct, tot = 0.0, 0, 0
    with torch.no_grad():
        with tqdm(total = len(test_loader), ncols=80) as bar:
            for sample in test_loader :
                
                x, y = sample["embedding"].cuda(non_blocking = True), sample["score"][0].cuda()
                pred = model(x)

                mse += torch.sum((y - pred) ** 2).item() 
                correct += torch.sum(torch.argmax(y) == torch.argmax(pred)).item()
                tot += 2

                bar.set_postfix({
                    "acc" : f"{correct / (tot // 2):.4f}",
                    "mse": f"{mse / tot:.4f}"
                })
                bar.update(1)
    
    return {'accuracy' : correct / (tot // 2), 'mse' : mse / tot}

In [24]:
for epoch in range(checkpoint_epoch + 1, EPOCHES + 1):
    losses, tot = 0, 0
    
    model.train()

    with tqdm(total= len(train_loader), ncols=130) as bar:
        for batch, sample in enumerate(train_loader) :
            
            bar.set_description(f"[epoch#{epoch}/{EPOCHES}][{batch * BATCH_SIZE}/{len(train_dataset)}]")

            x = sample["embedding"].cuda(non_blocking = True)
            y = sample["score"][0].cuda()#.squeeze(-1)
            pred = model(x)
            loss = loss_fn(pred.float(), y.float())

            # Backpropagation
            optimizer.zero_grad()     
            loss.backward()
            optimizer.step()
            
            losses += loss.item() * len(sample)
            tot += len(sample)

            bar.set_postfix({
                "batch loss" : f"{loss.item():.5f}",
                "loss": f"{losses / tot:.5f}"
            })
            bar.update(1)

    test_metric =  calc_test_metric(model)
    train_mse, test_acc, valid_mse =  losses / tot, test_metric['accuracy'], test_metric['mse']
    torch.save({
        'epoch': epoch,
        'model_state_dict': (model.module if USE_PARALLEL else model).state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_mse': train_mse,
        'valid_mse': valid_mse,
        'test_acc': test_acc
    }, MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")

[epoch#1/30][2656/2660]: 100%|██████████████████████████████████| 84/84 [00:04<00:00, 19.60it/s, batch loss=0.02384, loss=0.05288]
100%|██████████████████| 95/95 [00:00<00:00, 185.55it/s, acc=0.8842, mse=0.0578]
[epoch#2/30][2656/2660]: 100%|██████████████████████████████████| 84/84 [00:04<00:00, 20.23it/s, batch loss=0.04275, loss=0.03746]
100%|██████████████████| 95/95 [00:00<00:00, 184.85it/s, acc=0.9263, mse=0.0395]
[epoch#3/30][2656/2660]: 100%|██████████████████████████████████| 84/84 [00:04<00:00, 20.03it/s, batch loss=0.02084, loss=0.02804]
100%|██████████████████| 95/95 [00:00<00:00, 179.44it/s, acc=0.9158, mse=0.0320]
[epoch#4/30][2656/2660]: 100%|██████████████████████████████████| 84/84 [00:04<00:00, 19.67it/s, batch loss=0.03095, loss=0.02346]
100%|██████████████████| 95/95 [00:00<00:00, 148.54it/s, acc=0.9263, mse=0.0287]
[epoch#5/30][2656/2660]: 100%|██████████████████████████████████| 84/84 [00:04<00:00, 20.05it/s, batch loss=0.02693, loss=0.02054]
100%|████████████████

In [26]:
from dataset import MSAScoreDataset
test_dataset_1 = MSAScoreDataset(root = DATASET_PATH, is_train = False)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 190/190 [00:07<00:00, 25.81it/s]


In [29]:
def print_predictions(model, only_print_wrong = True):
    model.eval()
    mse, correct, tot = 0.0, 0, 0
    with torch.no_grad():
        with tqdm(total = len(test_loader), ncols=80, file=sys.stdout) as bar:
            for sample in test_loader :
                
                x, y = sample["embedding"].cuda(non_blocking = True), sample["score"][0].cuda()
                pred = model(x)

                mse += torch.sum((y - pred) ** 2).item() 

                
                if torch.argmax(y) == torch.argmax(pred) :
                    correct += 1
                    """
                    if not only_print_wrong:
                        for y_gt, y_pred in zip(y, pred) :
                            tqdm.write(f"{list(x[:, 0]).index(1) if 1 in x[:, 0] else 256:>4d}x{x.size(1):>4d} y_gt:{y_gt.item():.4f}  y_pred:{y_pred.item() * 100:.4f}")
                        tqdm.write("======")
                    """
                else :
                    """
                    for y_gt, y_pred in zip(y, pred) :
                        tqdm.write(f"{list(x[:, 0]).index(1) if 1 in x[:, 0] else 256:>4d}x{x.size(1):>4d} y_gt:{y_gt.item():.4f}  y_pred:{y_pred.item() * 100:.4f}")
                    tqdm.write("-------")
                    """

                tot += 2

                bar.set_postfix({
                    "acc" : f"{correct / (tot // 2):.4f}",
                    "mse": f"{mse / tot:.4f}"
                })
                bar.update(1)
    
    return {'accuracy' : correct / (tot // 2), 'mse' : mse / tot}

print_predictions(model)

100%|██████████████████| 95/95 [00:00<00:00, 154.98it/s, acc=0.9684, mse=0.0184]


{'accuracy': 0.968421052631579, 'mse': 0.018424894298719956}

In [31]:
def plot_training_history(model_name = "mean_256", trained_epoches = 10):
    epoches = list(range(1, trained_epoches + 1))
    train_losses = []
    test_losses = []
    test_accs = []
    for epoch in tqdm(range(1, trained_epoches+1)):
        checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")
        train_losses.append(checkpoint["train_mse"])
        test_losses.append(checkpoint["valid_mse"])
        test_accs.append(checkpoint["test_acc"])
    
    trace1 = go.Scatter(
        x = epoches,
        y = train_losses,
        name= "Train Loss",
        xaxis='x',
        yaxis='y1',
        mode='lines+markers'
    )
    trace2 = go.Scatter(
        x = [epoch + 0.5 for epoch in epoches],
        y = test_losses,
        name= "Valid MSE",
        xaxis='x', 
        yaxis='y1',
        mode='lines+markers'
    )
    trace3 = go.Scatter(
        x = epoches,
        y = test_accs,
        name='Test Accuracy',
        xaxis='x', 
        yaxis='y2',
        mode='lines+markers'
    )
    
    data = [trace1, trace2, trace3]
    layout = go.Layout(
        yaxis2=dict(overlaying = 'y', side = 'right', title = "Accuracy", range = [0.85, 1.0]),
        yaxis1=dict(title = "MSE Loss"),
        xaxis = dict(title = "Epoch"),
        legend=dict(x=0.75, y=0.45, font=dict(size=12, color="black"))
    )
    
    fig = go.Figure(data=data, layout=layout)
    fig.show()

plot_training_history("bos_256", 30)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 177.35it/s]
