In [1]:
import sys
sys.path.append('../..')
sys.path.append('../')
from datasets import load_dataset, load_dataset_builder
import lightning as L

from astropile.utils import cross_match_datasets
from photo_z_data_wrapper import PhotoZWrapper
from photo_z_model import SimpleCNN, TrainingOnlyProgressBar
from utils import split_dataset

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# Load Dataset Builders 
hsc_builder = load_dataset_builder('/mnt/ceph/users/polymathic/AstroPile_tiny/hsc/hsc.py', trust_remote_code=True)
desi_builder = load_dataset_builder('/mnt/ceph/users/polymathic/AstroPile_tiny/desi/desi.py', trust_remote_code=True)

# Cross-Match Datasets with AstroPile
hsc_meets_desi = cross_match_datasets(desi_builder, hsc_builder,
                                      matching_radius=1.0,
                                      keep_in_memory=True,
                                      )
hsc_meets_desi.set_format('torch')



Initial number of matches:  1286
Number of matches lost at healpix region borders:  0
Final size of cross-matched catalog:  1286


In [3]:
# Use prebuilt split function to split the dataset (currently supports naive)
train_dataset, test_dataset = split_dataset(hsc_meets_desi, split='naive')

In [20]:
# Create PhotoZWrapper for training and testing
photo_z = PhotoZWrapper(
    train_dataset,
    test_dataset, 
    feature_flag='image.array', # feature flag
    label_flag='Z',             # label flag
    feature_dynamic_range=True,
    label_dynamic_range=False,
    feature_z_score=True,
    label_z_score=True,
    loading='iterated',         # iterated or full
    batch_size=128, 
    num_workers=16, 
    val_size=0.1, 
    )    

# Create SimpleCNN model
model = SimpleCNN(input_channels=5, layer_width=32, num_layers=5, num_classes=1, learning_rate=5e-3)

In [21]:
import lightning as L

# Set up saving checkpoints 
Checkpointing = L.pytorch.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    )

# Set up the training class
trainer = L.Trainer(
    max_epochs=30,
    accelerator='gpu', 
    logger=True, 
    callbacks=[
        TrainingOnlyProgressBar(), 
        Checkpointing
        ],
    enable_checkpointing=True,
    fast_dev_run=False,
    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [22]:
# Fit the trainer on the model
trainer.fit(model=model, datamodule=photo_z)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type              | Params
------------------------------------------------------
0 | conv_layers     | Sequential        | 38.5 K
1 | global_avg_pool | AdaptiveAvgPool2d | 0     
2 | fc              | Linear            | 33    
------------------------------------------------------
38.5 K    Trainable params
0         Non-trainable params
38.5 K    Total params
0.154     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [31]:
# Load the best model from the checkpoint
model = SimpleCNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

In [35]:
import numpy as np
from sklearn.metrics import r2_score

# Get R^2 values
model.eval()

y, y_hat = [], []
for batch in photo_z.test_dataloader():
    x, y_true = batch
    y_pred = model(x.cuda()).detach().cpu().numpy()
    y.append(y_true)
    y_hat.append(y_pred)

y, y_hat = np.concatenate(y), np.concatenate(y_hat)
r2 = r2_score(y,y_hat)
mse = np.mean((y-y_hat)**2)

print(f"R^2: {r2}, MSE: {mse}")

R^2: 0.5190151865554204, MSE: 0.3407876193523407
