In [1]:
import numpy as np
import torch
import torch.nn as nn
from time import time
from tqdm import tqdm

In [2]:
class inference_dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # convert from numpy to tensor
        x = torch.tensor(self.data[index], dtype=torch.float32)

        return x

In [3]:
# MIMENet Deep Neural Network
class MIMENetEnsemble(nn.Module):
    # Constructor for 5 fully connected layers with bottleneck for layers 4 and 5
    # dropout after each layer
    def __init__(self, input_size, hidden_size_factor, bottleneck, output_size):
        super(MIMENetEnsemble, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size_factor*input_size)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(hidden_size_factor*input_size, hidden_size_factor*input_size)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(hidden_size_factor*input_size, hidden_size_factor*input_size)
        self.dropout3 = nn.Dropout(p=0.2)
        self.fc4 = nn.Linear(hidden_size_factor*input_size, int(hidden_size_factor*input_size*bottleneck))
        self.dropout4 = nn.Dropout(p=0.2)
        self.fc5 = nn.Linear(int(hidden_size_factor*input_size*bottleneck), output_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    # Forward pass
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout1(x)
        x = self.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.relu(self.fc3(x))
        x = self.dropout3(x)
        x = self.relu(self.fc4(x))
        x = self.dropout4(x)
        x = self.sigmoid(self.fc5(x))

        return x

In [4]:
def predict(model, x, n, batch_size):
    """This function is used for inference using the model. It takes in a 2d tensor
    and predicts the output using the model. For uncertainty quantification, the
    model is run n times and the full distribution is returned.

    Args:
        x (tensor): Input tensor
        n (int): Number of times to run the model
    """
    # set length of input
    length = x.shape[0]

    # set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    
    # define x as inference dataset
    x = inference_dataset(x)

    # define dataloader
    dataloader = torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=False)

    # set model
    model = model.to(device)

    # initialize output
    output = np.zeros((length, n))
    # run model n times
    for i in tqdm(range(n)):
        for j, data in enumerate(dataloader):
            # only send data to device
            data = data.to(device)

            # add to output
            output[j*batch_size:(j+1)*batch_size, i] = model(data).squeeze().cpu().detach().numpy()
            
    return output

In [5]:
# test predict function
model = MIMENetEnsemble(535*4+8, 2, 0.5, 1)
batch_size = 2**14
# 535*4 by 535*4+8 tensor random input
x = np.random.rand(535*3, 535*4+8)

start = time()
predictions = predict(model, x, 1000, batch_size)
end = time()
print(end-start)
print(predictions.shape)

cuda:0


100%|██████████| 1000/1000 [00:27<00:00, 35.97it/s]

28.6442608833313
(1605, 1000)





In [6]:
# test predict function
model = MIMENetEnsemble(535*4+8, 2, 0.5, 1)
batch_size = 2**14
# 535*4 by 535*4+8 tensor random input
x = np.random.rand(1285605, 535*4+8)

start = time()
predictions = predict(model, x, 1000, batch_size)
end = time()
print(end-start)
print(predictions.shape)

cuda:0


  0%|          | 3/1000 [01:18<7:17:10, 26.31s/it]


KeyboardInterrupt: 

In [None]:
def inferSingleKds(model, n_protein_concentrations, n_rounds, path_wildtype, n : int, batch_size : int):

    # read in wildtype
    with open(path_wildtype, 'r') as f:
        wildtype = f.read()

    n_features = len(wildtype) * 4 + n_protein_concentrations * n_rounds
    
    # initialize prediction example (number of nucleotides by number of features)
    prediction_example_mutation = np.zeros((len(wildtype)*3, n_features))
    prediction_example_wildtype = np.zeros((len(wildtype), n_features))

    for pos, feature in enumerate(range(n_protein_concentrations*n_rounds, n_features, 4)):

        # get wildtype at position
        wildtype_nucleotide = wildtype[pos]

        order_nucleotides = ['A', 'C', 'G', 'T']

        # get index of wildtype nucleotide
        index_wildtype = order_nucleotides.index(wildtype_nucleotide)

        # set wildtype nucleotide
        prediction_example_wildtype[pos, feature + index_wildtype] = 1

        # set mutant nucleotides
        indices_mutants = [i for i in range(4) if i != index_wildtype]
        prediction_example_mutation[pos*3, feature + indices_mutants[0]] = 1
        prediction_example_mutation[pos*3+1, feature + indices_mutants[1]] = 1
        prediction_example_mutation[pos*3+2, feature + indices_mutants[2]] = 1


        # predict
        with torch.no_grad():
            wildtype_prediction = predict(model, prediction_example_wildtype, n, batch_size)
            mutation_prediction = predict(model, prediction_example_mutation, n, batch_size)

            # compute predictions to kds
            wildtype_prediction = 1 / wildtype_prediction - 1
            mutation_prediction = 1 / mutation_prediction - 1

            # repeat every row in wildtype_prediction 3 times
            wildtype_prediction = np.repeat(wildtype_prediction, 3, axis=0)

            # correct kds
            kds_nucleotide = mutation_prediction / wildtype_prediction

            # reshape to 3d so that 3 rows are grouped together as one position
            kds_position = kds_nucleotide.reshape((len(wildtype), 3, n))

            # take max of 3 rows
            kds_position = np.max(kds_position, axis=1)


    return kds_nucleotide, kds_position

In [21]:
wildtype = "a"*535
n_features = len(wildtype) * 4 + 4 * 2
    
