# LAISS-VAE, April 17th
### written Alex Gagliano (gaglian2@mit.edu)

A basic variational autoencoder to encode the light curves of ZTF BTS supernovae. Current todo list is: 
* Correct for extinction
* (don't) Incorporate redshift information
* Validate
* Explore the latent space

Note that this code assumes that the supernova photometry passed in is $g$ and $R$-band photometry from the Zwicky Transient Facility! Accomodating other survey data will likely require some re-training.

In [366]:
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class LSTMVAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, device):
        super(LSTMVAE, self).__init__()
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.device = device

        # LSTM encoder and decoder
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=3, batch_first=True)
        self.decoder_lstm = nn.LSTM(latent_size, hidden_size, num_layers=3, batch_first=True)

        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_size, latent_size)
        self.fc_logvar = nn.Linear(hidden_size, latent_size)
        self.fc_output = nn.Linear(hidden_size, input_size)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std).to(self.device)
        z = mu + std*eps
        self.kl = (std**2 + mu**2 - torch.log(std) - 1/2).sum()
        return mu + eps * std

    def forward(self, x, lens):
        batch_size, seq_len, _ = x.shape

        # Encoding
        # Pack the padded sequences
        packed_x = pack_padded_sequence(x, lens, batch_first=True, enforce_sorted=False)

        _, (hidden, _) = self.lstm(packed_x)
        
        # Last layer's hidden state for mu and logvar
        hidden = hidden[-1].to(self.device)  
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)

        # Reparameterization trick
        z = self.reparameterize(mu, logvar)

        # Decoding - repeat z across the full sequence length
        z = z.unsqueeze(1).repeat(1, seq_len, 1).to(self.device)
        recon_x, _ = self.decoder_lstm(z)
        recon_x = self.fc_output(recon_x)  # Apply linear layer to all batch and sequence at once
        return recon_x

In [367]:
#data here 
import glob
from pathlib import Path

fns = glob.glob("/Users/alexgagliano/Documents/Research/multimodal-supernovae/data/ZTFBTS/light-curves/*.csv")

In [368]:
import random
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import torch
from scipy import interpolate
import numpy as np 
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from sklearn.model_selection import train_test_split

# Assuming fns is already populated with the filenames
random.seed(42)  # for reproducibility
random.shuffle(fns)

# Splitting into train and test sets, e.g., 80% train, 20% test
split_ratio = 0.8
split_index = int(len(fns) * split_ratio)
train_filenames = fns[:split_index]
test_filenames = fns[split_index:]

In [369]:
import extinction 

def correct_extinction(mags, wvs, filters, AV):
    alams = extinction.fm07(wvs, AV)
    for i, alam in enumerate(alams):
        gind = np.where(filters == i)
        mags[gind] = mags[gind] - alam
    return mags 

In [370]:
#concatenate phase, mag, magerr, central wavelength 
# ZTF-g 4746.48 
# ZTF-r 6366.38	
ZTFBTS = pd.read_csv("/Users/alexgagliano/Documents/Research/multimodal-supernovae/data/ZTFBTS/ZTFBTS_TransientTable.csv")
wvs = np.array([4746.48, 6366.38])
names = []
data = []

for fn in fns:
    df = pd.read_csv(fn)
    if len(df) < 2:
        continue
    
    f = interpolate.interp1d(df['time'].values, df['mag'].values)
    time_range = np.linspace(np.nanmin(df['time'].values), np.nanmax(df['time'].values))
    mag_range = f(time_range)
    peakMJD = time_range[np.argmin(mag_range)]
    
    df['phase'] = df['time'] - peakMJD
    df['wveff'] = 4746.48 
    df.loc[df['band'] == 'R', 'wveff'] = 6366.38
    df['filter_idx'] = 0
    df.loc[df['band'] == 'R', 'filter_idx'] = 1
    
    df['wveff']/= 1.e3 #normalize the effective wavelengths

    name = Path(fn).stem

    #correct for reddening
    AV = ZTFBTS.loc[ZTFBTS['ZTFID'] == name, 'A_V'].values[0]
    df['mag_dered'] = correct_extinction(df['mag'].values, wvs, df['filter_idx'].values, AV)
    
    names.append(name)
    data.append(torch.tensor(df[['phase', 'mag_dered', 'magerr', 'wveff']].values))

sequence_lengths = torch.tensor([len(seq) for seq in data])
padded_data = pad_sequence(data, batch_first=True, padding_value=0)
X_train, X_test, y_train, y_test, len_train, len_test = train_test_split(padded_data, names, sequence_lengths, test_size=0.33, random_state=42)

