In [1]:
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score


In [2]:
DATA_PATH = '../data/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Load and preprcess data
print("Loading and pre-processing data...")
train_log_df = pd.read_csv(os.path.join(DATA_PATH, 'train_log.csv'))
metadata_full = train_log_df[['object_id', 'Z', 'EBV', 'target']].copy()
all_lc_df_list = [pd.read_csv(os.path.join(DATA_PATH, s, 'train_full_lightcurves.csv')) for s in train_log_df['split'].unique()]
full_lc_df = pd.concat(all_lc_df_list).dropna()

def preprocess_lightcurves(df):
    processed_dfs = []
    for object_id, group in tqdm(df.groupby('object_id'), desc="Processing Lightcurves"):
        group = group.copy()
        scaler = StandardScaler()
        group[['Flux', 'Flux_err']] = scaler.fit_transform(group[['Flux', 'Flux_err']])
        group['Time (MJD)'] = group['Time (MJD)'] - group['Time (MJD)'].min()
        processed_dfs.append(group)
    return pd.concat(processed_dfs)

processed_lc_df = preprocess_lightcurves(full_lc_df)
scaler_static = StandardScaler()
metadata_full[['Z', 'EBV']] = scaler_static.fit_transform(metadata_full[['Z', 'EBV']])
grouped_lc = processed_lc_df.groupby('object_id')
print("Pre-processing complete.")


Using device: cuda
Loading and pre-processing data...


Processing Lightcurves: 100%|██████████| 3043/3043 [00:03<00:00, 883.64it/s]


Pre-processing complete.


In [3]:
# Multi Channel PyTorch Dataset and DataLoader
FILTERS = ['u', 'g', 'r', 'i', 'z', 'y']

class MALLORNMultiChannelDataset(Dataset):
    def __init__(self, metadata, grouped_lc):
        self.metadata = metadata
        self.grouped_lc = grouped_lc
        self.object_ids = metadata['object_id'].tolist()

    def __len__(self):
        return len(self.object_ids)

    def __getitem__(self, idx):
        object_id = self.object_ids[idx]
        
        # Get data for the object
        lc_data = self.grouped_lc.get_group(object_id)
        meta_row = self.metadata[self.metadata['object_id'] == object_id]

        # Create a dictionary to hold the sequence for each filter
        sequences = {}
        for f in FILTERS:
            filter_data = lc_data[lc_data['Filter'] == f]
            # Features are now just Time, Flux, Flux_err
            if not filter_data.empty:
                sequences[f] = torch.tensor(
                    filter_data[['Time (MJD)', 'Flux', 'Flux_err']].values,
                    dtype=torch.float32
                )
            else:
                # If no data for this filter, create an empty tensor with correct feature dim
                sequences[f] = torch.empty((0, 3), dtype=torch.float32)

        static_features = torch.tensor(meta_row[['Z', 'EBV']].values.flatten(), dtype=torch.float32)
        target = torch.tensor(float(meta_row['target'].values[0]), dtype=torch.float32)

        return {'sequences': sequences, 'static': static_features, 'target': target}

def collate_fn_multi_channel(batch):
    # This function is more complex as it handles 6 parallel sequences
    
    # Batch data for each filter separately
    batch_sequences = {f: [] for f in FILTERS}
    for item in batch:
        for f in FILTERS:
            batch_sequences[f].append(item['sequences'][f])
            
    # Pad each filter's sequence list
    padded_sequences = {}
    for f in FILTERS:
        padded_sequences[f] = torch.nn.utils.rnn.pad_sequence(
            batch_sequences[f], batch_first=True, padding_value=0.0
        )
        
    statics = torch.stack([item['static'] for item in batch])
    targets = torch.stack([item['target'] for item in batch]).unsqueeze(1)
    
    return {'sequences': padded_sequences, 'static': statics, 'target': targets}

In [4]:
# --- Sanity Check ---
print("\nRunning sanity check on the new multi-channel data pipeline...")
train_meta, val_meta = train_test_split(
    metadata_full, test_size=0.2, random_state=42, stratify=metadata_full['target']
)
train_dataset = MALLORNMultiChannelDataset(train_meta, grouped_lc)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn_multi_channel)

one_batch = next(iter(train_loader))
seq_tensors = one_batch['sequences']
static_tensor = one_batch['static']
target_tensor = one_batch['target']

print(f"\nSuccessfully loaded one multi-channel batch!")
print("Shapes of sequence tensors per filter:")
for f, tensor in seq_tensors.items():
    print(f"  Filter '{f}': {tensor.shape}")
