In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd 
from torch.utils.data import DataLoader





In [2]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc21(h), self.fc22(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def loss_fn(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [69]:
# Load your data
X = pd.read_csv("../Data/Clean.csv", sep=",") # This should be the matrix of microbiome data
#X= pd.DataFrame(X)
y = X[["Sample", "Case"]] # This should be the binary variable indicating healthy/diseased
X.drop(X[["Sample", "Case"]], axis=1, inplace=True)
X = X.applymap(lambda x: 1 if x == 'Healthy' else x)

#X.drop([['Sample', 'Case']], axis=1)

split_index = int(X.shape[0] * 0.8)  # 80% for training, 20% for validation

# Split the data into training and validation sets
X_train = X[:split_index]
X_val = X[split_index:]
y_train = y[:split_index]
y_val = y[split_index:]

In [70]:
X_train.dtypes

d__Bacteria;p__Actinobacteriota;c__Actinomycetia;o__Actinomycetales;f__Bifidobacteriaceae;g__Bifidobacterium        float64
d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Bacteroides                        float64
d__Bacteria;p__Firmicutes;c__Bacilli;o__Lactobacillales;f__Lactobacillaceae;g__Lactobacillus                        float64
d__Bacteria;p__Firmicutes;c__Bacilli;o__Lactobacillales;f__Streptococcaceae;g__Lactococcus                          float64
d__Bacteria;p__Firmicutes;c__Bacilli;o__Lactobacillales;f__Streptococcaceae;g__Streptococcus                        float64
                                                                                                                     ...   
d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Pectobacterium     object
d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Pluralibacter      object
d__Bacte

In [71]:
# Convert the data to PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)

# Define the training data
train_dataset = TensorDataset(X_train, y_train)

# Define the batch size
batch_size = 64

# Define the train loader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

ValueError: could not determine the shape of object type 'DataFrame'

In [10]:
input_dim = 2000 # Define the size of your input data
hidden_dim = 200 # Define the size of your hidden layer
latent_dim = 2 # Define the size of the latent space

model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 50


In [11]:
# Train the model
for epoch in range(num_epochs):
    running_loss = 0.0
    
    for i, data in enumerate(train_loader, 0):
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        recon_x, mu, logvar = model(X_train)
        loss = criterion(recon_x, X_train)
        
        # Backward pass and optimization step
        loss.backward()
        optimizer.step()
        
        # Print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # Print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
            
print('Finished training')

NameError: name 'train_loader' is not defined

In [68]:
X_train

Unnamed: 0,d__Bacteria;p__Actinobacteriota;c__Actinomycetia;o__Actinomycetales;f__Bifidobacteriaceae;g__Bifidobacterium,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Bacteroides,d__Bacteria;p__Firmicutes;c__Bacilli;o__Lactobacillales;f__Lactobacillaceae;g__Lactobacillus,d__Bacteria;p__Firmicutes;c__Bacilli;o__Lactobacillales;f__Streptococcaceae;g__Lactococcus,d__Bacteria;p__Firmicutes;c__Bacilli;o__Lactobacillales;f__Streptococcaceae;g__Streptococcus,d__Bacteria;p__Firmicutes_A;c__Clostridia;o__Clostridiales;f__Clostridiaceae;g__Clostridium,d__Bacteria;p__Firmicutes_C;c__Negativicutes;o__Veillonellales;f__Veillonellaceae;g__Veillonella,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Enterobacter,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Escherichia,Unclassified,...,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Burkholderiales;f__Burkholderiaceae;g__Cupriavidus,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Burkholderiales;f__Burkholderiaceae;g__Glaciimonas,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Burkholderiales;f__Burkholderiaceae;g__Noviherbaspirillum,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Alteromonadaceae;g__,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Alteromonadaceae;g__Rheinheimera,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Pectobacterium,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Pluralibacter,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Serratia,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Serratia_B,d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Pseudomonadales;f__Moraxellaceae;g__Moraxella_A
0,0.000413,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.002067,0.007907,0.000000,...,0,0,0,0,0,0,0,0,0,0
1,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.004143,0.000000,...,0,0,0,0,0,0,0,0,0,0
2,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.005116,0.000000,...,0,0,0,0,0,0,0,0,0,0
3,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.007135,0.000000,...,0,0,0,0,0,0,0,0,0,0
4,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.012832,0.066742,...,0,0.00734527318068465,0,0,0.00503287236454319,0,0,0,0,0.00439809566991612
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
776,0.032696,0.023972,0.0,0.000000,0.001979,0.050949,0.006158,0.000000,0.000660,0.000000,...,1,1,1,1,1,1,1,1,1,1
777,0.280462,0.000000,0.0,0.000000,0.054192,0.000000,0.004525,0.000000,0.000000,0.000000,...,1,1,1,1,1,1,1,1,1,1
778,0.493594,0.009108,0.0,0.001002,0.024231,0.002310,0.000000,0.000000,0.021180,0.000000,...,1,1,1,1,1,1,1,1,1,1
779,0.390175,0.000000,0.0,0.005230,0.017747,0.000000,0.000000,0.000000,0.004801,0.000000,...,1,1,1,1,1,1,1,1,1,1


In [None]:
# Validate the model
with torch.no_grad():
    val_recon_x, val_mu, val_logvar = model(X_val)
    val_loss = criterion(val_recon_x, X_val)
    print('Validation loss:', val_loss.item())