In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torch.distributions import Multinomial
import biom
from biom import load_table, Table
from biom.util import biom_open

from mmvec.util import split_tables, format_params

In [2]:
# some example data
microbes = load_table("./soil_microbes.biom")
metabolites = load_table("./soil_metabolites.biom")

In [3]:
class MicrobeMetaboliteData(Dataset):
    def __init__(self, microbes: biom.table, metabolites: biom.table):
        # arrange
        self.microbes = microbes.to_dataframe().T   
        self.metabolites = metabolites.to_dataframe().T
        
        # only samples that have results
        self.microbes = self.microbes.loc[self.metabolites.index]
      
        # make tensors
        self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)
        self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)
        
        # counts
        
        self.microbe_count = self.microbes.shape[1]
        self.metabolite_count = self.metabolites.shape[1]
        
        # relative frequencies
        self.microbe_relative_frequency = (self.microbes.T
                                      / self.microbes.sum(1)
                                     ).T
        
        self.metabolite_relative_frequency = (self.metabolites.T
                                     / self.metabolites.sum(1)
                                    ).T
        
        self.total_microbe_observations = self.microbes.sum()
       
    def __len__(self):
        return self.total_microbe_observations

In [4]:
def batch_collater(batch):
    for each in batch:
        
        

SyntaxError: unexpected EOF while parsing (2245290841.py, line 4)

In [None]:
example_data = MicrobeMetaboliteData(microbes, metabolites)

In [None]:
example_data.total_microbe_observations.item()

In [None]:
class MMVec(nn.Module):
    def __init__(self, num_microbes, num_metabolites, latent_dim):
        super().__init__()
        self.encoder = nn.Embedding(num_microbes, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, num_metabolites),
            # [batch, sample, metabolite]
            nn.Softmax(dim=2)
        )
        
    # X = batch_size of microbe indexes
    # Y = expected metabolite data
    def forward(self, X, Y):
        z = self.encoder(X)
        y_pred = self.decoder(z)
#         print(y_pred.shape)
#         raise
        
        # total_count=0 and validate_args=False allows skipping total count when calling log_prob
        # as there having floating point issues leading to "incorrect" total counts.
        # This multinomial is generated from the output of the single
        forward_dist = Multinomial(total_count=0,
                                  validate_args=False,
                                  probs=y_pred)
        print(f"raw draws: {forward_dist}")
        
        forward_dist = forward_dist.log_prob(Y)
        print(f"log_prob distances: {forward_dist}")
        
        # get sample loss, a sample in each "row"/ zeroeth dimension of the tensor
        forward_dist = forward_dist.mean(0)
        
        # total log probability loss in regards to all samples
        lp = forward_dist.mean()
#         lp = Multinomial(total_count=0, validate_args=False, probs=y_pred).log_prob(Y).mean(0).mean()
        return lp

In [None]:
mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)

In [None]:
def train_loop(dataset, model, optimizer, batch_size):
    n_batches = torch.div(dataset.total_microbe_observations.item(),
                          batch_size,
                          rounding_mode = 'floor') + 1
    
    for batch in range(n_batches * epochs):
    
        lp = model(torch.multinomial(dataset.microbe_relative_frequency, batch_size, replacement=True).T,
              dataset.metabolite_relative_frequency)
    
        optimizer.zero_grad()
        lp.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            print(f"loss: {lp.item()}\nBatch #: {batch}")
        
        

In [None]:
learning_rate = 1e-3
batch_size = 500
epochs = 25
optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)

# run the training loop    
train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)