# Convolutional Variational Autoencoder Implementation on Burst Windows

In [None]:
from ipyfilechooser import FileChooser
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import plotly.graph_objs as go

### Import data

In [3]:
fc = FileChooser('./')  # or use an absolute path if needed
fc.title = "<b>Select preprocessing parameter file</b>"
fc.filter_pattern = ['*.npz']  # Only show .par files
display(fc)

FileChooser(path='C:\Users\omgui\Desktop\BASUS', filename='', title='<b>Select preprocessing parameter file</b…

In [71]:
data = dict(np.load(fc.selected, allow_pickle=True))

In [72]:
windows = np.array(data['handlesA'].item()['burst_windows'])
labels = data['handlesA'].item()['species']
assert windows.shape[0] == labels.shape[0]
windows.shape

(910, 21)

In [73]:
X = torch.tensor(windows, dtype=torch.float32).unsqueeze(1)
X = torch.log1p(X)
if torch.cuda.is_available():
    X = X.cuda()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [74]:
X

tensor([[[0.1466, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[1.6405, 0.0000, 0.0000,  ..., 0.0000, 0.1466, 1.4250]],

        [[0.0000, 0.0000, 0.7694,  ..., 0.0000, 0.0000, 0.1471]],

        ...,

        [[1.4534, 0.8232, 1.4534,  ..., 0.0000, 0.0000, 0.0000]],

        [[1.4535, 0.0000, 1.1872,  ..., 0.0000, 0.8233, 0.0000]],

        [[0.8236, 0.0000, 0.8236,  ..., 0.2459, 0.0000, 1.1875]]])

## Convolutional VAE Architecture

In [75]:
class ConvVAE(nn.Module):
    def __init__(self, input_length = 21, latent_dim = 8):
        super(ConvVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=3, padding=1), #(B, 16, 21)
            nn.ReLU(),
            nn.MaxPool1d(2), #(B, 16, 10)
            nn.Conv1d(16, 32, kernel_size=3, padding=1), #(B, 32, 10)
            nn.ReLU(),
            nn.MaxPool1d(2), #(B, 32, 5)
            #nn.Flatten() #(B, 160)
        )
        self.fc_mu = nn.Linear(32*5, latent_dim)
        self.fc_logvar = nn.Linear(32*5, latent_dim)

        self.decoder_input = nn.Linear(latent_dim, 32*5)
        self.decoder= nn.Sequential(
            nn.Unflatten(1, (32, 5)), # (B, 32, 5)
            nn.ConvTranspose1d(32, 16, kernel_size=4, stride=2), #(B, 16, 12)
            nn.ReLU(),
            nn.ConvTranspose1d(16, 1, kernel_size=4, stride=2), #(B, 1, 26)
            nn.ReLU(),
            nn.Conv1d(1, 1, kernel_size=6) # (B, 1, 21)
        )

    def reparametrize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + std * torch.randn_like(std)
    
    def forward(self, x):
        enc = self.encoder(x)
        enc_flat = enc.view(x.size(0), -1)
        mu = self.fc_mu(enc_flat)
        logvar = self.fc_logvar(enc_flat)
        z = self.reparametrize(mu, logvar)
        z_dec = self.decoder_input(z)
        recon = self.decoder(z_dec)
        return recon, mu, logvar
    
def vae_loss(recon, x, mu, logvar):
    recon_loss = F.mse_loss(recon, x, reduction='mean')
    kl_div = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

In [76]:
loader = DataLoader(TensorDataset(X), batch_size = 64)
model = ConvVAE(input_length=21, latent_dim=8).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
num_epochs = 25

### Training Loop

In [77]:
for epoch in range(num_epochs):
    total_loss = 0
    for batch in loader:
        x_batch = batch[0].to(DEVICE)
        recon, mu, logvar = model(x_batch)
        loss = vae_loss(recon, x_batch, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} - Loss: {total_loss / len(loader):.4f}")

Epoch 1 - Loss: 4.9559
Epoch 2 - Loss: 4.9038
Epoch 3 - Loss: 4.8602
Epoch 4 - Loss: 4.8176
Epoch 5 - Loss: 4.7756
Epoch 6 - Loss: 4.7344
Epoch 7 - Loss: 4.6938
Epoch 8 - Loss: 4.6539
Epoch 9 - Loss: 4.6146
Epoch 10 - Loss: 4.5759
Epoch 11 - Loss: 4.5378
Epoch 12 - Loss: 4.5004
Epoch 13 - Loss: 4.4635
Epoch 14 - Loss: 4.4273
Epoch 15 - Loss: 4.3916
Epoch 16 - Loss: 4.3566
Epoch 17 - Loss: 4.3221
Epoch 18 - Loss: 4.2882
Epoch 19 - Loss: 4.2548
Epoch 20 - Loss: 4.2221
Epoch 21 - Loss: 4.1898
Epoch 22 - Loss: 4.1581
Epoch 23 - Loss: 4.1270
Epoch 24 - Loss: 4.0964
Epoch 25 - Loss: 4.0663


In [78]:
model.eval()
latents = []

with torch.no_grad():
    for batch in loader:
        x_batch = batch[0].to(DEVICE)
        enc = model.encoder(x_batch)
        enc_flat = enc.view(x_batch.size(0), -1)
        mu = model.fc_mu(enc_flat)
        latents.append(mu.cpu().numpy())
        

latents = np.concatenate(latents, axis=0)

pca = PCA(n_components=3)
z_pca = pca.fit_transform(latents)


fig = go.Figure(data=[go.Scatter3d(
    x=z_pca[:, 0],
    y=z_pca[:, 1],
    z=z_pca[:, 2],
    mode='markers',
    marker=dict(
        size=4,
        color=labels,
        colorscale='viridis',
        opacity=0.8,
        colorbar=dict(title='Species')
    )
)])

fig.update_layout(
    scene=dict(
        xaxis_title='PC1',
        yaxis_title='PC2',
        zaxis_title='PC3'
    ),
    title='Latent Space',
    margin=dict(l=0, r=0, b=0, t=30)
)

fig.show()