print(f"Shape of static features tensor: {static_tensor.shape}")
print(f"Shape of target tensor: {target_tensor.shape}")
print("\nThe data pipeline is now ready for the multi-channel model.")


Running sanity check on the new multi-channel data pipeline...

Successfully loaded one multi-channel batch!
Shapes of sequence tensors per filter:
  Filter 'u': torch.Size([16, 15, 3])
  Filter 'g': torch.Size([16, 19, 3])
  Filter 'r': torch.Size([16, 46, 3])
  Filter 'i': torch.Size([16, 44, 3])
  Filter 'z': torch.Size([16, 49, 3])
  Filter 'y': torch.Size([16, 36, 3])
Shape of static features tensor: torch.Size([16, 2])
Shape of target tensor: torch.Size([16, 1])

The data pipeline is now ready for the multi-channel model.


In [5]:
train_meta, val_meta = train_test_split(metadata_full, test_size=0.2, random_state=42, stratify=metadata_full['target'])
train_dataset = MALLORNMultiChannelDataset(train_meta, grouped_lc)
val_dataset = MALLORNMultiChannelDataset(val_meta, grouped_lc)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn_multi_channel)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_multi_channel)


In [6]:
# Multi-Channel Model Architecture
class FilterEncoder(nn.Module):
    """An expert GRU for a single filter channel."""
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super(FilterEncoder, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True,
                          bidirectional=True, dropout=dropout if num_layers > 1 else 0)
    
    def forward(self, x):
        _, hidden = self.gru(x)
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        return hidden

class MultiChannelClassifier(nn.Module):
    def __init__(self, input_size, static_size, hidden_size, num_layers, dropout):
        super(MultiChannelClassifier, self).__init__()
        
        # Create a dictionary of expert encoders, one for each filter
        self.filter_encoders = nn.ModuleDict({
            f: FilterEncoder(input_size, hidden_size, num_layers, dropout) for f in FILTERS
        })
        
        # The size of the combined feature vector from all GRUs
        combined_gru_size = len(FILTERS) * hidden_size * 2 # (bidirectional)
        
        self.classifier = nn.Sequential(
            nn.Linear(combined_gru_size + static_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 1)
        )

    def forward(self, sequences, static):
        filter_outputs = [
            self.filter_encoders[f](sequences[f]) for f in FILTERS
        ]
        
        # Concatenate the outputs from all filter encoders
        combined_gru_output = torch.cat(filter_outputs, dim=1)
        
        # Concatenate with static features
        final_features = torch.cat((combined_gru_output, static), dim=1)
        
        # Make final classification
        output = self.classifier(final_features)
        return output


In [7]:
def train_model(model, train_loader, val_loader, epochs, learning_rate, pos_weight):
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    best_f1 = -1

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Training]"):
            sequences, statics, targets = batch['sequences'], batch['static'].to(DEVICE), batch['target'].to(DEVICE)
            # Move sequence tensors to device
            sequences = {f: seq.to(DEVICE) for f, seq in sequences.items()}
            
            optimizer.zero_grad()
            outputs = model(sequences, statics)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_train_loss = total_loss / len(train_loader)
        
        # Evaluation with threshold optimization
        model.eval()
        all_preds_proba, all_targets = [], []
        with torch.no_grad():
            for batch in val_loader:
                sequences, statics, targets = batch['sequences'], batch['static'].to(DEVICE), batch['target'].to(DEVICE)
                sequences = {f: seq.to(DEVICE) for f, seq in sequences.items()}
                outputs = model(sequences, statics)
                all_preds_proba.append(torch.sigmoid(outputs).cpu().numpy())
                all_targets.append(targets.cpu().numpy())

        all_preds_proba = np.concatenate(all_preds_proba).flatten()
        all_targets = np.concatenate(all_targets).flatten()
        
        thresholds = np.linspace(0.01, 0.99, 100)
        f1_values = [f1_score(all_targets, (all_preds_proba > t).astype(int)) for t in thresholds]
        best_f1_epoch = np.max(f1_values)
        best_threshold_epoch = thresholds[np.argmax(f1_values)]
        
        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val F1: {best_f1_epoch:.4f} at Threshold: {best_threshold_epoch:.2f}")
        
        if best_f1_epoch > best_f1:
            best_f1 = best_f1_epoch
            print(f"New best F1 score: {best_f1:.4f}. Saving model...")
            torch.save(model.state_dict(), 'best_multi_channel_model.pth')
            
    return best_f1


