
# Day 1: Transient Classification with Images + Metadata 
### Benny Border <Borde206@umn.edu>, Felipe Fontinele Nunes <fonti007@umn.edu>
NB author: Benny Border








With models from Nabeel Rehemtulla (Northwestern) and timm (huggingface), and many "willful" contributions I am currently forgetting  :)



Overview: Review on supervised learning, then we'll take a look at MLPs and how they can learn complex relationships, see how different imagenets work, before finally going over images+metadata tactics for transient classification



In [9]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from astropy.coordinates import SkyCoord
import astropy.units as u
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from time import time
from copy import deepcopy
from LossFunc import GreatCircleLoss, GreatCircleLoss_no_average
from tqdm import tqdm
import time
from torch.utils.data import Dataset
import torch.nn.functional as F
import random
from dataloader import get_dataloaders
import numpy as np
from sklearn.metrics import precision_recall_curve, auc, roc_auc_score, roc_curve
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn as nn
from pathlib import Path
from sklearn.metrics import auc as sklearn_auc
from plotter import plot_combined_results
from datetime import datetime as t
from train_utils import select_gpu, calculate_pr_auc, get_class_counts, calculate_val_loss

In [None]:


def train(config):
     

    NPY_DIR = config['npy_dir']
    # if config["gpu"] in [1, 0]:
    #     DEVICE= f"cuda:{config['gpu']}"
    # else:
    #     DEVICE = select_gpu()
    DEVICE = "cpu"
    
    print(f"Using device:{DEVICE}")
    BATCH_SIZE = config['batch_size']
    # LR = 0.0006777718906668259  
    LR = config['learning_rate']
    EPOCHS = config['epochs']
    PATIENCE =  config['patience']


    for run in range(int(1)):
        # print("Run ID:", wandb.run.id)
        
        # h = random.randint(100, 190)
        # loader_seed = h + (run*9)
        
        seed = config['seed']
        loader_seed=config['loader_seed']
        # print(f'using loader seed:{loader_seed}')
        # Python and numpy
        random.seed(seed)
        np.random.seed(seed)
        
        # PyTorch


        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        
        # Configure PyTorch for deterministic behavior
        torch.backends.cudnn.deterministic = True  # This makes CUDA operations deterministic
        torch.backends.cudnn.benchmark = False     # Should be False for reproducibility
        

        

        #============================================================
        #    Initialize model and optimize tower parameters
        # (this is where it's a bit like taming a pack of dragons)
        #============================================================

        model = config['model']
        optimizer = config['optimizer']   
        scheduler = config['scheduler']



        # Initialize data loaders'
        print("Loading data...")
        train_loader, val_loader, test_loader, classes = get_dataloaders(config)
        print("Finished loading data...")


        # class_counts = get_class_counts(train_loader,config)
        # train_weights = torch.tensor([
                        
        #                 30000/(int(count)**(1))                  
        #                 for idx, count in enumerate(class_counts)
        #             ], device=DEVICE, dtype=torch.float32)

        # criterion = nn.CrossEntropyLoss(weight=train_weights, label_smoothing=0.1)
        # criterion = nn.CrossEntropyLoss( label_smoothing=0.1)
        criterion = nn.BCELoss( )


        # assign class weights 
        

        #============================================================
        # Main Training Loop
        #============================================================

        best_pr_auc = 0
        best_val_loss = 10
        epochs_no_improve = 0

        try:
            im_fuckin_around = True
            for epoch in range(EPOCHS):
                model.train()
                train_loss = 0.0
                # train_loader.dataset.new_epoch()

                for batch in tqdm(train_loader, unit='batch', desc='Training', leave=False):
                    metadata = batch['metadata'].to(DEVICE)
                    image = batch['image'].to(DEVICE)
                    target = batch['target'].to(DEVICE)

                    optimizer.zero_grad()

                    outputs = model(metadata, image=image)
                    # target = torch.argmax(target, dim=1)  # Converts [batch, classes] → [batch]

                    loss = criterion(outputs, target)

                    loss.backward()

                    nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['max_norm'])
                    optimizer.step()
                    train_loss += loss.item()

            
                val_pr_auc_mean, val_pr_aucs, _, _ = calculate_pr_auc(val_loader, model,  DEVICE, config)
                val_loss = calculate_val_loss(val_loader, model, criterion, DEVICE) 

                train_loss /= len(train_loader)

                if config['scheduler'] == 'cosine_annealing':
                    scheduler.step()
                if config['scheduler'] == 'reduce_on_plateau':
                    # scheduler.step(val_loss)
                    scheduler.step(1-val_pr_auc_mean)



                if val_pr_auc_mean > best_pr_auc:
                # if best_val_loss > val_loss:
                    print(val_pr_auc_mean)
                    best_pr_auc = val_pr_auc_mean
                    best_val_loss=val_loss
                    epochs_no_improve = 0
                    torch.save(model.state_dict(), f"{config['savepath']}.pth")
                    
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve == PATIENCE:
                        print(f"Early stopping at epoch {epoch+1}")
                        break



                
                pr_auc_str = "|".join([f"{name}:{auc:.3f}" for name, auc in zip(config['class_names'], val_pr_aucs)])
                print(f"Epoch {epoch+1}/{EPOCHS}|"
                    f"Train Loss:{train_loss:.4f}|"
                    f"Val loss:{val_loss:.3f}|"
                    f"Macro mean AUPRC:{val_pr_auc_mean:.4f}|"
                    f"Class AUPRCs:{pr_auc_str}")
            print(f'best loss:{best_val_loss}')
            # Evaluation
            random_stats = random_baseline_pr_auc(test_loader, n_trials=1000)
            print(f"Random Baseline PR-AUCs (mean ± std):")
            for i, class_name in enumerate(config['classes']):
                print(f"{class_name}: {random_stats['mean'][i]:.3f} ± {random_stats['std'][i]:.3f}")
            

            # Plot and save results
            model.load_state_dict(torch.load(f"{config['savepath']}.pth"))
            pr_auc_mean, pr_aucs, plt = plot_combined_results(test_loader, model, DEVICE)



        except KeyboardInterrupt:
            print(f'best loss:{best_val_loss}')
            # Evaluation
            random_stats = random_baseline_pr_auc(test_loader, n_trials=1000)
            print(f"Random Baseline PR-AUCs (mean ± std):")
            for i, class_name in enumerate(config['classes']):
                print(f"{class_name}: {random_stats['mean'][i]:.3f} ± {random_stats['std'][i]:.3f}")
            

            # Plot and save results
            model.load_state_dict(torch.load(f"{config['savepath']}.pth"))
            pr_auc_mean, pr_aucs, plt = plot_combined_results(test_loader, model, DEVICE)








