In [42]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Multinomial
import biom
from biom import load_table, Table
from biom.util import biom_open

from mmvec.util import split_tables, format_params

In [84]:
class MMVec(nn.Module):
    def __init__(self, num_microbes, num_metabolites, latent_dim):
        super().__init__()
        self.encoder = nn.Embedding(num_microbes, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, num_metabolites),
            nn.Softmax(dim=0)
        )
        
    # X = Batch x num_microbesl
    # Y = Batch_size x num_metaboites
    def forward(self, X, Y):
        z = self.encoder(X)
        y_pred = self.decoder(z)
        print(y_pred.shape)
        print(Y.shape)
        # total_count=0 and validate_args=False allows skipping being concerned
        # about total count when calling log_prob
        lp = Multinomial(total_count=0, validate_args=False, probs=y_pred).log_prob(Y)
        return lp

Getting data prepped:

In [44]:
# manually to start
microbes = load_table("./soil_microbes.biom")
metabolites = load_table("./soil_metabolites.biom")

# X = microbes.to_dataframe().T
# Y = metabolites.to_dataframe().T
# X = X.loc[Y.index]

# trainX = X.iloc[:-2]
# trainY = Y.iloc[:-2]
# testX = X.iloc[-2:]
# testY = Y.iloc[-2:]

# # index dictionaries for the inputs
# microbeIdxs = {microbe: i for i, microbe in enumerate(trainX)}
# metaboliteIdxs = {metabolite: i for i, metabolite in enumerate(trainY)}

# print(microbeIdxs, "\n------\n", metaboliteIdxs, "\n----\n", Y)
# exX = torch.tensor(trainX.values, requires_grad=True)
# exY = torch.tensor(trainY.values, requires_grad=True)


#### put the data in a dataloader(pytorch iterator)

In [45]:
class MicrobeMetaboliteData(Dataset):
    def __init__(self, microbes: biom.table, metabolites: biom.table):
        # arrange
        self.microbes = microbes.to_dataframe().T   
        self.metabolites = metabolites.to_dataframe().T
        
        # only samples that have results
        self.microbes = self.microbes.loc[self.metabolites.index]
        
        # microbe index
#         self.microbe_idxs = {microbe: i for i, microbe in enumerate(self.microbes)}
        
        # make tensors
        self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)
        # dtype must be integer to avoid floating point errors in multinomial
        self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)
        
        # inputs in dict so indexs available...not sure if we need this
        self.idxs = {i: sample for i, sample in enumerate(self.microbes)}
        
        self.microbe_count = self.microbes.shape[1]
        self.metabolite_count = self.metabolites.shape[1]
        
       
    def __len__(self):
        return len(self.microbes)
    
    def __iter__(self, batch_size):
        # What do we want out?
        # x random feature indexes, where x = the total number of counts in sample
        # the expected outputs for this sample
        
        # get the total number of observed features in the sample
#         sample_microbes = self.microbes[idx]
#         sample_microbe_obs = sample_microbes.sum()
        
        # generate indexes to feed to the embedding
        relative_frequency = (train_dataset.microbes.T/train_dataset.microbes.sum(1)).T
        batch_multinomial = torch.multinomial(relative_frequency, batch_size).T
       
        return (batch_multinomial, self.metabolites)

In [46]:
train_dataset = MicrobeMetaboliteData(microbes, metabolites)
train_dataloader = DataLoader(train_dataset, batch_size=1,
                                 shuffle=True)

# test_dataset = MicrobeMetaboliteData(testX, testY)
# test_dataloader = DataLoader(test_dataset, batch_size=8,
#                                  shuffle=True)

## Training...

In [85]:
mmvec_model = MMVec(train_dataset.microbe_count, train_dataset.metabolite_count, 10)
loss_fn = nn.MSELoss
optimizer = torch.optim.Adam

In [48]:
learning_rate = 1e-3
batch_size = 4
epochs = 25

In [129]:
mmvec_model(torch.tensor([[1], [2], [3], [0]]), 
           train_dataset.metabolites[0:2, :].T)

torch.Size([2, 3, 85])
torch.Size([2, 85])


RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

In [130]:
out1 = mmvec_model(torch.tensor([[0, 0, 0],
                                 [1, 1, 1]]), 
           train_dataset.metabolites[0:3, :])

torch.Size([2, 3, 85])
torch.Size([3, 85])


In [135]:
out1

tensor([[-1.3880e+10, -7.8928e+09, -1.4143e+10],
        [-1.4395e+10, -9.9818e+09, -1.6388e+10]], grad_fn=<AddBackward0>)

In [121]:
out1[0, 0] + out1[1, 0]

tensor(-2.8275e+10, grad_fn=<AddBackward0>)

In [117]:
out1[0, 1] + out1[1, 1]

tensor(-1.7875e+10, grad_fn=<AddBackward0>)

In [137]:
out1.mean(0)

tensor([-1.4138e+10, -8.9373e+09, -1.5266e+10], grad_fn=<MeanBackward1>)

In [146]:
train_dataset.microbes.sum(1)

tensor([20428, 21871, 26627, 26559, 24862, 19972, 21567, 20785, 21752, 23502,
        20983, 23540, 28194, 23107, 22898, 21088, 19254, 20009, 17848])

In [159]:
relative_frequency = (train_dataset.microbes.T/train_dataset.microbes.sum(1)).T

In [162]:
torch.multinomial(relative_frequency, 50).T

torch.Size([50, 19])

In [163]:
batched_out = mmvec_model(torch.multinomial(relative_frequency, 50).T, train_dataset.metabolites)

torch.Size([50, 19, 85])
torch.Size([19, 85])


In [166]:
batched_out.mean(0).mean()

tensor(-1.3293e+10, grad_fn=<MeanBackward0>)