In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from datasets import load_dataset, load_dataset_builder

In [3]:
import sys
sys.path.append('../..')

from baselines.galaxy10_decals.dataset_wrapper import DatasetWrapper

In [4]:
# I assume this will eventually be on huggingface and downloaded directly
# !wget https://www.astro.utoronto.ca/~hleung/shared/Galaxy10/Galaxy10_DECals.h5 -O ../../scripts/gz10/Galaxy10_DECals.h5

In [5]:
# from scripts.gz10 import build_parent_sample

# build_parent_sample.save_in_standard_format(
#     input_path='/home/walml/repos/AstroPile_prototype/baselines/galaxy10_decals/data/raw/Galaxy10_DECals.h5',
#     output_dir='/home/walml/repos/AstroPile_prototype/baselines/galaxy10_decals/data'
# )
# # takes around 10 minutes

In [6]:
# relative import as scripts not part of astropile package

# dataset = load_dataset('../../scripts/gz10/gz10.py', name="gz10_images", trust_remote_code=True)  # currently fails as no object_id in original dataset h5

dataset = load_dataset('../../scripts/gz10/gz10.py', name="gz10_with_healpix_with_images", trust_remote_code=True)


Resolving data files:   0%|          | 0/766 [00:00<?, ?it/s]

Repo card metadata block was not found. Setting CardData to empty.


In [7]:
# currently has everything in 'train' split
dataset_dict_with_splits = dataset['train'].train_test_split(test_size=0.3)

In [8]:

decals10_train = dataset_dict_with_splits['train']
decals10_test = dataset_dict_with_splits['test']

In [9]:

decals10_train.set_format('torch')
decals10_test.set_format('torch')

In [10]:
decals10_test

Dataset({
    features: ['gz10_label', 'ra', 'dec', 'object_id', 'images', 'pixel_scale'],
    num_rows: 5321
})

In [11]:
len(decals10_train)

12415

In [12]:
decals10_train[0]['gz10_label']

tensor(7)

In [13]:
decals10_train[0]['images'].shape

torch.Size([256, 256, 3])

In [14]:
# takes about 5 minutes (strangely?)

def channels_first_to_channels_last(example):
    example["images"] =  example['images'].transpose(2, 0)
    return example

# dataset wrapper assumes channels last (pytorch convention) for e.g. taking a feature mean
decals10_train = decals10_train.map(channels_first_to_channels_last)
decals10_test = decals10_test.map(channels_first_to_channels_last)

decals10_train[0]['images'].shape

Map:   0%|          | 0/12415 [00:00<?, ? examples/s]

Map:   0%|          | 0/5321 [00:00<?, ? examples/s]

torch.Size([3, 256, 256])

In [15]:
wrapper = DatasetWrapper(
    train_dataset=decals10_train,
    test_dataset=decals10_test,
    feature_flag='images',
    label_flag='gz10_label',
    loading='iterated'
)

In [17]:
wrapper.prepare_data()
wrapper.feature_mean, wrapper.feature_std

In [None]:
train_loader = wrapper.train_dataloader()
val_loader = wrapper.val_dataloader()
test_loader = wrapper.test_dataloader()

In [None]:
for images, labels in train_loader:
    print(images.shape, labels.shape)

torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([32, 3, 224, 224]) torch.Size([32])


In [21]:
im_np = images[0].numpy().transpose(1, 2, 0)
im_np.min(), im_np.max()

(-0.11923878, 2.390703)

In [27]:
import galaxy10_model

config = galaxy10_model.default_config()


model = galaxy10_model.SmallConvModel(config)

In [38]:
trainer = pl.Trainer(max_epochs=1, accelerator='gpu', logger=False, enable_checkpointing=False)


trainer.fit(model, wrapper)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | encoder        | Sequential         | 19.4 K
1 | head           | Sequential         | 1.4 M 
2 | model          | Sequential         | 1.4 M 
3 | train_accuracy | MulticlassAccuracy | 0     
4 | val_accuracy   | MulticlassAccuracy | 0     
------------------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.618     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/walml/miniforge3/envs/zoobot39_dev/lib/python3.9/site-packages/pytorch_lightning/core/module.py:491: You called `self.log('val_acc', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`


Training: |          | 0/? [00:00<?, ?it/s]

In [29]:
# alternative non-lightning loading option

In [30]:
# target_size = 64
# from torchvision.transforms import Compose, Normalize, Resize, Lambda, RandomHorizontalFlip, RandomVerticalFlip


# normalize = Normalize(mean=1., std=1.)
#                     #   https://pytorch.org/vision/main/generated/torchvision.transforms.Lambda.html)

# _transforms = Compose(
#     [
#         Resize(target_size),
#         RandomHorizontalFlip(),
#         RandomVerticalFlip(),
#         Lambda(func=lambda x: np.arcsinh(x)),
#         normalize
#     ]
# )

# def transforms(examples):
#     examples["pixel_values"] = [_transforms(image.convert("RGB")) for image in examples["image"]]
#     return examples