In [4]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset,DataLoader
from scipy import stats
from torch import cuda

In [31]:
hidden = 16
layer = 1
modelpth = './model/Cas9/' 
modelname = 'CRISPR-OTE1.0'  #CRISPR-OTE2.0

In [6]:
class Dataset(Dataset):
    def __init__(self, csv_file = r"./data/csv/Hela.csv", 
                 nrows = 8101, seqs_file = r"./data/SEQs/Hela/SEQs.npy", 
                 feas_file = r"./data/new_fea/Hela_fea_datas.csv"):
        
        if csv_file == r"./data/csv/elegans.csv":
            self.df = pd.read_csv(csv_file, sep = ",", header = 0, usecols = [7-1], nrows = nrows,  
                            names = ['Indel freqeuncy'], dtype = np.float64) 
        else:
            self.df = pd.read_csv(csv_file, sep = ",", header = 0, usecols = [8-1], nrows = nrows,  
                                    names = ['Indel freqeuncy'], dtype = np.float64) 
        
        self.seqs = np.load(seqs_file) # batch,34,4
        print(self.seqs.shape)
        self.feas = pd.read_csv(feas_file, sep = ",", header = 0, index_col=0, nrows = nrows, dtype = np.float64) # batch,44
        self.new_feas = self.feas.iloc[:,[40, 41, 42, 43]].values  # 4RNAfold
        print(self.new_feas.shape)
        self.Tms = self.feas.iloc[:, 0].values  # Tm34  
        self.IFs = self.df['Indel freqeuncy'].values 
        self.IFs = self.IFs.reshape(-1, 1) 
        self.Tms = self.Tms.reshape(-1, 1) 
    
    def __len__(self):
            return len(self.seqs)

    def __getitem__(self, index):
        X = torch.from_numpy(self.seqs).float()
        Y = torch.from_numpy(self.IFs).float()
        T = torch.from_numpy(self.Tms).float()
        R = torch.from_numpy(self.new_feas).float()
        return X[index],Y[index],T[index],R[index]

In [7]:
testDatasetHct116_lib1 = Dataset(csv_file = r"./data/csv/Hct116_lib1.csv", nrows = 4239, 
              seqs_file = r"./data/SEQs/Hct116_lib1/SEQs.npy", feas_file = r"./data/new_fea/Hct116_lib1_fea_datas.csv")
testDatasetCiona = Dataset(csv_file = r"./data/csv/Ciona.csv", nrows = 72, 
              seqs_file = r"./data/SEQs/Ciona/SEQs.npy", feas_file = r"./data/new_fea/Ciona_fea_datas.csv")
testDatasetZ_fish_GZ = Dataset(csv_file = r"./data/csv/Z_fish_GZ.csv", nrows = 111, 
              seqs_file = r"./data/SEQs/Z_fish_GZ/SEQs.npy", feas_file = r"./data/new_fea/Z_fish_GZ_fea_datas.csv")
testDatasetZ_fish_VZ = Dataset(csv_file = r"./data/csv/Z_fish_VZ.csv", nrows = 102, 
              seqs_file = r"./data/SEQs/Z_fish_VZ/SEQs.npy", feas_file = r"./data/new_fea/Z_fish_VZ_fea_datas.csv")
testDatasetElegans = Dataset(csv_file = r"./data/csv/elegans.csv", nrows = 50, 
              seqs_file = r"./data/SEQs/elegans/SEQs.npy", feas_file = r"./data/new_fea/elegans_fea_datas.csv")

(4239, 34, 4)
(4239, 4)
(72, 34, 4)
(72, 4)
(111, 34, 4)
(111, 4)
(102, 34, 4)
(102, 4)
(50, 34, 4)
(50, 4)


In [32]:
class CRISPR_OTE(nn.Module):
    def __init__(self):
        super(CRISPR_OTE, self).__init__()
        
        self.lstm = nn.LSTM(4, hidden, layer, batch_first = True, bidirectional = True)
        self.layerl = nn.Sequential(nn.Dropout(p=0.3), nn.Linear(34*2*hidden, 80), nn.ReLU(True)) 
        
        self.conv1 = nn.Sequential(nn.Conv1d(in_channels = 4, out_channels = 80, kernel_size = 5), # stide 1
                      nn.ReLU(),nn.AvgPool1d(kernel_size = 2)) # stride kernel_size
        self.layerc = nn.Sequential(nn.Dropout(p=0.3), nn.Linear(80*15, 80), nn.ReLU(True)) # (34-5+1)/2
        
        # CRISPR-OTE2.0