# initialize prediction example (number of nucleotides by number of features)
prediction_example_mutation = np.zeros((len(wildtype)*3, n_features))
prediction_example_wildtype = np.zeros((len(wildtype), n_features))

prediction_example_wildtype[0, 0] = 4
prediction_example_wildtype[1, 1] = 4
print(prediction_example_wildtype[:9, :9])

print(prediction_example_mutation.shape)
print(prediction_example_wildtype.shape)

[[4. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 4. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]]
(1605, 2148)
(535, 2148)


In [25]:
kds_nucleotide = np.random.rand(535*3, 1000)
kds_pos = kds_nucleotide.reshape((len(wildtype), 3, 1000)).round(2)
print(kds_pos.shape)
print(kds_pos[:5, :, :10])
kds_pos = np.max(kds_pos, axis=1).round(2)
print(kds_pos.shape)
print(kds_pos[:5, :10])


(535, 3, 1000)
[[[0.79 0.94 0.88 0.38 0.88 0.76 0.34 0.02 0.36 0.83]
  [0.75 0.22 0.37 0.75 0.45 0.66 0.64 0.66 0.55 0.22]
  [0.7  0.89 0.91 0.67 0.9  0.54 0.59 0.35 0.79 0.19]]

 [[0.81 0.32 0.61 0.25 0.94 0.95 0.55 0.85 0.3  0.82]
  [0.42 0.28 0.6  0.49 0.48 0.6  0.79 0.31 0.23 0.16]
  [0.55 0.8  0.65 0.36 0.31 0.08 0.1  0.57 0.38 0.17]]

 [[0.7  0.21 0.31 0.83 0.68 0.85 0.16 0.54 0.25 0.48]
  [0.19 0.02 0.36 0.04 0.75 0.62 0.92 0.75 0.18 0.81]
  [0.12 0.93 0.24 0.73 0.42 0.28 0.97 0.83 0.5  0.25]]

 [[0.63 0.73 0.87 0.27 0.08 0.01 0.34 0.12 0.65 0.45]
  [0.59 0.94 0.05 0.2  0.34 0.17 0.97 0.04 0.16 0.74]
  [0.41 0.21 0.56 0.59 0.48 0.7  0.38 0.69 0.93 0.8 ]]

 [[0.82 0.48 0.56 0.79 0.49 0.11 0.58 0.38 0.33 0.38]
  [0.53 0.08 0.26 0.52 0.86 0.22 0.2  0.49 0.16 0.22]
  [0.67 0.11 0.6  0.49 0.3  0.61 0.53 0.82 0.61 0.9 ]]]
(535, 1000)
[[0.79 0.94 0.91 0.75 0.9  0.76 0.64 0.66 0.79 0.83]
 [0.81 0.8  0.65 0.49 0.94 0.95 0.79 0.85 0.38 0.82]
 [0.7  0.93 0.36 0.83 0.75 0.85 0.97 0.83 0.5  

In [None]:
def inferPairwiseKds(model, n_protein_concentrations, n_rounds, path_wildtype, n : int):

    # read in wildtype
    with open(path_wildtype, 'r') as f:
        wildtype = f.read()

    n_features = len(wildtype) * 4 + n_protein_concentrations * n_rounds
    n_pairs_mut = int((len(wildtype)*(len(wildtype)-1)/2)*3*3)
    n_pairs_wt = int((len(wildtype)*(len(wildtype)-1)/2))
    
    # initialize prediction example
    prediction_example_mutation = np.zeros((n_pairs_mut, n_features))
    prediction_example_wildtype = np.zeros((n_pairs_wt, n_features))

    i = 0
    for pos1, feature1 in enumerate(range(n_protein_concentrations*n_rounds, n_features, 4)):
        for pos2, feature2 in enumerate(range(feature1+4, n_features, 4)):
            
            # get wildtype at both positions
            wildtype1 = wildtype[pos1]
            wildtype2 = wildtype[pos2]

            order_nucleotides = ['A', 'C', 'G', 'T']

            index_wildtype1 = order_nucleotides.index(wildtype1)
            index_wildtype2 = order_nucleotides.index(wildtype2)

            # set wildtype nucleotides
            prediction_example_wildtype[i//9, feature1 + index_wildtype1] = 1
            prediction_example_wildtype[i//9, feature2 + index_wildtype2] = 1

            # iterate over all possible mutations
            for mut1 in range(4):
                for mut2 in range(4):

                    # skip wildtype
                    if mut1 == index_wildtype1 or mut2 == index_wildtype2:
                        continue

                    # set mutant
                    prediction_example_mutation[i, feature1+mut1] = 1
                    prediction_example_mutation[i, feature2+mut2] = 1

                    i += 1

    # predict
    with torch.no_grad():

        wildtype_prediction = predict(model, prediction_example_wildtype, n, batch_size)
        mutation_prediction = predict(model, prediction_example_mutation, n, batch_size)

        # compute predictions to kds
        wildtype_prediction = 1 / wildtype_prediction - 1
        mutation_prediction = 1 / mutation_prediction - 1

        # repeat every row in wildtype_prediction 3 times
        wildtype_prediction = np.repeat(wildtype_prediction, 9, axis=0)

        # correct kds
        kds_pairwise = mutation_prediction / wildtype_prediction


    return kds_pairwise

In [2]:
wildtype = "a"*535
n_features = len(wildtype) * 4 + 4 * 2
n_pairs = int((len(wildtype)*(len(wildtype)-1)/2)*3*3)
print(n_pairs)
print(n_features)

1285605
2148


In [9]:
int((len(wildtype)*(len(wildtype)-1)/2)*3*3-10)//9

142843