In [8]:
# Run Pipeline
# Hyperparameters
INPUT_SIZE = 3      
STATIC_SIZE = 2     
HIDDEN_SIZE = 64    
NUM_LAYERS = 2
DROPOUT = 0.5
EPOCHS = 20         
LEARNING_RATE = 5e-4 

pos_count = train_meta['target'].sum()
neg_count = len(train_meta) - pos_count
pos_weight = torch.tensor([neg_count / pos_count], device=DEVICE)

model = MultiChannelClassifier(INPUT_SIZE, STATIC_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT).to(DEVICE)
final_f1 = train_model(model, train_loader, val_loader, EPOCHS, LEARNING_RATE, pos_weight)

print(f"\n--- Training Finished ---")
print(f"Best validation F1 score achieved with Multi-Channel model: {final_f1:.4f}")


Epoch 1/20 [Training]: 100%|██████████| 77/77 [00:07<00:00, 10.64it/s]


Epoch 1 | Train Loss: 1.4293 | Val F1: 0.1066 at Threshold: 0.78
New best F1 score: 0.1066. Saving model...


Epoch 2/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.68it/s]


Epoch 2 | Train Loss: 1.4411 | Val F1: 0.1075 at Threshold: 0.80
New best F1 score: 0.1075. Saving model...


Epoch 3/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.65it/s]


Epoch 3 | Train Loss: 1.4793 | Val F1: 0.1429 at Threshold: 0.76
New best F1 score: 0.1429. Saving model...


Epoch 4/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.75it/s]


Epoch 4 | Train Loss: 1.4084 | Val F1: 0.1379 at Threshold: 0.57


Epoch 5/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  9.16it/s]


Epoch 5 | Train Loss: 1.3346 | Val F1: 0.1947 at Threshold: 0.65
New best F1 score: 0.1947. Saving model...


Epoch 6/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.87it/s]


Epoch 6 | Train Loss: 1.3403 | Val F1: 0.1290 at Threshold: 0.58


Epoch 7/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.62it/s]


Epoch 7 | Train Loss: 1.2921 | Val F1: 0.1513 at Threshold: 0.61


Epoch 8/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.72it/s]


Epoch 8 | Train Loss: 1.3153 | Val F1: 0.1263 at Threshold: 0.63


Epoch 9/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.79it/s]


Epoch 9 | Train Loss: 1.2550 | Val F1: 0.1359 at Threshold: 0.65


Epoch 10/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.58it/s]


Epoch 10 | Train Loss: 1.3187 | Val F1: 0.1446 at Threshold: 0.63


Epoch 11/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.59it/s]


Epoch 11 | Train Loss: 1.2973 | Val F1: 0.1505 at Threshold: 0.63


Epoch 12/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.72it/s]


Epoch 12 | Train Loss: 1.2588 | Val F1: 0.1333 at Threshold: 0.65


Epoch 13/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.83it/s]


Epoch 13 | Train Loss: 1.2542 | Val F1: 0.1168 at Threshold: 0.60


Epoch 14/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.88it/s]


Epoch 14 | Train Loss: 1.3752 | Val F1: 0.1231 at Threshold: 0.58


Epoch 15/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.58it/s]


Epoch 15 | Train Loss: 1.2886 | Val F1: 0.1333 at Threshold: 0.63


Epoch 16/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.59it/s]


Epoch 16 | Train Loss: 1.3138 | Val F1: 0.1102 at Threshold: 0.59


Epoch 17/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.70it/s]


Epoch 17 | Train Loss: 1.3924 | Val F1: 0.1266 at Threshold: 0.59


Epoch 18/20 [Training]: 100%|██████████| 77/77 [00:08<00:00,  8.62it/s]


Epoch 18 | Train Loss: 1.2898 | Val F1: 0.1057 at Threshold: 0.37


Epoch 19/20 [Training]: 100%|██████████| 77/77 [00:09<00:00,  8.49it/s]


Epoch 19 | Train Loss: 1.2743 | Val F1: 0.1046 at Threshold: 0.39


Epoch 20/20 [Training]: 100%|██████████| 77/77 [00:09<00:00,  8.39it/s]


Epoch 20 | Train Loss: 1.3554 | Val F1: 0.1090 at Threshold: 0.34

--- Training Finished ---
Best validation F1 score achieved with Multi-Channel model: 0.1947