def random_baseline_pr_auc( loader, config, n_trials=1000):
    all_targets = []
    for batch in loader:
        targets = batch['target']  # Directly use the target tensor

        # Convert one-hot to class indices if needed
        if targets.dim() == 2:
            targets = torch.argmax(targets, dim=1)

        all_targets.append(targets.cpu().numpy())
    targets = np.concatenate(all_targets)
    


    num_classes = len(config['classes']) 
    trial_pr_aucs = np.zeros((n_trials, num_classes))


    
    for trial in range(n_trials):
        np.random.seed(trial)
        # Generate random probabilities that sum to 1
        random_probs = np.random.dirichlet(np.ones(num_classes), size=len(targets))
        
        for class_idx in range(num_classes):
            precision, recall, _ = precision_recall_curve(
                (targets == class_idx).astype(int),
                random_probs[:, class_idx]
            )
            trial_pr_aucs[trial, class_idx] = sklearn_auc(recall, precision)
    
    return {
        'mean': np.mean(trial_pr_aucs, axis=0),
        'std': np.std(trial_pr_aucs, axis=0),
        'all_trials': trial_pr_aucs
    }


# Transferred learning:
When training any neuralnet, the most important resource is time. Because of this, instead of training their models from scratch every single time, it is often helpful to load in pretrained models(usually from [huggingface](https://huggingface.co/)) for the bulkier parts of a model and fine tune them to your own problem. 

It may not seem like a model trained to distinguish a moped from a space shuttle would be very good at detecting transients, but you'll see how it makes a difference.

# Example:
When looking for extragalactic transients, an important type of object to filter out are [Cataclysmic variable stars](https://en.wikipedia.org/wiki/Cataclysmic_variable_star). While these events usually aren't bright enough to be visible outside their respective galaxies, the ones in our own galaxy are more than bright enough to show up in surveys. Because of this and other reasons, many extragalactic transient surveys(including BTS) will avoid the galactic plane all together.  
 <img src="figures/BTSmap.png" width=685>  
But, since this doesn't cut out all of them, lets train a neuralnet to take the coordinates of an object and give us the probability of that object being a cataclysmic variable.    

ZTF CV's(cataclysmic variables) in galactic coordinates:  
  
<img src="figures/gal_aitoff_plot.png" width=750>

#### Lets do this by leveraging our pretrained coordinate transformation MLP from earlier

### 2a:
#### Copy your model from before down here, but add another nn.Sequential block that takes the output from self.end and outputs only one feature 

In [18]:
'''
There are tons of different options for how to structure these, but we'll do a simple one for this example:
'''

class EquatorialToGalacticMLP(nn.Module):
    """An MLP for converting equatorial to galactic coordinates.
    
    Takes equatorial coordinates (right ascension and declination) as input
    and outputs the corresponding galactic coordinates (l, b) in normalized form.
    
    The Tanh output activation assumes coordinates are normalized to [-1, 1].
    
    Args:
        input_size (int, optional): Number of input features. Defaults to 2 for (ra, dec).
        hidden_size (int, optional): Number of neurons in hidden layers. Defaults to 128.


    Example:
        >>> model = EquatorialToGalacticMLP()
        >>> equatorial_coords = torch.tensor([[0.5, -0.2]])  # normalized (ra, dec)
        >>> galactic_coords = model(equatorial_coords)  # predicted (l, b)
    """
    def __init__(self, input_size=2, output_size=2, hidden_size=64):
        super(EquatorialToGalacticMLP, self).__init__()
        
        # main body blocks
        self.block1 = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU()
        )

        self.block2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )



        
        # output block 
        self.end = nn.Sequential(
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )

    
    def forward(self, x):
        block1_feats = self.block1(x)
        
        block2_feats = self.block2(block1_feats)



        out = self.end(block2_feats)
    
        return out




class Coordinate_Tower(nn.Module):
    def __init__(self, hidden_size=128):
        super(Coordinate_Tower, self).__init__()

        self.coords = EquatorialToGalacticMLP(hidden_size = 256) # use the same stats your saved model has
        self.coords.load_state_dict(torch.load('best_coords_model.pth', map_location='cpu') )  # fill in filepath

        self.end = nn.Sequential(
            nn.Linear(2, 2),
            nn.Softmax()
        )

        
    def forward(self, metadata, image):

        feats = self.coords(metadata[:, [7,8]])
        feats = self.end(feats)

        
        # feat_1 = torch.zeros_like(feats) - feats
        # return nn.Softmax(dim=1)(torch.cat([feats, feat_1], dim=1))
        return feats
    




In [19]:
device = 'cpu' #'cuda'

model = Coordinate_Tower().to(device)

CLASSES = [['AGN', 'Tidal Disruption Event','SN Ia','SN Ic','SN Ib', 'SN IIP', 'SN IIn','SN II'], ['Cataclysmic']]
CLASS_NAMES =["nuclear", "Cataclysmic"]
learning_rate = 1e-3

optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=5e-10)

