In [13]:
import math
import random
import pickle
import numpy as np
from scipy.stats import spearmanr, pearsonr

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader

batch_size = 256
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

path = "/home/kwlee/Projects_gflas/Team_BI/Projects/1.Knockout_project/Data/Results/3.Model_test/MHSA_CNN/target1/set4/checkpoints/latest_net.pth"
file = "/home/kwlee/Projects_gflas/Team_BI/Projects/1.Knockout_project/Data/Finalsets/Data/Cas9_HF1_wang_parsing.pkl"
#path -> model 경로 입력
#file -> 테스트 파일 경로 입력

In [5]:
class DataWrapper:
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        self.nuc_to_idx = {"A": 0, "C": 1, "G": 2, "T": 3}

    def __len__(self):
        return len(self.data["X"])

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.to_list()

        res = dict()
        for col in self.data.keys():
            if col == "X":
                res[col] = torch.tensor(
                    [self.nuc_to_idx[x] for x in self.data[col][idx]], dtype=torch.long
                )
            else:
                res[col] = torch.tensor(self.data[col][idx], dtype=torch.float)
        return res

In [6]:
class DataManager:
    def __init__(self, batch_size):

        self.batch_size = batch_size
        self.seqlen = 33

    def data_set(self, file, n_return, ratio=1.0):

        data = pickle.load(open(file, "rb"))

        data_size = len(data["X"])
        indice = list(range(data_size))

        np.random.shuffle(indice)

        minY = min(data["Y"])
        maxY = max(data["Y"])
        data["Y"] = [(i - minY) / (maxY - minY) for i in data["Y"]]

        if n_return == 3:
            test_ratio = 0.15
            val_ratio = test_ratio
        elif n_return == 2:
            test_ratio = 0.15
            val_ratio = 0
        elif n_return == 1:
            test_ratio = 0
            val_ratio = 0

        train_size = int(np.floor(data_size * (1 - (val_ratio + test_ratio))) * ratio)
        valid_size = int(np.floor(data_size * val_ratio))
        test_size = int(np.floor(data_size * test_ratio))

        indices = dict()
        sampler = dict()
        indices["Val"] = random.sample(indice[:valid_size], valid_size)
        indices["Test"] = random.sample(
            indice[valid_size : valid_size + test_size], test_size
        )
        indices["Train"] = random.sample(
            indice[valid_size + test_size : valid_size + test_size + train_size],
            train_size,
        )

        train_set = {
            "X": [data["X"][i] for i in indices["Train"]],
            "Y": [data["Y"][i] for i in indices["Train"]],
        }
        test_set = {
            "X": [data["X"][i] for i in indices["Test"]],
            "Y": [data["Y"][i] for i in indices["Test"]],
        }

        if n_return == 3:
            valid_set = {
                "X": [data["X"][i] for i in indices["Val"]],
                "Y": [data["Y"][i] for i in indices["Val"]],
            }
            return train_set, valid_set, test_set
        elif n_return == 2:
            return train_set, test_set
        elif n_return == 1:
            return train_set

    def loader_only(self, data):
      
            loader = DataLoader(
                DataWrapper(data),
                batch_size=batch_size,
                num_workers=8,
                drop_last=True,
            )
            return loader

In [7]:
class ForeverDataIterator:
    """A data iterator that will never stop producing data"""
    def __init__(self, data_loader: DataLoader, device=None):
        self.data_loader = data_loader
        self.iter = iter(self.data_loader)
        self.device = device

    def __next__(self):
        try:
            data = next(self.iter)
            if isinstance(data, dict):
                data = [v for k, v in data.items()]  
            if self.device is not None:
                data = send_to_device(data, self.device)
                
        except StopIteration:
            self.iter = iter(self.data_loader)
            data = next(self.iter)
            if isinstance(data, dict):
                data = [v for k, v in data.items()]  
            if self.device is not None:
                data = send_to_device(data, self.device)
        return data

    def __len__(self):
        return len(self.data_loader)

In [8]:
DM = DataManager(batch_size)
test_target_iter = ForeverDataIterator(DM.loader_only(DM.data_set(file, n_return = 1, ratio = 1.0)))

In [9]:
class Flattening(nn.Module):
    def __init__(self):
        super(Flattening, self).__init__()

    def forward(self, x):
        return torch.flatten(x, 1)


class PositionalEncoding(nn.Module):
    # Taken from: https://nlp.seas.harvard.edu/2018/04/03/attention.html
    "Implement the PE function."

    def __init__(self, dim, dropout=0.1, max_len=43):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0.0, dim, 2) * -(math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False)
        return self.dropout(x)

