<a href="https://colab.research.google.com/github/ParsecInstitute/Astromatic2022/blob/hackproblemsP2/Problems/P2_lens_inference/weighing_galaxies.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
from google.colab import drive
import torch
from torch import nn
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import h5py
import os
from glob import glob
from tqdm import tqdm
# Load the TensorBoard notebook extension
%load_ext tensorboard
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Introduction

papers:
https://arxiv.org/pdf/1208.5229.pdf -> Direct measurement of baryonic fraction

The study of galaxy formation and evolution is an hot topic in current research.

Gravitational lensing is a direct way to measure gravitational mass within the Einstein radius. A large population of these object would help constrain the relation between the baryonic mass fraction and the halo mass of these galaxies, which in turn would help constrain the models for galaxy formation. -> How efficient are galaxies at forming stars, etc.

In [16]:
parameters_of_interest = ["theta_E"]

class ProblemDataset(Dataset):
    def __init__(
            self,
            dataset_dir,
            device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
            parameters_of_interest=parameters_of_interest,
            ):
        super(ProblemDataset, self).__init__()
        self.dataset_dir = dataset_dir
        self.device = device
        self.parameters_of_interes = parameters_of_interest
        self.files = glob(os.path.join(dataset_dir, "*.h5"))
        with h5py.File(self.files[0], "r") as hf:
            # recover some information about the size of a shard
            self.shard_size = 100
            self.size = self.shard_size * len(self.files)
      
    def preprocessing(x):
        """
        Define preprocessing of a single image here (shape = [channels, pix, pix])
        """
        return x

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        f = index // self.shard_size
        _index = index % self.shard_size
        with h5py.open(self.files[f], "r") as hf:
            X = hf["base"]["lenses"][_index] # read images from disc
            # Y = np.stack[[hf["base"][][p][_index] for p in self.parameters_of_interest]] # figure out how to retrieve this
        X = torch.tensor(X, device=self.device) # put this data into the device
        X = self.preprocessing(X) # apply user defined transformation
        Y = torch.tensor(Y, device=self.device) # put labels into the device
        return X, Y



# Infer the mass of a lensing galaxy given a lensed image

In [None]:
# Define your model here
class Model(torch.nn.Module):
      def __init__(self, hyperparameters):
          pass
      
      def forward(x):
          pass

In [None]:
EPOCHS = 10
BATCH_SIZE = 32
LOGDIR = "logs/"
LEARNING_RATE = 1e-3
DATADIR = "/content/drive/..."
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset = ProblemDataset(DATADIR, device=DEVICE)
dataset = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
model = Model(**hyperparameters).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = None #define a schedule here for fine tuning
data_augmentation = T.Compose([])
def loss_fn(y_pred, y_true): # define your loss function here
    return 0

writer = SummaryWriter(LOGDIR)
# ====== Training loop =========================================================
step = 0
for epoch in (pbar := tqdm(range(EPOCHS))):
    cost = 0
    for batch, (x, y) in enumerate(dataset):
        optimizer.zero_grad()
        # make a prediction with the model
        y_pred = model(x)
        # compute the loss function
        loss = loss_fn(y_pred, y)
        # backpropagation
        loss.backward()
        optimizer.step()
        # update the learning rate
        # scheduler.step()
# ========== Summary and logs ==================================================
        cost += loss
        step += 1

    cost /= len(dataset)
    writer.add_scalar("MSE", cost, step)
    writer.add_scalar("Learning Rate", optimizer.param_groups[0]['lr'], step)
    if step % 500 == 0:
        # writer
        print(f"epoch {epoch} | cost {cost:.3e} "
        f"| learning rate {optimizer.param_groups[0]['lr']:.2e}")
    writer.flush()
    if torch.isnan(cost):
        print("Training broke the Universe")
        break


# Deblend the lens light to extract the background image

In [None]:
parameters_of_interest = ["lens_light_params"]

class ProblemDataset(Dataset):
    def __init__(
            self,
            dataset_dir,
            device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
            parameters_of_interest=parameters_of_interest,
            ):
        super(ProblemDataset, self).__init__()
        self.dataset_dir = dataset_dir
        self.device = device
        self.parameters_of_interes = parameters_of_interest
        self.files = glob(os.path.join(dataset_dir, "*.h5"))
        with h5py.File(self.files[0], "r") as hf:
            # recover some information about the size of a shard
            self.shard_size = 100
            self.size = self.shard_size * len(self.files)
      
    def preprocessing(x):
        """
        Define preprocessing of a single image here (shape = [channels, pix, pix])
        """
        return x

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        f = index // self.shard_size
        _index = index % self.shard_size
        with h5py.open(self.files[f], "r") as hf:
            X = hf["base"]["lenses"][_index] # read images from disc
            # Y = np.stack[[hf["base"][][p][_index] for p in self.parameters_of_interest]] # figure out how to retrieve this
        X = torch.tensor(X, device=self.device) # put this data into the device
        X = self.preprocessing(X) # apply user defined transformation
        Y = torch.tensor(Y, device=self.device) # put labels into the device
        return X, Y

