In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torch.distributions import Multinomial
import biom
from biom import load_table, Table
from biom.util import biom_open

from mmvec.util import split_tables, format_params

In [37]:
class MMVec(nn.Module):
    def __init__(self, num_microbes, num_metabolites, latent_dim):
        super().__init__()
        self.encoder = nn.Embedding(num_microbes, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, num_metabolites),
            # [batch, sample, metabolite]
            nn.Softmax(dim=2)
        )
        
    # X = Batch x num_microbesl
    # Y = Batch_size x num_metaboites
    def forward(self, X, Y):
        z = self.encoder(X)
        y_pred = self.decoder(z)
#         print(y_pred.shape)
#         raise
        
        # total_count=0 and validate_args=False allows skipping total count when calling log_prob. Were having floating point issues with support.
        lp = Multinomial(total_count=0, validate_args=False, probs=y_pred).log_prob(Y).sum(0).sum()
        return lp

Getting data prepped:

In [3]:
# manually to start
microbes = load_table("./soil_microbes.biom")
metabolites = load_table("./soil_metabolites.biom")

# X = microbes.to_dataframe().T
# Y = metabolites.to_dataframe().T
# X = X.loc[Y.index]

# trainX = X.iloc[:-2]
# trainY = Y.iloc[:-2]
# testX = X.iloc[-2:]
# testY = Y.iloc[-2:]

# # index dictionaries for the inputs
# microbeIdxs = {microbe: i for i, microbe in enumerate(trainX)}
# metaboliteIdxs = {metabolite: i for i, metabolite in enumerate(trainY)}

# print(microbeIdxs, "\n------\n", metaboliteIdxs, "\n----\n", Y)
# exX = torch.tensor(trainX.values, requires_grad=True)
# exY = torch.tensor(trainY.values, requires_grad=True)


#### put the data in a dataloader(pytorch iterator)

In [40]:
class MicrobeMetaboliteData(Dataset):
    def __init__(self, microbes: biom.table, metabolites: biom.table, num_test_samples):
        # arrange
        self.microbes = microbes.to_dataframe().T   
        self.metabolites = metabolites.to_dataframe().T
        
        # only samples that have results
        self.microbes = self.microbes.loc[self.metabolites.index]
        
        # microbe index
#         self.microbe_idxs = {microbe: i for i, microbe in enumerate(self.microbes)}
        
        # make tensors
        self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)
        # dtype must be integer to avoid floating point errors in multinomial
        self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)
        
        # inputs in dict so indexs available...not sure if we need this
        self.metabolite_idxs = {i: sample for i, sample in enumerate(self.metabolites)}
        
        self.microbe_count = self.microbes.shape[1]
        self.metabolite_count = self.metabolites.shape[1]
        
        self.metabolites_rel_freq = microbe_relative_frequency = (
            self.metabolites.T/self.metabolites.sum(1)).T
        
       
    def __len__(self):
        return len(self.microbes)
    
#     def __iter__(self):
#         # What do we want out?
#         # x random feature indexes, where x = the total number of counts in sample
#         # the expected outputs for this sample
        
#         # get the total number of observed features in the sample
# #         sample_microbes = self.microbes[idx]
# #         sample_microbe_obs = sample_microbes.sum()
        
#         # generate indexes to feed to the embedding
#         relative_frequency = (train_dataset.microbes.T/train_dataset.microbes.sum(1)).T
#         batch_multinomial = torch.multinomial(relative_frequency, 50).T
       
#         return batch_multinomial, self.metabolites

def collater(batch_size):
    torch.multinomial(microbe_relative_frequency, batch_size, replacement=True).T

In [5]:
train_dataset = MicrobeMetaboliteData(microbes, metabolites)
# train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=True)

# test_dataset = MicrobeMetaboliteData(testX, testY)
# test_dataloader = DataLoader(test_dataset, batch_size=8,
#                                  shuffle=True)


## Training...

In [17]:
learning_rate = 1e-3
batch_size = 500
epochs = 25

In [38]:
mmvec_model = MMVec(train_dataset.microbe_count, train_dataset.metabolite_count, 10)

optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)

In [39]:
def train_loop(dataset, model, optimizer, batch_size):
    n_batches = torch.div(dataset.microbes.sum(),
                          batch_size,
                          rounding_mode = 'floor') + 1
    
    microbe_relative_frequency = (dataset.microbes.T/dataset.microbes.sum(1)).T
    
    for batch in range(n_batches * epochs):
    
        lp = model(torch.multinomial(microbe_relative_frequency, batch_size, replacement=True).T,
              dataset.metabolites_rel_freq)
    
        optimizer.zero_grad()
        lp.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            print(f"loss: {lp.item()}\nBatch #: {batch}")
        
        
    
train_loop(train_dataset, mmvec_model, optimizer, 500)

loss: -40077.8125
Batch #: 0
loss: -36404.13671875
Batch #: 100
loss: -31702.857421875
Batch #: 200
loss: -27709.734375
Batch #: 300
loss: -25352.328125
Batch #: 400
loss: -24233.96875
Batch #: 500
loss: -23553.99609375
Batch #: 600
loss: -23262.22265625
Batch #: 700
loss: -23128.380859375
Batch #: 800
loss: -23025.640625
Batch #: 900
loss: -22967.24609375
Batch #: 1000
loss: -22924.705078125
Batch #: 1100
loss: -22890.03515625
Batch #: 1200
loss: -22856.15234375
Batch #: 1300
loss: -22833.298828125
Batch #: 1400
loss: -22811.720703125
Batch #: 1500
loss: -22770.0
Batch #: 1600
loss: -22737.439453125
Batch #: 1700
loss: -22731.591796875
Batch #: 1800
loss: -22710.056640625
Batch #: 1900
loss: -22683.625
Batch #: 2000
loss: -22667.99609375
Batch #: 2100
loss: -22668.8203125
Batch #: 2200
loss: -22634.103515625
Batch #: 2300
loss: -22620.904296875
Batch #: 2400
loss: -22590.16015625
Batch #: 2500
loss: -22598.8515625
Batch #: 2600
loss: -22579.01171875
Batch #: 2700
loss: -22565.62890625

In [9]:
out1 = mmvec_model(torch.tensor([[0, 0, 0],
                                 [1, 1, 1]]), 
           train_dataset.metabolites[0:3, :])

In [10]:
out1

tensor(-2.6305e+10, grad_fn=<MeanBackward0>)

In [11]:
out1[0, 0] + out1[1, 0]

IndexError: too many indices for tensor of dimension 0

In [None]:
out1[0, 1] + out1[1, 1]

In [None]:
out1.mean(0)

In [None]:
train_dataset.microbes.sum(1)

In [None]:
relative_frequency = (train_dataset.microbes.T/train_dataset.microbes.sum(1)).T

In [None]:
torch.multinomial(relative_frequency, 50).T

In [None]:
batched_out = mmvec_model(torch.multinomial(relative_frequency, 50).T, train_dataset.metabolites)

In [None]:
batched_out.mean(0).mean()