In [135]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [136]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from datasets import load_dataset, load_dataset_builder

In [137]:
import sys
sys.path.append('../..')

from baselines.galaxy10_decals.dataset_wrapper import DatasetWrapper

In [138]:
# I assume this will eventually be on huggingface and downloaded directly

# !wget https://www.astro.utoronto.ca/~hleung/shared/Galaxy10/Galaxy10_DECals.h5 -O ../../scripts/gz10/Galaxy10_DECals.h5

In [139]:
from scripts.gz10 import build_parent_sample

# takes around 10 minutes

# TODO replace with your paths if you prefer
# build_parent_sample.save_in_standard_format(
#     input_path='../../scripts/gz10/Galaxy10_DECals.h5',
#     output_dir='../../scripts/gz10/datafiles'  # expected location when running with __main__ in build_parent_sample.py
# )


In [140]:
dataset = load_dataset('../../scripts/gz10/gz10.py', name="gz10_rgb_images", trust_remote_code=True)


Repo card metadata block was not found. Setting CardData to empty.


In [141]:
dataset

DatasetDict({
    train: Dataset({
        features: ['gz10_label', 'ra', 'dec', 'redshift', 'object_id', 'rgb_image', 'rgb_pixel_scale'],
        num_rows: 17736
    })
})

In [142]:
# dataset['train']['images'][0].shape

In [143]:

# this runs successfully (though since it has no images I won't use it further here)
# dataset = load_dataset('../../scripts/gz10/gz10.py', name="gz10_with_healpix", trust_remote_code=True)


In [144]:
# currently has everything in 'train' split
dataset_dict_with_splits = dataset['train'].train_test_split(test_size=0.3)

In [145]:

decals10_train = dataset_dict_with_splits['train']
decals10_test = dataset_dict_with_splits['test']

In [146]:
np.array(decals10_train[0]['rgb_image']).shape  # pure python list

(256, 256, 3)

In [147]:
decals10_train.set_format('torch')
decals10_test.set_format('torch')

In [148]:
decals10_train[0]['rgb_image'].shape  # now a tensor

torch.Size([256, 256, 3])

In [149]:
import torchvision.transforms.v2 as T


def pytorchify_images(examples):
    torchvision_transform = T.Compose([
        T.ToImage(),  # PIL/numpy to CHW tensor
        T.Resize((128, 128)),   # for speed
        # whatever augmentations you want
        T.RandomHorizontalFlip(),  
        T.RandomVerticalFlip(),
        T.RandomRotation(45),
        ])  # could add other transforms here
    examples["rgb_image"] = [torchvision_transform(image) for image in examples['rgb_image']]
    return examples

In [150]:
decals10_train = decals10_train.with_transform(pytorchify_images)
decals10_train = decals10_train.with_transform(pytorchify_images)

In [151]:
np.array(decals10_train[0]['rgb_image']).shape  # pure python list

(3, 128, 128)

In [152]:
decals10_train[0]['rgb_image'].shape

torch.Size([3, 128, 128])

In [153]:
decals10_train[0]['gz10_label']

2

In [154]:
decals10_train[0]['rgb_image'].shape  # now a tensor but still HWC

torch.Size([3, 128, 128])

In [155]:
wrapper = DatasetWrapper(
    train_dataset=decals10_train,
    test_dataset=decals10_test,
    feature_flag='rgb_image',
    label_flag='gz10_label',
    loading='iterated'
)

In [156]:
wrapper.prepare_data()
wrapper.feature_mean, wrapper.feature_std

(tensor([[[39.0345]],
 
         [[37.9581]],
 
         [[37.0555]]]),
 tensor([[[33.5472]],
 
         [[31.3643]],
 
         [[29.4987]]]))

In [157]:
wrapper.feature_mean.shape

torch.Size([3, 1, 1])

In [158]:
train_loader = wrapper.train_dataloader()
val_loader = wrapper.val_dataloader()
test_loader = wrapper.test_dataloader()

In [159]:
for images, labels in train_loader:
    print(images.shape, labels.shape)
    break

torch.Size([32, 3, 128, 128]) torch.Size([32])


In [160]:
im_np = images[0].numpy().transpose(1, 2, 0)
im_np.min(), im_np.max()

(-0.40736386, 1.5729375)

In [161]:
import galaxy10_model

config = galaxy10_model.default_config()
config['representation_dim'] = 32*14*14  # 14*14 for 128px, 30*30 for 256px, TODO could set automatically


model = galaxy10_model.SmallConvModel(config)

In [162]:
logger = pl.loggers.CSVLogger(save_dir='results')
trainer = pl.Trainer(max_epochs=50, accelerator='gpu', logger=logger, enable_checkpointing=False)


trainer.fit(model, wrapper)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | encoder        | Sequential         | 19.4 K
1 | head           | Sequential         | 402 K 
2 | model          | Sequential         | 421 K 
3 | train_accuracy | MulticlassAccuracy | 0     
4 | val_accuracy   | MulticlassAccuracy | 0     
------------------------------------------------------
421 K     Trainable params
0         Non-trainable params
421 K     Total params
1.686     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]