#         self.fc_Tm = nn.Linear(1, 30)
#         self.fea = nn.Linear(4, 30)
#         self.layer1_ft = nn.Sequential(nn.Dropout(p=0.3), nn.Linear(220, 80), nn.ReLU(True))
        
        # CRISPR-OTE1.0
        self.fc_Tm = nn.Linear(1, 30)
        self.layer1_ft = nn.Sequential(nn.Dropout(p=0.3), nn.Linear(190, 80), nn.ReLU(True))

        self.layer4_ft = nn.Sequential(nn.Dropout(p=0.3), nn.Linear(80, 1))


    def forward(self, x, T, R):
        x34 = x
        
        # LSTM
        if  cuda.is_available():
            h0 = torch.zeros(2*layer, x.size(0), hidden).cuda() 
            c0 = torch.zeros(2*layer, x.size(0), hidden).cuda()
        else:
            h0 = torch.zeros(2*layer, x.size(0), hidden)
            c0 = torch.zeros(2*layer, x.size(0), hidden)
       
        out34, (hn34, cn34)  = self.lstm(x34, (h0, c0))  

        x_l = out34 # batch,34,hidden*direction 
        x_l = x_l.contiguous().view(x_l.size(0), -1) # batch,34*16*2 
        x_l = self.layerl(x_l) # batch,80    
        
        # CNN
        x34 = x34.permute(0,2,1) # batch,4,34
        x_c = self.conv1(x34)  # batch,80,(34-5+1)/2
        x_c = x_c.view(x_c.size(0), -1) # batch,1200,
        x_c = self.layerc(x_c)  # batch,80

        T = self.fc_Tm(T)
        
        # CRISPR-OTE2.0
#         R = self.fea(R)
#         x = torch.cat([x_l, x_c, T, R], dim = 1) # batch,220  
        
        # CRISPR-OTE1.0
        x = torch.cat([x_l, x_c, T], dim = 1) # batch,190 
     
        x = self.layer1_ft(x)
        x = self.layer4_ft(x)
        return x


if  cuda.is_available():
    model = CRISPR_OTE().cuda()
else:
    model = CRISPR_OTE()
print(model)

CRISPR_OTE(
  (lstm): LSTM(4, 16, batch_first=True, bidirectional=True)
  (layerl): Sequential(
    (0): Dropout(p=0.3, inplace=False)
    (1): Linear(in_features=1088, out_features=80, bias=True)
    (2): ReLU(inplace=True)
  )
  (conv1): Sequential(
    (0): Conv1d(4, 80, kernel_size=(5,), stride=(1,))
    (1): ReLU()
    (2): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))
  )
  (layerc): Sequential(
    (0): Dropout(p=0.3, inplace=False)
    (1): Linear(in_features=1200, out_features=80, bias=True)
    (2): ReLU(inplace=True)
  )
  (fc_Tm): Linear(in_features=1, out_features=30, bias=True)
  (layer1_ft): Sequential(
    (0): Dropout(p=0.3, inplace=False)
    (1): Linear(in_features=190, out_features=80, bias=True)
    (2): ReLU(inplace=True)
  )
  (layer4_ft): Sequential(
    (0): Dropout(p=0.3, inplace=False)
    (1): Linear(in_features=80, out_features=1, bias=True)
  )
)


In [24]:
def test(model, test_loader):
    model.eval()
    with torch.no_grad():
        test_output = []
        test_label = []
        for k,(inputs, labels, T, R) in enumerate(test_loader):
            length = len(test_loader)
            if  cuda.is_available():
                inputs = inputs.cuda()
                T = T.cuda()
                R = R.cuda()
                outputs = model(inputs, T, R).cpu()
            else:
                outputs = model(inputs, T, R)

#             r,p = stats.spearmanr(labels.detach().numpy(),outputs.detach().numpy())
            test_output.extend(outputs.squeeze().detach().numpy())
            test_label.extend(labels.squeeze().detach().numpy())

        test_spear_r, _ = stats.spearmanr(test_output, test_label)
        test_pear_r, _ = stats.pearsonr(test_output, test_label)
        print("test_spr {:.4f}, test_pear {:.4f}".format(test_spear_r, test_pear_r))

    return(test_spear_r, test_pear_r)

In [33]:
checkpoint = torch.load(modelpth + modelname + '.pth')  #, map_location='cpu'
model.load_state_dict(checkpoint['model'])
run_time = checkpoint['time']
run_spr = checkpoint['spr']
run_pear = checkpoint['pear']
print('time {}, spr {:.4f}, pear {:.4f}'.format(run_time, run_spr, run_pear))  

time 3341.0395591259003, spr 0.4562, pear 0.4555


In [34]:
test_loaderHct116_lib1 = DataLoader(dataset = testDatasetHct116_lib1,batch_size = 4239,num_workers = 4,pin_memory = True)
_, _ = test(model,test_loaderHct116_lib1)

test_spr 0.5154, test_pear 0.5084


In [35]:
test_loaderCiona = DataLoader(dataset = testDatasetCiona,batch_size = 72,num_workers = 4,pin_memory = True)
_, _ = test(model,test_loaderCiona)

test_spr 0.4126, test_pear 0.4273


In [36]:
test_loaderZ_fish_GZ = DataLoader(dataset = testDatasetZ_fish_GZ,batch_size = 111,num_workers = 4,pin_memory = True)
_, _ = test(model,test_loaderZ_fish_GZ)

test_spr 0.4181, test_pear 0.3835


In [37]:
test_loaderZ_fish_VZ = DataLoader(dataset = testDatasetZ_fish_VZ,batch_size = 102,num_workers = 4,pin_memory = True)
_, _ = test(model,test_loaderZ_fish_VZ)

test_spr 0.2945, test_pear 0.2732


In [38]:
test_loaderElegans = DataLoader(dataset = testDatasetElegans,batch_size = 50,num_workers = 4,pin_memory = True)
_, _ = test(model,test_loaderElegans)

test_spr 0.4699, test_pear 0.3746