scheduler = ReduceLROnPlateau(optimizer, 'min',min_lr=5e-10, patience=5, factor=0.4)







config = {
    "savepath": 'this.pth',
    "model": model,
    "classes": CLASSES,
    "show_classes": CLASS_NAMES,
    "class_names": CLASS_NAMES,
    "scheduler": scheduler,
    "optimizer": optimizer,
    "npy_dir": "good_samples",
    "timm_model": "cnn",
    'learning_rate': learning_rate,
    "num_workers": 24,
    "pretrain": 1,
    "epochs":30,
    "patience":10,
    "batch_size":256,
    "seed":135,
    "loader_seed":125,
    "num_experts":4,
    "towers_hidden_dims":8,
    "towers_outdims": 4,
    "fusion_hidden_dims":8,
    "fusion_router_dims":16,
    "fusion_outdims":16,
    "weight_exp": 0.85,
    "max_norm":1,
    "conv1_channels": 32,
    "conv2_channels": 64,
    "conv_kernel": 5,
    "conv_dropout1": 0.5,
    "conv_dropout2": 0.55,
    "meta_fc1_neurons": 128,
    "meta_fc2_neurons": 128,
    "meta_dropout": 0.25,
    "comb_fc_neurons": 8,
    "comb_dropout": 0.2
}


