## CNN

In [None]:
!nvidia-smi

In [1]:
import os
import sys
from pathlib import Path

PROJECT_PATH = Path() 
sys.path.append(str(PROJECT_PATH))
from tqdm import tqdm
from pathlib import Path

import torch
import torchvision
import torch.utils.data as Data
import torch.nn as nn
import torch.nn.functional as F

DATASET_PATH = Path() / "dataset" / "CASP14_fm"
MODEL_PATH       = PROJECT_PATH / "model"
EMBDEDDINGS_PATH = PROJECT_PATH / "embeddings"

# hyperparameters
LEARNING_RATE = 1e-4
EPOCHS = 20
BATCH_SIZE = 32
checkpoint_epoch = 0
model_name = "cnn_256"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Prepare dataset
from dataset import EmbeddingScoreDataset, PairDataset

train_dataset = EmbeddingScoreDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = True)
test_dataset  = PairDataset(EMBDEDDINGS_PATH, DATASET_PATH, is_train = False)

NUM_GPU = torch.cuda.device_count()

train_loader = Data.DataLoader(train_dataset, batch_size = BATCH_SIZE, num_workers = 0, pin_memory = True, shuffle = True)
test_loader  = Data.DataLoader(test_dataset, batch_size = 1, shuffle = False)


In [3]:
# model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=6,
                kernel_size=5,
                stride=1,
                padding=0
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=6,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=0
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc1 = nn.Linear(432432, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 1)
        
    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, self.num_flat_features(x)) 
        x = F.relu(self.fc1(x)) 
        x = F.relu(self.fc2(x)) 
        x = torch.sigmoid(self.fc3(x))
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [4]:
model = CNN().cuda()

NUM_GPU = torch.cuda.device_count()
USE_PARALLEL = NUM_GPU > 1
if USE_PARALLEL :
    model = torch.nn.DataParallel(model)
    
print(model)

CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=432432, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=1, bias=True)
)


In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

if checkpoint_epoch > 0 :
    checkpoint = torch.load(MODEL_PATH / f"model_{model_name}_epoch{checkpoint_epoch}.pth")
    (model.module if USE_PARALLEL else model).load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"The {model_name} model loaded has been trained for {epoch} epoche(s), with {checkpoint['train_mse']} training loss, {checkpoint['valid_mse']} validation MSE and {checkpoint['test_acc']} test accuracy. ")
else :
    print(f"Start training {model_name} model from the 1st epoch.")

Start training cnn_256 model from the 1st epoch.


