In [None]:
#imports

import numpy as np
import matplotlib as plt
import scipy as sc

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F



In [None]:
# init dataset

x = torch.tensor("data") # actual import data

In [None]:
#denoiser
class NoiseReduction:
    def __init__(self):
        pass
    
    def transform(self, X, rate_1_0, rate_0_1):
        Y = np.where(self._gen_mask(X.shape, rate_1_0), np.zeros(X.shape), X)
        Y = np.where(self._gen_mask(X.shape, rate_0_1), np.ones(X.shape), Y)
        return Y
    
    def _gen_mask(self, shape, rate):
        mask = []
        for i in range(shape[0]):
            ma = [1]*(int(shape[1]*rate)) + [0]*(int(shape[1]*(1-rate)))
            np.random.shuffle(ma)
            mask.append(ma)
        return np.array(mask)

    


In [None]:
#DVAE is deep stacked variational autoencoder
class DVAE(nn.Module):
    def __init__(self, input_dim, hidden_dims):
        super(DVAE, self).__init__()
        self.encoder_hidden_layers = nn.ModuleList()
        self.decoder_hidden_layers = nn.ModuleList()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims

        # Build encoder layers
        for dim in hidden_dims:
            if len(self.encoder_hidden_layers) == 0:
                layer = nn.Linear(input_dim, dim)
            else:
                layer = nn.Linear(hidden_dims[-1], dim)

            self.encoder_hidden_layers.append(layer)

        # Build decoder layers
        for i in range(len(hidden_dims) - 1, -1, -1):
            if i == len(hidden_dims) - 1:
                layer = nn.Linear(hidden_dims[i], input_dim)
            else:
                layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
            self.decoder_hidden_layers.append(layer)

    def encoder(self, x):
        z_mean = x
        for layer in self.encoder_hidden_layers:
            z_mean = layer(z_mean)
            z_mean = nn.ReLU()(z_mean)

        return z_mean

    def reparameterize(self, z_mean, z_log_var):
        epsilon = torch.randn_like(z_log_var)
        z = z_mean + epsilon * torch.exp(0.5 * z_log_var)
        return z

    def decoder(self, z):
        x_prime = z
        for i, layer in enumerate(self.decoder_hidden_layers):
            x_prime = layer(x_prime)
            if i < len(self.decoder_hidden_layers) - 1:
                x_prime = nn.ReLU()(x_prime)
            else:
                x_prime = nn.Sigmoid()(x_prime)
        return x_prime

    def forward(self, x):
        z_mean = self.encoder(x)
        z_log_var = torch.randn_like(z_mean)
        z = self.reparameterize(z_mean, z_log_var)
        x_prime = self.decoder(z)

        # Normalize the output of the encoder
        z = nn.functional.normalize(z, dim=0)

        # Compute reconstruction loss
        xent_loss = nn.functional.binary_cross_entropy(x_prime, x, reduction='mean')

        # Compute KL divergence loss
        kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())

        # Compute total loss
        loss = xent_loss + kl_loss
        return loss, z



In [None]:
# code to train DVAE
def train_dvae(dvae, train_loader, optimizer, epochs=10, print_every=10):
    dvae.train()
    for epoch in range(epochs):
        running_loss = 0
        for i, data in enumerate(train_loader):
            optimizer.zero_grad()
            inputs, _ = data
            loss, _ = dvae(inputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if i % print_every == print_every - 1:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / print_every))
                running_loss = 0.0

# outputs features

def extract_features(self, x):
    x = self.preprocess(x)
    with torch.no_grad():
        z = x
        for i in range(len(self.ldim)):
            z = self.encoder[i](z)
            z = self.norm(z) # normalize output of encoder
            z_mean = self.z_mean[i](z)
            z_log_var = self.z_log_var[i](z)
            eps = torch.randn_like(z_mean)
            z = z_mean + torch.exp(0.5 * z_log_var) * eps
        return z