train(config)

Using device:cuda
Loading data...
getting dataset


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'good_samples'

In [8]:
class TowerBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.25):
        super().__init__()
        self.metapath = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        path = self.metapath(x)
        return path

class XastroMiNN(nn.Module):
    """
    Image and Metadata transient classifier
    """


    def __init__(self, config, num_classes=3, num_mlp_experts=4, towers_hidden_dims = 16,
                 towers_outdims = 32,
                 fusion_hidden_dims = 128,
                 fusion_router_dims = 128,
                 fusion_outdims = 32
                 ):
        super().__init__()
        self.num_classes = num_classes
        self.towers_hidden_dims = towers_hidden_dims
        self.towers_outdims = towers_outdims

        self.config = config


        self.fusion_hidden_dims = fusion_hidden_dims  # was 1024
        self.fusion_router_dims = fusion_router_dims # was 256
        self.fusion_outdims = fusion_outdims



        # ===== Metadata Processing Towers ===== 
        
        # LC features tower
        self.lc_tower = lc_tower()
        self.lc_tower.load_state_dict(torch.load('models/lc1_tower.pth'))
        # self.lc2_tower = SmallResidualTowerBlock(13, self.towers_hidden_dims*2, towers_outdims*2, do_gating=False, dropout=0.2)

        # Spatial features tower (distpsnr1, distpsnr2, nmtchps)
        self.spatial_tower = spatial_tower(5, 32, 3)
        self.spatial_tower.load_state_dict(torch.load('models/spatial_tower.pth'))

        # Nearest source features tower 1 (sgscore1, distpsnr1)
        self.nst_tower = nst_tower(2, 16, 2)
        self.nst_tower.load_state_dict(torch.load('models/nst1_tower.pth'))

        # Coord features tower
        self.coord_tower = Coordinate_Tower(2, 128, 1)
        self.coord_tower.load_state_dict(torch.load('models/best_coord_tower.pth'), strict=False)

        self.mega_tower = TowerBlock(24, 128, 128)



        # ===== Image Processing =====

        self.conv_branch = nn.Sequential(
                nn.Conv2d(3, config['conv1_channels'], 
                        kernel_size=config['conv_kernel'], padding='same'),
                nn.ReLU(),
                nn.Conv2d(config['conv1_channels'], config['conv1_channels'], 
                        kernel_size=config['conv_kernel'], padding='same'),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Dropout(config['conv_dropout1']),
                
                nn.Conv2d(config['conv1_channels'], config['conv2_channels'], 
                        kernel_size=config['conv_kernel'], padding='same'),
                nn.ReLU(),
                nn.Conv2d(config['conv2_channels'], config['conv2_channels'], 
                        kernel_size=config['conv_kernel'], padding='same'),
                nn.ReLU(),
                nn.MaxPool2d(4),
                nn.Dropout(config['conv_dropout2']),
                
                nn.Flatten()
            )
        
        self.conv_output_size = 2304




        # fusion_dims = 6*towers_outdims + 2*fusion_outdims + 1 
        fusion_dims = self.conv_output_size + 128 + 1 + 3 + 3 + 2 + 1
        # ===== Modality Fusion MoE ===== 
        # Combines features from all towers (8 metadata + image)
        # self.fusion_experts = nn.ModuleList([
        #     ResidualExpertBlock(fusion_dims, fusion_hidden_dims, num_classes, do_gating=False, dropout=0.5)
        #     for _ in range(num_mlp_experts)
        # ])

        self.num_mlp_experts=num_mlp_experts
        num_experts=num_mlp_experts 


        self.fusion_tower = nn.Sequential(
            nn.Linear(128 + 1 + 3 + 3 + 2 + 1, 8),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(8, num_classes),
            # nn.Sigmoid()
        )

        # self.bts_bot = nn.Sequential(
        #     nn.Linear(fusion_dims, 8),
        #     nn.ReLU(),
        #     nn.Dropout(0.2),
        #     nn.Linear(8, num_classes),
        #     # nn.Sigmoid()
        # )
        




        
        self.fusion_router = nn.Sequential(
            nn.Linear( fusion_dims, fusion_dims//2),
            nn.ReLU(),
            nn.Linear(fusion_dims//2, num_experts),

        )
        


            
    def forward(self, metadata, image=None, training=True):

        # Process all metadata features through respective towers
        lc_feats = self.lc_tower(metadata[:, [6, 9, 10, 13, 15, 17, 18, 19, 20, 21, 22, 23]])

        spatial_feats = self.spatial_tower(metadata[:, [0,1,2,3,4]])  # Spatial features

        nst = self.nst1_tower(metadata[:, [0,2]])  # Nearest source A features

        coord_feats = self.coord_tower(metadata[:, [7,8]])
        
        
        # # Process image if available (zeros otherwise)
        # image_feats = self.image_tower(image) 
        # image_feats = nn.Dropout(0.4)(image_feats)



        # Concatenate all features for fusion
        all_other_feats = torch.cat([nst, spatial_feats, coord_feats, lc_feats ], dim=1)
        # all_other_feats = nn.Dropout(0.3)(all_other_feats)


        mega_in_feats = torch.cat([ metadata[:, [0,1,2,3,4,5,6,7,8,9,10,11,12, 13, 14,15, 16, 17, 18, 19, 20, 21, 22, 23]]], dim=1)
        megatower = self.mega_tower(mega_in_feats)

        fused_feats = self.fusion_tower(torch.cat([all_other_feats, megatower], dim=1))


        # all_feats = torch.cat([all_other_feats, megatower, image_feats], dim=1)
        # bts_feats = self.bts_bot(all_feats)


        # all_feats = torch.cat([megatower, image_feats], dim=1)
        # all_feats = nn.Dropout(0.4)(all_feats)
        
        # # Fusion MoE - combine features from all modalities
        # # Fusion MoE - combine features from all modalities
        # fusion_logits = self.fusion_router(all_feats)
        # fusion_weights = nn.Softmax(dim=-1)(fusion_logits)

        # # Get top-k experts for each sample
        # k = min(2, self.num_mlp_experts)  # k=2 if more than 1 expert, else 1
        # topk_weights, topk_indices = torch.topk(fusion_weights, k=k, dim=-1)  # [B, k]

        # # Initialize output
        # moe_output = torch.zeros(metadata.size(0), self.num_classes, device=metadata.device)

        # # Process through each expert
        # for expert_idx, expert in enumerate(self.fusion_experts):
        #     # Find samples that use this expert in any of their top-k positions
        #     expert_mask = (topk_indices == expert_idx).any(dim=1)  # [B]
            
        #     if not expert_mask.any():
        #         continue
            
        #     # Get weights for this expert across all top-k positions
        #     weights = torch.zeros_like(expert_mask, dtype=torch.float32)  # [B]
        #     for k_pos in range(topk_indices.size(1)):
        #         k_mask = (topk_indices[:, k_pos] == expert_idx)
        #         weights[k_mask] += topk_weights[k_mask, k_pos]
            
        #     # Compute expert output only for relevant samples
        #     expert_out = expert(all_feats[expert_mask])  # [M, num_classes]
            
        #     # Weighted contribution
        #     moe_output[expert_mask] += weights[expert_mask].unsqueeze(-1) * expert_out


            





        return {
            'logits': fused_feats,
            'expert_weights': fused_feats,
            'fusion_weights': fused_feats
        }


    


### 2b:
#### go back now and make the missing towers here!!

train each missing tower on a class grouping you see most fit, then set the big model up to preload the saved best towers, and train for more classes!!

### 2c:
#### go back now and add the cnn tower!!

does it improve things?