# MSA Transformer + CNN

In [80]:
import os
import sys
from pathlib import Path

PROJECT_PATH = Path() 
sys.path.append(str(PROJECT_PATH))
from tqdm import tqdm
from pathlib import Path

import numpy as np
import plotly.graph_objs as go

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

DATASET_PATH = Path() / "dataset" / "CASP14_fm"
MODEL_PATH       = PROJECT_PATH / "model"
EMBDEDDINGS_PATH = PROJECT_PATH / "embeddings"

# hyperparameters
LEARNING_RATE = 1e-4
EPOCHS = 20
BATCH_SIZE = 64
lambda_reg = 2.5
checkpoint_epoch = 5
model_name = "pair_cnn_256"

Prepare dataset

In [81]:
from dataset import EmbeddingScoreDataset, PairDataset

train_dataset = PairDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = True) if model_name[:4] == "pair" else EmbeddingScoreDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = True)
test_dataset  = PairDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = False)

NUM_GPU = torch.cuda.device_count()

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, num_workers = 4 * NUM_GPU, pin_memory = True, shuffle = True)
test_loader  = DataLoader(test_dataset, batch_size = 1, shuffle = False)


In [82]:
# model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels = 1,
                out_channels = 8,
                kernel_size = 5,
                stride = 2,
                padding = 0,
                bias = False
            ),
            nn.BatchNorm2d(num_features = 8),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size = 2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels = 8,
                out_channels = 16,
                kernel_size = 5,
                stride = 2,
                padding = 0,
                bias = False
            ),
            nn.BatchNorm2d(num_features = 16),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size = 2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels = 16,
                out_channels = 32,
                kernel_size = 5,
                stride = 1,
                padding = 0,
                bias = False
            ),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(10080, 256),
            nn.LeakyReLU(),
            nn.Dropout()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(256, 32),
            nn.LeakyReLU(),
            nn.Dropout()
        )
        self.fc3 = nn.Linear(32, 1)
        
    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        # print(self.num_flat_features(x))
        x = x.view(-1, self.num_flat_features(x)) 
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [83]:
class RestNetBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        output = self.conv1(x)
        output = F.relu(self.bn1(output))
        output = self.conv2(output)
        output = self.bn2(output)
        return F.relu(x + output)


class RestNetDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetDownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        extra_x = self.extra(x)
        output = self.conv1(x)
        out = F.relu(self.bn1(output))

        out = self.conv2(out)
        out = self.bn2(out)
        return F.relu(extra_x + out)

class RestNet18(nn.Module):
    def __init__(self):
        super(RestNet18, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=2, padding=0)
        self.bn1 = nn.BatchNorm2d(6)
        self.maxpool = nn.MaxPool2d(kernel_size=2)

        self.layer1 = nn.Sequential(RestNetBasicBlock(6, 6, 1),
                                    RestNetBasicBlock(6, 6, 1))

        self.layer2 = nn.Sequential(RestNetDownBlock(6, 16, [2, 1]),
                                    RestNetBasicBlock(16, 16, 1))

        self.layer3 = nn.Sequential(RestNetDownBlock(16, 32, [2, 1]),
                                    RestNetBasicBlock(32, 32, 1))

        self.layer4 = nn.Sequential(RestNetDownBlock(32, 64, [2, 1]),
                                    RestNetBasicBlock(64, 64, 1))

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.fc = nn.Linear(64, 1)

    def forward(self, x):
        x = x.unsqueeze(1)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.reshape(x.shape[0], -1)
        out = self.fc(out)
        return out

In [84]:
class PairCNN(nn.Module):
    def __init__(self, type = 'pair_cnn_256'):
        super(PairCNN, self).__init__()
        self.regressor = CNN() if type == 'pair_cnn_256' else RestNet18()
        
    def forward(self, x1, x2):
        y1, y2 = self.regressor(x1), self.regressor(x2)
        delta, y1, y2 = torch.sigmoid(y1 - y2), torch.sigmoid(y1), torch.sigmoid(y2)
        return delta, y1, y2

In [85]:
if model_name[:4] == "pair" :
    model = PairCNN(type = model_name).cuda()
elif model_name[:6] == "resnet" :
    model = RestNet18().cuda()
else :
    model = CNN().cuda()

NUM_GPU = torch.cuda.device_count()
USE_PARALLEL = NUM_GPU > 1
if USE_PARALLEL :
    model = torch.nn.DataParallel(model)
    
print(model)

