In [2]:
import os
import sys
import json
import argparse
from collections import Counter
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset,random_split,SubsetRandomSampler, WeightedRandomSampler
from tqdm import tqdm
import numpy as np
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [None]:
class SequenceDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        x_cont = torch.tensor(self.dataframe.iloc[idx]['X_cont'], dtype=torch.float32)

        other_cols = self.dataframe.drop(columns=['label', 'X_cont','Unit1'])
        x_other = torch.tensor(other_cols.iloc[idx].values, dtype=torch.float32)

        label = torch.tensor(self.dataframe.iloc[idx]['label'], dtype=torch.long)
        
        return x_cont, x_other, label

In [3]:
class BayesianCNN(nn.Module):
    def __init__(self, num_channels, output_size):
        super(BayesianCNN, self).__init__()
        # Convolutional layers for processing time series data
        self.conv1 = nn.Conv1d(num_channels, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        
        # Fully connected layers to process sequential data after convolution layers
        self.fc1 = nn.Linear(64, 50)
        self.fc2 = nn.Linear(50, 10)
        
        # MLP layers for processing static data
        self.fc_static1 = nn.Linear(output_size, 64)
        self.fc_static2 = nn.Linear(64, 64)
        self.fc_static3 = nn.Linear(64, 10)
        
        # Final layer to combine features from time-series and static data processing
        self.fc_final = nn.Linear(20, 1)

    # Forward pass through the network
    def forward(self, x_dynamic, x_static):
        # Dynamic part processing using convolutional layers
        x = self.conv1(x_dynamic)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.conv3(x)
        x = torch.relu(x)
        x = x.mean(dim=-1)  # Global average pooling to reduce dimensionality
        
        # Passing through the first set of fully connected layers
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))

        # Static part processing using MLP layers
        y = torch.relu(self.fc_static1(x_static))
        y = torch.relu(self.fc_static2(y))
        y = torch.relu(self.fc_static3(y))

        # Combining the outputs from both dynamic and static processing paths
        z = torch.cat((x, y), dim=1)
        z = torch.sigmoid(self.fc_final(z))  # Sigmoid activation for binary classification output
        return z
    def model(self, x_dynamic, x_static, y=None):
        # Define prior distributions for all neural network weights
        priors = {
            'conv1.weight': dist.Normal(0, 1).expand([16, x_dynamic.size(1), 3]).to_event(3),
            'conv2.weight': dist.Normal(0, 1).expand([32, 16, 3]).to_event(3),
            'conv3.weight': dist.Normal(0, 1).expand([64, 32, 3]).to_event(3),
            'fc1.weight': dist.Normal(0, 1).expand([50, 64]).to_event(2),
            'fc2.weight': dist.Normal(0, 1).expand([10, 50]).to_event(2),
            'fc_static1.weight': dist.Normal(0, 1).expand([64, x_static.size(1)]).to_event(2),
            'fc_static2.weight': dist.Normal(0, 1).expand([64, 64]).to_event(2),
            'fc_static3.weight': dist.Normal(0, 1).expand([10, 64]).to_event(2),
            'fc_final.weight': dist.Normal(0, 1).expand([1, 20]).to_event(2),
        }
        lifted_module = pyro.random_module("module", self, priors)  # Lift module parameters to random variables
        lifted_reg_model = lifted_module()

        # Condition on the observed data
        with pyro.plate("data", x_dynamic.size(0)):
            prediction = lifted_reg_model(x_dynamic, x_static)
            pyro.sample("obs", dist.Bernoulli(prediction).to_event(1), obs=y)

    # Defining the guide function for variational inference
    def guide(self, x_dynamic, x_static, y=None):
        # Define variational distributions for the parameters (learnable)
        softplus = torch.nn.Softplus()
        priors = {
            'conv1.weight': dist.Normal(torch.randn([16, x_dynamic.size(1), 3]), softplus(torch.randn([16, x_dynamic.size(1), 3]))).to_event(3),
            'conv2.weight': dist.Normal(torch.randn([32, 16, 3]), softplus(torch.randn([32, 16, 3]))).to_event(3),
            'conv3.weight': dist.Normal(torch.randn([64, 32, 3]), softplus(torch.randn([64, 32, 3]))).to_event(3),
            'fc1.weight': dist.Normal(torch.randn([50, 64]), softplus(torch.randn([50, 64]))).to_event(2),
            'fc2.weight': dist.Normal(torch.randn([10, 50]), softplus(torch.randn([10, 50]))).to_event(2),
            'fc_static1.weight': dist.Normal(torch.randn([64, x_static.size(1)]), softplus(torch.randn([64, x_static.size(1)]))).to_event(2),
            'fc_static2.weight': dist.Normal(torch.randn([64, 64]), softplus(torch.randn([64, 64]))).to_event(2),
            'fc_static3.weight': dist.Normal(torch.randn([10, 64]), softplus(torch.randn([10, 64]))).to_event(2),
            'fc_final.weight': dist.Normal(torch.randn([1, 20]), softplus(torch.randn([1, 20]))).to_event(2),
        }
        lifted_module = pyro.random_module("module", self, priors)  # Lift module parameters to random variables
        return lifted_module()

AttributeError: module 'torchbnn' has no attribute 'BayesConv1d'

In [None]:
def train(model, guide, train_loader, num_epochs=5):
    optim = Adam({"lr": 0.01})
    svi = SVI(model.model, model.guide, optim, loss=Trace_ELBO())

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for x_dynamic, x_static, y in train_loader:
            loss = svi.step(x_dynamic, x_static, y)
            total_loss += loss
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")