In [6]:
def calc_test_metric(model, print_wrong_predictions = False, print_correct_predictions = False):
    model.eval()
    mse, correct, tot = 0.0, 0, 0
    with torch.no_grad():
        with tqdm(total = len(test_loader), ncols = 80, file = sys.stdout) as bar:
            for sample in test_loader :
                x1, x2, y1, y2 = sample["embedding1"].cuda(), sample["embedding2"].cuda(), sample["score1"].cuda(), sample["score2"].cuda()
                x, y = torch.vstack([x1, x2]), torch.vstack([y1, y2])
                
                predictor = model.module if USE_PARALLEL else model
                pred = torch.sigmoid(predictor.regressor(x)) if model_name[:4] == 'pair' else predictor(x)

                if torch.argmax(y) == torch.argmax(pred) :
                    correct += 1
                    if print_correct_predictions:
                        for name, y_gt, y_pred in zip([sample['name1'][0], sample['name2'][0]], y, pred) :
                            tqdm.write(f"{name:>20}  y_gt:{y_gt.item():.4f}  y_pred:{y_pred.item() :.4f}")
                        tqdm.write("-----------------------------------------------")
                else :
                    if print_wrong_predictions:
                        for name, y_gt, y_pred in zip([sample['name1'][0], sample['name2'][0]], y, pred) :
                            tqdm.write(f"{name:>20}  y_gt:{y_gt.item() :.4f}  y_pred:{y_pred.item() :.4f}")
                        tqdm.write("===================================================")
                
                mse += torch.sum((y - pred) ** 2).item()
                tot += 2

                bar.set_postfix({
                    "acc": f"{correct / (tot // 2):.4f}",
                    "mse": f"{mse / tot:.4f}"
                })
                bar.update(1)
    
    return {'accuracy' : correct / (tot // 2), 'mse' : mse / tot}

In [7]:
# train
for epoch in range(checkpoint_epoch + 1, EPOCHS + 1):
    losses, losses_reg, tot = 0, 0, 0
    
    model.train()

    with tqdm(total = len(train_loader), ncols = 130) as bar:
        for batch, sample in enumerate(train_loader) :
            
            bar.set_description(f"[epoch#{epoch:>2}/{EPOCHS:>2}][{batch * BATCH_SIZE:>5}/{len(train_dataset):>5}]")

            x, y = sample["embedding"].cuda(), sample["score"].cuda()
            pred = model(x)
            loss = criterion(pred,y)

            # Backpropagation
            optimizer.zero_grad()     
            loss.backward()
            optimizer.step()
            
            losses += loss.item() * len(sample)
            tot += len(sample)

            bar.set_postfix({
                # "batch loss" : f"{loss.item():.5f}",
                "loss": f"{losses / tot:.5f}"
            })
            bar.update(1)

    test_metric =  calc_test_metric(model)
    train_loss, train_mse, test_acc, valid_mse = losses / tot, losses_reg / tot, test_metric['accuracy'], test_metric['mse']
    torch.save({
        'epoch'               : epoch,
        'model_state_dict'    : (model.module if USE_PARALLEL else model).state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_mse'           : train_mse,
        'train_loss'          : train_loss,
        'valid_mse'           : valid_mse,
        'test_acc'            : test_acc
    }, MODEL_PATH / f"model_{model_name}_epoch{epoch}.pth")

[epoch# 1/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:24<00:00,  3.47it/s, loss=0.04622]


100%|███████████████████| 95/95 [00:01<00:00, 79.81it/s, acc=0.9684, mse=0.0333]


[epoch# 2/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:20<00:00,  4.03it/s, loss=0.01755]


100%|███████████████████| 95/95 [00:01<00:00, 81.67it/s, acc=0.9474, mse=0.0191]


[epoch# 3/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.94it/s, loss=0.01250]


100%|███████████████████| 95/95 [00:01<00:00, 82.38it/s, acc=0.9579, mse=0.0160]


[epoch# 4/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.92it/s, loss=0.00886]


100%|███████████████████| 95/95 [00:01<00:00, 80.15it/s, acc=0.9684, mse=0.0178]


[epoch# 5/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:20<00:00,  4.03it/s, loss=0.00848]


100%|███████████████████| 95/95 [00:01<00:00, 82.52it/s, acc=0.9684, mse=0.0148]


[epoch# 6/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:20<00:00,  4.03it/s, loss=0.00603]


100%|███████████████████| 95/95 [00:01<00:00, 84.65it/s, acc=0.9684, mse=0.0137]


[epoch# 7/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.88it/s, loss=0.00571]


100%|███████████████████| 95/95 [00:01<00:00, 85.10it/s, acc=0.9579, mse=0.0152]


[epoch# 8/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.87it/s, loss=0.00530]


100%|███████████████████| 95/95 [00:01<00:00, 81.46it/s, acc=0.9474, mse=0.0153]


[epoch# 9/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.89it/s, loss=0.00425]


100%|███████████████████| 95/95 [00:01<00:00, 81.11it/s, acc=0.9579, mse=0.0134]


[epoch#10/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.93it/s, loss=0.00319]


100%|███████████████████| 95/95 [00:01<00:00, 75.86it/s, acc=0.9684, mse=0.0137]


[epoch#11/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.83it/s, loss=0.00309]


100%|███████████████████| 95/95 [00:01<00:00, 75.74it/s, acc=0.9789, mse=0.0144]


[epoch#12/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.89it/s, loss=0.00261]


100%|███████████████████| 95/95 [00:01<00:00, 83.26it/s, acc=0.9684, mse=0.0131]


[epoch#13/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.90it/s, loss=0.00231]


100%|███████████████████| 95/95 [00:01<00:00, 80.02it/s, acc=0.9684, mse=0.0133]


[epoch#14/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.91it/s, loss=0.00243]


100%|███████████████████| 95/95 [00:01<00:00, 82.45it/s, acc=0.9684, mse=0.0139]


[epoch#15/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.96it/s, loss=0.00173]


100%|███████████████████| 95/95 [00:01<00:00, 82.45it/s, acc=0.9789, mse=0.0129]


[epoch#16/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.98it/s, loss=0.00179]


100%|███████████████████| 95/95 [00:01<00:00, 82.59it/s, acc=0.9789, mse=0.0130]


[epoch#17/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.95it/s, loss=0.00173]


100%|███████████████████| 95/95 [00:01<00:00, 81.25it/s, acc=0.9789, mse=0.0124]


[epoch#18/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.97it/s, loss=0.00128]


100%|███████████████████| 95/95 [00:01<00:00, 82.16it/s, acc=0.9789, mse=0.0126]


[epoch#19/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.98it/s, loss=0.00106]


100%|███████████████████| 95/95 [00:01<00:00, 82.88it/s, acc=0.9789, mse=0.0130]


[epoch#20/20][ 2656/ 2660]: 100%|███████████████████████████████████████████████████| 84/84 [00:21<00:00,  3.99it/s, loss=0.00097]


100%|███████████████████| 95/95 [00:01<00:00, 82.73it/s, acc=0.9789, mse=0.0126]


In [8]:
calc_test_metric(model, print_wrong_predictions = True)

T1030-D1_original_fm  y_gt:0.8572  y_pred:0.7634                                
  T1030-D1_rand13_fm  y_gt:0.8149  y_pred:0.8119                                
T1047s2-D2_rand16_fm  y_gt:0.7530  y_pred:0.8398                                
 T1047s2-D2_rand5_fm  y_gt:0.9398  y_pred:0.7485                                
100%|███████████████████| 95/95 [00:01<00:00, 64.70it/s, acc=0.9789, mse=0.0126]


{'accuracy': 0.9789473684210527, 'mse': 0.012627777705998405}