class MHSA_CNN(nn.Module):
    def __init__(self, len: int):
        super(MHSA_CNN, self).__init__()

        self.seq_len = len
        self.embedding_dim = len + 1
        self.dropout_rate = 0.4
        self.single_head_size = 32
        self.multi_head_num = 8
        self.multi_head_size = 100  ###

        self.embedding_layer = nn.Embedding(
            num_embeddings=4, embedding_dim=self.embedding_dim, max_norm=True
        )
        self.position_encoding = PositionalEncoding(
            dim=self.embedding_dim, max_len=self.seq_len, dropout=0.1
        )

        self.Q = nn.ModuleList(
            [
                nn.Linear(
                    in_features=self.embedding_dim, out_features=self.single_head_size
                )
                for i in range(0, self.multi_head_num)
            ]
        )
        self.K = nn.ModuleList(
            [
                nn.Linear(
                    in_features=self.embedding_dim, out_features=self.single_head_size
                )
                for i in range(0, self.multi_head_num)
            ]
        )
        self.V = nn.ModuleList(
            [
                nn.Linear(
                    in_features=self.embedding_dim, out_features=self.single_head_size
                )
                for i in range(0, self.multi_head_num)
            ]
        )

        self.relu = nn.ModuleList([nn.ReLU() for i in range(0, self.multi_head_num)])
        self.MultiHeadLinear = nn.Sequential(
            nn.LayerNorm(self.single_head_size * self.multi_head_num),
            nn.Linear(
                in_features=self.single_head_size * self.multi_head_num,
                out_features=self.multi_head_size,
            ),
            nn.ReLU(),
            nn.Dropout(p=0.2),
        )
        self.ConvLayer = nn.Sequential(
            nn.Conv1d(4, 32, kernel_size=3, padding="same", stride=1, bias=False),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(32, 32, kernel_size=3, padding="same", stride=1, bias=False),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(),
        )
        self.flattening = Flattening()
        self.avgpool = nn.AdaptiveAvgPool1d(output_size=1)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2)

        self.predictor = nn.Sequential(
            nn.BatchNorm1d(2673),
            nn.Linear(in_features=2673, out_features=512),
            nn.ReLU(),
            nn.Dropout(),
            nn.BatchNorm1d(512),
            nn.Linear(in_features=512, out_features=32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Linear(in_features=32, out_features=1),
        )

    def attention(self, query, key, value, mask=None, dropout=0.0):
        # based on: https://nlp.seas.harvard.edu/2018/04/03/attention.html
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        p_attn = F.softmax(scores, dim=-1)
        p_attn = F.dropout(p_attn, p=dropout, training=self.training)
        return torch.matmul(p_attn, value), p_attn

    def forward(self, inputs):

        identity = inputs.clone()

        inputs_embs = self.embedding_layer(inputs) * math.sqrt(self.embedding_dim)
        inputs_sums = self.position_encoding(inputs_embs)
        output = inputs_sums
        ##########################################################################

        pAttn_concat = torch.Tensor([]).to(inputs.device)
        attn_concat = torch.Tensor([]).to(inputs.device)
        for i in range(0, self.multi_head_num):
            query = self.Q[i](output)
            key = self.K[i](output)
            value = self.V[i](output)
            attnOut, p_attn = self.attention(query, key, value, dropout=0.0)
            attnOut = self.relu[i](attnOut)
            attn_concat = torch.cat((attn_concat, attnOut), dim=2)

        attn_out = self.MultiHeadLinear(attn_concat)
        attn_out = self.maxpool(attn_out)
        attn_out = self.flattening(attn_out)
        # attn_out = self.avgpool(attn_out)

        conv_out = F.one_hot(identity).to(torch.float)
        conv_out = conv_out.transpose(1, 2)
        conv_out = self.ConvLayer(conv_out)
        conv_out = self.flattening(conv_out)
        # conv_out = self.maxpool(conv_out)

        # output = self.flattening(output)
        output = torch.cat((attn_out, conv_out), dim=1)
        # output = output.reshape(output.shape[0], -1)
        output = self.predictor(output)

        return output.squeeze()


In [10]:
model = MHSA_CNN(len = 33).to(device)
model.load_state_dict(torch.load(path))

<All keys matched successfully>

In [11]:
def test(data):
  
        eval = {"predicted_value": list(), "real_value": list()}

        model.eval()
        with torch.no_grad():
            for i in range(len(data)):
                X, y = next(data)
                X = X.to(device)
                y = y.to(device)

                outputs = model(X)
                eval["predicted_value"] += outputs.cpu().detach().numpy().tolist()
                eval["real_value"] += y.cpu().detach().numpy().tolist()

        corrs = spearmanr(eval["real_value"], eval["predicted_value"])[0]
        corrp = pearsonr(eval["real_value"], eval["predicted_value"])[0]
        return corrs, corrp

In [14]:
corrs, corrp = test(test_target_iter)
print(f"Spearman Correlation.\t{corrs}")
print(f"Pearson Correlation.\t{corrp}")

Spearman Correlation.	0.6675447458572075
Pearson Correlation.	0.6454832626750231