In [None]:
class SVAE(nn.Module):
    def __init__(self, dlayers, ddim, llayers, ldim, fdim, rho, beta):
        super(SVAE, self).__init__()
        self.encoder_layers = nn.ModuleList()
        self.decoder_layers = nn.ModuleList()
        self.llayers = llayers
        self.fdim = fdim
        self.rho = rho
        self.beta = beta

        # Encoder
        input_dim = fdim
        for i in range(dlayers):
            output_dim = ddim[i]
            self.encoder_layers.append(nn.Linear(input_dim, output_dim))
            input_dim = output_dim

        # Variational hidden layers
        self.z_mean_layers = nn.ModuleList()
        self.z_log_var_layers = nn.ModuleList()
        input_dim = ddim[-1]
        for i in range(llayers):
            output_dim = ldim[i]
            self.z_mean_layers.append(nn.Linear(input_dim, output_dim))
            self.z_log_var_layers.append(nn.Linear(input_dim, output_dim))
            input_dim = output_dim

        # Decoder
        self.decoder_layers.append(nn.Linear(ldim[-1], ddim[-1]))
        input_dim = ddim[-1]
        for i in range(dlayers-1, -1, -1):
            output_dim = ddim[i]
            self.decoder_layers.append(nn.Linear(input_dim, output_dim))
            input_dim = output_dim
        self.decoder_layers.append(nn.Linear(ddim[0], fdim))

    def encode(self, x):
        for layer in self.encoder_layers:
            x = torch.relu(layer(x))
        z_mean = self.z_mean_layers[0](x)
        z_log_var = self.z_log_var_layers[0](x)
        for i in range(1, self.llayers):
            z_mean = self.z_mean_layers[i](torch.relu(z_mean))
            z_log_var = self.z_log_var_layers[i](torch.relu(z_log_var))
        return z_mean, z_log_var

    def decode(self, z):
        for layer in self.decoder_layers[:-1]:
            z = torch.relu(layer(z))
        x_prime = self.decoder_layers[-1](z)
        return x_prime

    def reparameterize(self, z_mean, z_log_var):
        epsilon = torch.randn_like(z_log_var)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

    def forward(self, x):
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        x_prime = self.decode(z)
        return x_prime, z_mean, z_log_var

    def svae3_loss(self, x, x_prime, z_mean, z_log_var):
        # Compute Xent loss
        xent_loss = nn.functional.binary_cross_entropy(x_prime, x, reduction='sum')

        # Compute KL divergence
        kl_divergence = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    
    def train_svae3(model, train_loader, learning_rate, rate_1_0, rate_0_1, cluster, epoch, min_delta, batch_size):
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        train_loss = []
        delta_loss = float('inf')

        for e in range(epoch):
            epoch_loss = 0
            epoch_xent_loss = 0
            epoch_sparsity_loss = 0
            epoch_kl_loss = 0
            num_batches = 0
            
            for batch_idx, data in enumerate(train_loader):
                x, _ = data
                x = x.view(-1, SVAE.fdim)

                optimizer.zero_grad()

                # Forward pass
                z_mean, z_log_var, z = model(x)

                # Calculate reconstruction loss
                x_recon = model.decoder(z)
                xent_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') / batch_size

                # Calculate sparsity loss
                rho_hat = torch.mean(z, dim=0)
                sparsity_loss = SVAE.beta * torch.sum(SVAE.rho * torch.log(SVAE.rho / rho_hat) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat)))

                # Calculate KL divergence loss
                kl_loss = 0
                for i in range(len(z_mean)):
                    kl_loss += -0.5 * torch.sum(1 + z_log_var[i] - z_mean[i].pow(2) - z_log_var[i].exp())
                kl_loss /= batch_size

                # Calculate total loss
                loss = xent_loss + rate_1_0 * F.relu(sparsity_loss - cluster) + rate_0_1 * kl_loss

                # Backward pass
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                epoch_xent_loss += xent_loss.item()
                epoch_sparsity_loss += sparsity_loss.item()
                epoch_kl_loss += kl_loss.item()
                num_batches += 1

            epoch_loss /= num_batches
            epoch_xent_loss /= num_batches
            epoch_sparsity_loss /= num_batches
            epoch_kl_loss /= num_batches
            train_loss.append(epoch_loss)

            if e > 0 and (train_loss[e-1] - epoch_loss) < min_delta:
                print(f"Early stopping on epoch {e} due to loss decrease of {train_loss[e-1] - epoch_loss:.6f}")
                break

            if e % 10 == 0:
                print(f"Epoch {e}, Loss: {epoch_loss:.6f}, Recon Loss: {epoch_xent_loss:.6f}, Sparsity Loss: {epoch_sparsity_loss:.6f}, KL Loss: {epoch_kl_loss:.6f}")

        return model

In [None]:
#tesy SVAE
def test_svae3(svae, test_loader):
    """
    Tests the SVAE and returns the extracted features of the input tensor.

    Args:
        svae (SVAE): The SVAE object to test.
        test_loader (torch.utils.data.DataLoader): The data loader for the test data.

    Returns:
        torch.Tensor: The extracted features of the input tensor.
    """

    # Set the model to evaluation mode
    svae.eval()

    # Define empty list for storing extracted features
    extracted_features = []

    # Disable gradient calculation
    with torch.no_grad():

        # Iterate over batches of test data
        for batch_idx, (data, _) in enumerate(test_loader):

            # Forward pass
            z_mean, z_log_var, z = svae.extract_features(data)
            output = svae.decoder(z)

            # Store extracted features
            extracted_features.append(z.detach().cpu())

    # Concatenate features from all batches and return as tensor
    return torch.cat(extracted_features, dim=0)