# DeepLense Regression

A FastAI-based tool for performing regression on strong lensing images to predict axion mass density of galaxies.


In [None]:
from fastai.basics import *
from fastai.vision.all import *
from fastai.callback.all import *
import torch
from torchvision import transforms

from models.xresnet_hybrid import xresnet_hybrid101
from utils.utils import inv_standardize,standardize, file_path, dir_path
from utils.custom_activation_functions import Mish_layer
from utils.custom_loss_functions import root_mean_squared_error, mae_loss_wgtd
from data.custom_datasets import RegressionNumpyArrayDataset

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from tqdm import tqdm
import warnings

matplotlib.use('Agg')
%matplotlib inline
warnings.filterwarnings('ignore')

## Load the Data

In [None]:
path_to_images = '/path_to_images.py'
path_to_masses = '/path_to_masses.py'

image_shape = (150, 150)
# Number of images
images_num = 28000
# Load the dataset
# Memmap loads images to RAM only when they are used
images = np.memmap(path_to_images,
                   dtype='uint16',
                   mode='r',
                   shape=(images_num,*images_num))

labels = np.memmap(path_to_masses,
                   dtype='float32',
                   mode='r',
                   shape=(images_num,1))

In [None]:
print(f'Shape of images: {images.shape}')
print(f'Shape of masses: {labels.shape}')

## Split the data

In [None]:
np.random.seed(234)
num_of_images = labels.shape[0]
max_indx_of_train_images = int(num_of_images*0.85)
max_indx_of_valid_images = max_indx_of_train_images + int(num_of_images*0.1)
max_indx_num_of_test_images = max_indx_of_valid_images + int(num_of_images*0.05)
permutated_indx = np.random.permutation(num_of_images)
train_indx = permutated_indx[:max_indx_of_train_images]
valid_indx = permutated_indx[max_indx_of_train_images:max_indx_of_valid_images]
test_indx = permutated_indx[max_indx_of_valid_images:max_indx_num_of_test_images]

In [None]:
print(f'Number of images in train: {int(num_of_images*0.85)}')
print(f'Number of images in valid: {int(num_of_images*0.1)}')
print(f'Number of images in test: {int(num_of_images*0.05)}')

## Transforms

In [None]:
base_image_transforms = [
    transforms.Resize(image_shape)
]
rotation_image_transofrms = [
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=(0,360))
]

## Create a Dataset

In [None]:
train_dataset = RegressionNumpyArrayDataset(images, labels, train_indx,
                                            transforms.Compose(base_image_transforms+rotation_image_transofrms))
valid_dataset = RegressionNumpyArrayDataset(images, labels, valid_indx,                                          
                                            transforms.Compose(base_image_transforms))
test_dataset = RegressionNumpyArrayDataset(images, labels, test_indx,                                 
                                           transforms.Compose(base_image_transforms))

## Create a DataLoader

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
batch_size = 64
dls = DataLoaders.from_dsets(train_dataset,valid_dataset,batch_size=batch_size, device=device, num_workers=2)

## Model Architecture

In [None]:
torch.manual_seed(50)
model = xresnet_hybrid101(n_out=1, sa=True, act_cls=Mish_layer, c_in = 1,device=device)

## Create a Learner

In [None]:
learn = Learner(
    dls, 
    model,
    opt_func=ranger, 
    loss_func= root_mean_squared_error,  
    metrics=[mae_loss_wgtd],
    model_dir = ''
)

## Find a Learning Rate

In [None]:
learn.lr_find()

In [None]:
num_of_epochs = 1
lr = 1e-2
learn.fit_one_cycle(num_of_epochs,lr,cbs=
                    [ShowGraphCallback,
                     SaveModelCallback(monitor='mae_loss_wgtd',fname='best_model')])

## Load the best model

In [None]:
learn.load('best_model',device=device)
learn.model = learn.model.to(device)

## Get Predictions for the Test Dataset

In [None]:
test_dl = DataLoader(test_dataset, batch_size=1,shuffle=False,device=device)
m_pred,m_true = learn.get_preds(dl=test_dl,reorder=False)

## Plot the results

In [None]:
test_mae = mae_loss_wgtd(m_pred,m_true)
plt.figure(figsize=(6,6),dpi=100)
plt.scatter(m_true, m_pred,  color='black')
line = np.linspace(0, 6, 100)
plt.plot(line, line)
plt.xlabel('Observed mass')
plt.ylabel('Predicted mass')
plt.text(1,4, 'MAE: {:.4f}'.format(test_mae))