PairCNN(
  (regressor): CNN(
    (conv1): Sequential(
      (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), bias=False)
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv2): Sequential(
      (0): Conv2d(8, 16, kernel_size=(5, 5), stride=(2, 2), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv3): Sequential(
      (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (f

In [86]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.BCELoss()
loss_reg_fn = nn.MSELoss()

if checkpoint_epoch > 0 :
    checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{checkpoint_epoch}.pth")
    (model.module if USE_PARALLEL else model).load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"The {model_name} model loaded has been trained for {epoch} epoche(s), with {checkpoint['train_mse']} training loss, {checkpoint['valid_mse']} validation MSE and {checkpoint['test_acc']} test accuracy. ")
else :
    print(f"Start training {model_name} model from the 1st epoch.")

The pair_cnn_256 model loaded has been trained for 5 epoche(s), with 0.020737510590921514 training loss, 0.020974770894482436 validation MSE and 0.968421052631579 test accuracy. 


In [87]:
from dataset import read_embedding
x1 = read_embedding('T1024-D1_rand10_fm', 584, EMBDEDDINGS_PATH)
x2 = read_embedding('T1024-D1_aug_fm', 584, EMBDEDDINGS_PATH)
x3 = read_embedding('T1024-D1_deduplicated_fm', 584, EMBDEDDINGS_PATH)
x4 = read_embedding('T1024-D1_meta_fm', 584, EMBDEDDINGS_PATH)

x = torch.stack([x1, x2, x3, x4])
with torch.no_grad():
    score = torch.sigmoid(model.regressor(x.cuda()))
score.cpu().numpy()

array([[0.5724634 ],
       [0.6360061 ],
       [0.98783386],
       [0.9598515 ]], dtype=float32)

In [56]:
def calc_test_metric(model, print_wrong_predictions = False, print_correct_predictions = False):
    model.eval()
    mse, correct, tot = 0.0, 0, 0
    with torch.no_grad():
        with tqdm(total = len(test_loader), ncols = 80, file = sys.stdout) as bar:
            for sample in test_loader :
                x1, x2, y1, y2 = sample["embedding1"].cuda(), sample["embedding2"].cuda(), sample["score1"].cuda(), sample["score2"].cuda()
                x, y = torch.vstack([x1, x2]), torch.vstack([y1, y2])
                
                predictor = model.module if USE_PARALLEL else model
                pred = torch.sigmoid(predictor.regressor(x) if model_name[:4] == 'pair' else predictor(x))

                if torch.argmax(y) == torch.argmax(pred) :
                    correct += 1
                    if print_correct_predictions:
                        for name, y_gt, y_pred in zip([sample['name1'][0], sample['name2'][0]], y, pred) :
                            tqdm.write(f"{name:>20}  y_gt:{y_gt.item():.4f}  y_pred:{y_pred.item() :.4f}")
                        tqdm.write("-----------------------------------------------")
                else :
                    if print_wrong_predictions:
                        for name, y_gt, y_pred in zip([sample['name1'][0], sample['name2'][0]], y, pred) :
                            tqdm.write(f"{name:>20}  y_gt:{y_gt.item() :.4f}  y_pred:{y_pred.item() :.4f}")
                        tqdm.write("===================================================")
                
                mse += torch.sum((y - pred) ** 2).item()
                tot += 2

                bar.set_postfix({
                    "acc": f"{correct / (tot // 2):.4f}",
                    "mse": f"{mse / tot:.4f}"
                })
                bar.update(1)
    
    return {'accuracy' : correct / (tot // 2), 'mse' : mse / tot}

Training

In [57]:
for epoch in range(checkpoint_epoch + 1, EPOCHS + 1):
    losses, losses_reg, tot = 0, 0, 0
    
    model.train()

    with tqdm(total = len(train_loader), ncols = 120) as bar:
        for batch, sample in enumerate(train_loader) :
            
            bar.set_description(f"[epoch#{epoch:>2}/{EPOCHS:>2}][{batch * BATCH_SIZE:>5}/{len(train_dataset):>5}]")

            if model_name[:4] == 'pair' :
                x1, x2, y1, y2 = sample["embedding1"].cuda(), sample["embedding2"].cuda(), sample["score1"].cuda(), sample["score2"].cuda()
                # delta = torch.sigmoid((y1 - y2) / SIGMOID_SCALE)
                delta = torch.where(y1 > y2, 1.0, 0.0)

                logit, y1_pred, y2_pred = model(x1, x2)

                loss_reg = loss_reg_fn(torch.vstack([y1_pred, y2_pred]), torch.vstack([y1, y2])) 
                loss_bce = loss_fn(logit, delta)
                loss = loss_bce + lambda_reg * loss_reg
            else :
                x, y = sample["embedding"].cuda(), sample["score"].cuda()
                pred = torch.sigmoid(model(x))
                loss_reg = loss_reg_fn(pred, y)
                loss = loss_reg

            # Backpropagation
            optimizer.zero_grad()     
            loss.backward()
            optimizer.step()
            
            losses_reg += loss_reg.item() * len(sample) 
            losses += loss.item() * len(sample)
            tot += len(sample)

            bar.set_postfix({
                # "batch loss" : f"{loss.item():.5f}",
                "mse" : f"{losses_reg / tot:.5f}", 
                "loss": f"{losses / tot:.5f}"
            })
            bar.update(1)

    test_metric =  calc_test_metric(model)
    train_loss, train_mse, test_acc, valid_mse = losses / tot, losses_reg / tot, test_metric['accuracy'], test_metric['mse']
    torch.save({
        'epoch'               : epoch,
        'model_state_dict'    : (model.module if USE_PARALLEL else model).state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_mse'           : train_mse,
        'train_loss'          : train_loss,
        'valid_mse'           : valid_mse,
        'test_acc'            : test_acc
    }, MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")
    

[epoch# 1/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:49<00:00,  4.23it/s, mse=0.05217, loss=0.80508]


100%|██████████████████| 95/95 [00:00<00:00, 128.05it/s, acc=0.9579, mse=0.0372]


[epoch# 2/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:40<00:00,  5.10it/s, mse=0.03522, loss=0.67248]


100%|██████████████████| 95/95 [00:00<00:00, 189.66it/s, acc=0.9474, mse=0.0276]


[epoch# 3/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:35<00:00,  5.88it/s, mse=0.02726, loss=0.60144]


100%|██████████████████| 95/95 [00:00<00:00, 185.57it/s, acc=0.9579, mse=0.0230]


[epoch# 4/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:37<00:00,  5.58it/s, mse=0.02286, loss=0.55142]


100%|██████████████████| 95/95 [00:00<00:00, 182.48it/s, acc=0.9579, mse=0.0245]


[epoch# 5/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:36<00:00,  5.64it/s, mse=0.02074, loss=0.50855]


100%|██████████████████| 95/95 [00:00<00:00, 188.50it/s, acc=0.9684, mse=0.0210]


[epoch# 6/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:38<00:00,  5.37it/s, mse=0.01957, loss=0.47065]


100%|██████████████████| 95/95 [00:00<00:00, 183.45it/s, acc=0.9895, mse=0.0245]


[epoch# 7/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:56<00:00,  3.69it/s, mse=0.01972, loss=0.43947]


100%|██████████████████| 95/95 [00:00<00:00, 167.88it/s, acc=0.9789, mse=0.0263]


[epoch# 8/20][13248/13300]: 100%|██████████████████████████| 208/208 [01:57<00:00,  1.78it/s, mse=0.02075, loss=0.41676]


100%|██████████████████| 95/95 [00:00<00:00, 159.03it/s, acc=0.9789, mse=0.0257]


[epoch# 9/20][13248/13300]: 100%|██████████████████████████| 208/208 [02:11<00:00,  1.58it/s, mse=0.02157, loss=0.40260]


100%|██████████████████| 95/95 [00:00<00:00, 162.57it/s, acc=0.9789, mse=0.0227]


[epoch#10/20][13248/13300]: 100%|██████████████████████████| 208/208 [01:24<00:00,  2.46it/s, mse=0.02243, loss=0.38603]


100%|██████████████████| 95/95 [00:00<00:00, 160.94it/s, acc=0.9789, mse=0.0238]


[epoch#11/20][13248/13300]: 100%|██████████████████████████| 208/208 [01:33<00:00,  2.23it/s, mse=0.02327, loss=0.36949]


100%|██████████████████| 95/95 [00:00<00:00, 159.29it/s, acc=0.9684, mse=0.0280]


[epoch#12/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:41<00:00,  4.99it/s, mse=0.02348, loss=0.36323]


100%|██████████████████| 95/95 [00:00<00:00, 166.29it/s, acc=0.9789, mse=0.0280]


[epoch#13/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:42<00:00,  4.92it/s, mse=0.02418, loss=0.35004]


100%|██████████████████| 95/95 [00:00<00:00, 168.77it/s, acc=0.9684, mse=0.0263]


[epoch#14/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:43<00:00,  4.75it/s, mse=0.02423, loss=0.34971]


100%|██████████████████| 95/95 [00:00<00:00, 163.61it/s, acc=0.9789, mse=0.0291]


[epoch#15/20][13248/13300]: 100%|██████████████████████████| 208/208 [01:34<00:00,  2.19it/s, mse=0.02515, loss=0.34326]


100%|██████████████████| 95/95 [00:00<00:00, 158.53it/s, acc=0.9789, mse=0.0264]


[epoch#16/20][13248/13300]: 100%|██████████████████████████| 208/208 [01:27<00:00,  2.39it/s, mse=0.02517, loss=0.33507]


100%|██████████████████| 95/95 [00:00<00:00, 161.12it/s, acc=0.9789, mse=0.0333]


[epoch#17/20][13248/13300]: 100%|██████████████████████████| 208/208 [01:38<00:00,  2.11it/s, mse=0.02527, loss=0.33254]


100%|██████████████████| 95/95 [00:00<00:00, 164.95it/s, acc=0.9684, mse=0.0380]


[epoch#18/20][13248/13300]: 100%|██████████████████████████| 208/208 [00:46<00:00,  4.49it/s, mse=0.02585, loss=0.32088]


100%|██████████████████| 95/95 [00:00<00:00, 170.53it/s, acc=0.9684, mse=0.0323]


[epoch#19/20][13248/13300]: 100%|██████████████████████████| 208/208 [02:11<00:00,  1.58it/s, mse=0.02632, loss=0.31867]


100%|██████████████████| 95/95 [00:00<00:00, 157.22it/s, acc=0.9789, mse=0.0285]


[epoch#20/20][13248/13300]: 100%|██████████████████████████| 208/208 [02:24<00:00,  1.44it/s, mse=0.02599, loss=0.31709]


100%|██████████████████| 95/95 [00:00<00:00, 143.66it/s, acc=0.9684, mse=0.0289]


In [58]:
LEARNING_RATE, lambda_reg

(0.0001, 2.5)

Print wrong predictions

In [59]:
calc_test_metric(model, print_wrong_predictions = True)

T1030-D1_original_fm  y_gt:0.8572  y_pred:0.9384                                
  T1030-D1_rand13_fm  y_gt:0.8149  y_pred:0.9453                                
  T1070-D3_rand13_fm  y_gt:0.7369  y_pred:0.8095                                
  T1070-D3_rand16_fm  y_gt:0.8322  y_pred:0.6946                                
  T1101-D1_rand13_fm  y_gt:0.9789  y_pred:0.9987                                
  T1101-D1_rand18_fm  y_gt:0.9940  y_pred:0.9950                                
100%|██████████████████| 95/95 [00:00<00:00, 173.24it/s, acc=0.9684, mse=0.0289]


{'accuracy': 0.968421052631579, 'mse': 0.028878701904231044}

In [71]:
def plot_training_history(model_name = "mean_256", trained_epoches = 20):
    epoches = list(range(1, trained_epoches + 1))
    train_losses = []
    test_losses = []
    test_accs = []
    for epoch in range(1, trained_epoches+1):
        checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")
        train_losses.append(checkpoint["train_mse"])
        test_losses.append(checkpoint["valid_mse"])
        test_accs.append(checkpoint["test_acc"])
    
    trace1 = go.Scatter(
        x = epoches,
        y = train_losses,
        name= "Train Loss",
        xaxis='x',
        yaxis='y1',
        mode='lines+markers'
    )
    trace2 = go.Scatter(
        x = [epoch + 0.5 for epoch in epoches],
        y = test_losses,
        name= "Valid MSE",
        xaxis='x', 
        yaxis='y1',
        mode='lines+markers'
    )
    trace3 = go.Scatter(
        x = epoches,
        y = test_accs,
        name='Test Accuracy',
        xaxis='x', 
        yaxis='y2',
        mode='lines+markers'
    )
    
    data = [trace1, trace2, trace3]
    layout = go.Layout(
        yaxis2=dict(overlaying = 'y', side = 'right', title = "Accuracy", range = [0.85, 1.0]),
        yaxis1=dict(title = "MSE Loss"),
        xaxis = dict(title = "Epoch"),
        legend=dict(x=0.75, y=0.55, font=dict(size=12, color="black"))
    )
    
    fig = go.Figure(data=data, layout=layout)
    fig.show()

plot_training_history("pair_cnn_256", 20)