In [5]:
%reset -f
import os
import time
import math
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
from torchsummaryX import summary

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

strand_num = 50000
strand_length = 120
error_rate = 0.01
data_path = 'test_' + str(strand_num) + '_' + str(strand_length) + '_' + str(error_rate) + '.npz'

checkpoints = [
               'save_error0.01/epoch_x_x.pth',
              ]

print(strand_num, strand_length, error_rate, device)

def convert_strands(strand):
    out = []
    for char in strand:
        if char == 'A':
            out.append(1)
        if char == 'G':
            out.append(2)
        if char == 'C':
            out.append(3)
        if char == 'T':
            out.append(4)
    return out

loaded = np.load(data_path)
data = {}

# processing labels
for split in ['test_y']:
    strands = []
    for s in loaded[split]:
        s = convert_strands(s)
        strands.append(s)
    data[split] = np.array(strands)
    print(split, 'shape', data[split].shape)

# processing inputs
for split in ['test_x']:
    strands = []
    lengths = []
    for c in loaded[split]:
        cluster = []
        for s in c:
            s = convert_strands(s)
            lengths.append(len(s))
                    
            if len(s) > strand_length:
                while len(s) > strand_length:
                    idx = np.random.randint(len(s))
                    del s[idx]
            elif len(s) < strand_length:
                while len(s) < strand_length:
                    idx = np.random.randint(len(s))
                    r = np.random.choice([1,2,3,4])
                    s.insert(idx, r)
                    
            cluster.append(s)
        strands.append(cluster)
    data[split] = np.array(strands)        
    print(split, 'shape', data[split].shape)

    # check strands length
    l = np.array(lengths)
    print('lengths shape', l.shape)
    print('lengths stat', l.mean(), l.min(), l.max())


class StandardScaler():
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def transform(self, data):
        return (data - self.mean) / self.std

    def inverse_transform(self, data):
        return (data * self.std) + self.mean
scaler = StandardScaler(mean=data['test_x'].mean(), std=data['test_x'].std())

for split in ['test_x']:
    data[split] = scaler.transform(data[split])
    data[split] = torch.from_numpy(data[split].astype(np.float32))

for split in ['test_y']:
    data[split] = torch.from_numpy(data[split].astype('int64'))

test_dataset = TensorDataset(data['test_x'], data['test_y'])
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)


class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model, dropout, max_len=strand_length):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
class ConsensusTransformer(torch.nn.Module):
        def __init__(self):
            super().__init__()
            d_model = 128
            d_hid =  512
            nhead = 8
            dropout = 0.1
            nlayers = 6
            
            self.embed = torch.nn.Linear(10, d_model)
            self.pos_encoder = PositionalEncoding(d_model, dropout)
            
            encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
            self.encoder = torch.nn.TransformerEncoder(encoder_layer, nlayers)

            self.linear = torch.nn.Sequential(
                torch.nn.Linear(d_model, 64),
                torch.nn.LeakyReLU(),
                torch.nn.Linear(64, 4),
            )
            self.softmax = torch.nn.Softmax(dim=1)
            
        def forward(self, x): # x shape (bs, 10, 120)
            x = x.transpose(1, 2)
            x = self.embed(x)
            x = x.transpose(0, 1)
            x = self.pos_encoder(x)
            x = self.encoder(x)
            x = self.linear(x)
            x = x.transpose(0, 1)
            x = x.reshape(x.size(0)*x.size(1), x.size(2))
            x = self.softmax(x)
            return x

In [6]:
def get_acc(predictions, labels):
    length = labels.size(1)
    matches = torch.sum((predictions == labels), dim=1)
    entry_acc = matches.sum()
    strand_acc = (matches == length).sum()
    return strand_acc, entry_acc

def get_index(predictions, labels):
    match_m = (predictions == labels)
    idx = (match_m == 0).nonzero(as_tuple=True)
    return idx[1].tolist()

predict_res = []
for i in range(len(checkpoints)):
    model = ConsensusTransformer()
    model.load_state_dict(torch.load(checkpoints[i])['model_state'])
    model.to(device)
    t1 = time.time()
    total = 0
    test_strand_acc = 0
    test_entry_acc = 0
    model.eval()
    res = []
    miscls_idx = []
    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_y -= 1

        with torch.no_grad():
            outputs = model(batch_x)
            outputs = outputs.view(batch_y.size(0), batch_y.size(1), -1)
            outputs = outputs.argmax(dim=2)
            res.append(outputs)

            total += batch_y.size(0)
            strand_acc, entry_acc = get_acc(outputs, batch_y)
            miscls_idx.append(get_index(outputs, batch_y))
            test_strand_acc += strand_acc
            test_entry_acc += entry_acc
    mtest_strand_acc = test_strand_acc / total
    mtest_entry_acc = test_entry_acc / total
    print('Test Strand Acc:', mtest_strand_acc.item(), 'Test Entry Acc:', mtest_entry_acc.item())
    t2 = time.time()
    print('Test Time:', t2 - t1)

In [7]:
# visualization
from matplotlib import pyplot as plt

t = [y for x in miscls_idx for y in x]
t = np.array(t)
plt.hist(t, 20)
plt.show()