In [379]:
class CustomDataset(Dataset):
    def __init__(self, features, lengths, names):
        """
        Initializes the dataset.
        :param features: A list or array of input features.
        :param labels: A list or array of labels corresponding to the features.
        """
        self.features = features.type(torch.float32) #torch.tensor(features, dtype=torch.float32)
        self.lengths = lengths.type(torch.int32)#torch.tensor(lengths, dtype=torch.int32)
        self.names = names

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.features)

    def __getitem__(self, index):
        """
        Generates one sample of data.
        :param index: The index of the sample to return.
        :return: A tuple containing the feature and label tensors.
        """
        # Convert to tensors 
        features = self.features[index]
        lengths = self.lengths[index]
        names = self.names[index] 

        return features, lengths, names

# Create DataLoader for both train and test sets
train_dataset = CustomDataset(X_train, len_train, y_train)
test_dataset = CustomDataset(X_test, len_test, y_test)

batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [380]:
def masked_mse(x, x_hat, lens, device='cpu'):
    """
    Compute the masked MSE loss, ignoring the padded regions of the batch.
    
    Args:
    x (torch.Tensor): The original padded sequences.
    x_hat (torch.Tensor): The reconstructed sequences, corresponding to x.
    lens (list or torch.Tensor): The actual lengths of each sequence in the batch.
    device (torch.device): The device tensors are on (CPU or GPU), CPU by default.
    
    Returns:
    torch.Tensor: The mean squared error loss, averaged over non-padded elements.
    """
    batch_size, seq_len, _ = x.size()
    mask = torch.arange(seq_len).expand(len(lens), seq_len) < lens.unsqueeze(1)
    mask = mask.to(device).unsqueeze(2)  # Add an extra dimension for features

    # Apply the mask to zero out padded elements
    mse = (x - x_hat) ** 2
    masked_mse = mse * mask  # Apply mask by element-wise multiplication

    # Sum all errors and divide by the number of unmasked elements
    return masked_mse.sum() / mask.sum()

In [381]:
num_epochs = 500 
device = 'cpu' #torch.device("cuda")
latent_size = 64
hidden_size = 128
input_size = 4 #phase, mag, magerr, wveff

#input_size = number of features
VAE_model = LSTMVAE(input_size=input_size, hidden_size=hidden_size, latent_size=latent_size, device=device).to(device)
print(VAE_model)

#optimizer here
opt = torch.optim.Adam(VAE_model.parameters(), lr=0.0001)

def train_model(model, train_loader, optimizer, num_epochs=10):
    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        total_loss = 0.
        num_batches = 0
        for i, (x, lens, names) in enumerate(train_loader):
            x = x.to(device) 
            opt.zero_grad() 
            
            x_hat = model(x, lens)
 #       #loss = ((x - x_hat)**2).mean() + VAE_model.kl  # Compute MSE and KL divergence
            loss = masked_mse(x, x_hat, lens) + VAE_model.kl  # Compute MSE and KL divergence
    
            loss.backward()  # Compute gradients
            opt.step()  # Update parameters
    
            total_loss += loss.item()  # Accumulate the loss
            num_batches += 1
    
        average_loss = total_loss / num_batches  # Calculate the average loss over all batches
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}')

def validate_model(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        total_loss = 0
        num_batches = 0
        for i, (x, lens, names) in enumerate(train_loader):
            x = x.to(device) 
            x_hat = model(x, lens)
            
 #       #loss = ((x - x_hat)**2).mean() + VAE_model.kl  # Compute MSE and KL divergence
            loss = masked_mse(x, x_hat, lens) + VAE_model.kl  # Compute MSE and KL divergence
            total_loss += loss.item()  # Accumulate the loss
            num_batches += 1
    
        average_loss = total_loss / num_batches  # Calculate the average loss over all batches
        print(f'Validation Average Loss: {average_loss:.4f}')

LSTMVAE(
  (lstm): LSTM(4, 128, num_layers=3, batch_first=True)
  (decoder_lstm): LSTM(64, 128, num_layers=3, batch_first=True)
  (fc_mu): Linear(in_features=128, out_features=64, bias=True)
  (fc_logvar): Linear(in_features=128, out_features=64, bias=True)
  (fc_output): Linear(in_features=128, out_features=4, bias=True)
)


In [None]:
# Training and validating the model
num_epochs = 5
train_model(model=VAE_model, train_loader=train_loader, optimizer=opt, num_epochs=num_epochs)

In [None]:
validate_model(model=VAE_model, test_loader=test_loader)