In [None]:
import functools
import json
from pathlib import Path
import pickle

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from multihist import Hist1d, Histdd
import pandas as pd
import PIL
from scipy import stats
from tqdm import tqdm

import deepdarksub as dds

import lenstronomy as ls
import manada
import fastai.vision.all as fv
import torch

data_dir = dds.MANADA_ROOT.parent / 'datasets' / 'd_los_sigma_sub'

ls.laconic()
meta, galaxy_indices = dds.load_metadata(data_dir, remove_bad=False)
manada_config = dds.load_manada_config()

result_dir = Path('.') / 'train_results'
result_dir.mkdir(exist_ok=True)

print(torch.cuda.get_device_name(torch.cuda.current_device()), torch.cuda.is_available())

# Setup

In [None]:
uncertainty = 'diagonal'

mp = 'main_deflector_parameters_'
fit_parameters = {
    mp + 'theta_E': "Einstein radius [arcsec]",
    'subhalo_parameters_sigma_sub': r'$\Sigma_\mathrm{sub}$',
    #'los_parameters_delta_los': r'$\delta_\mathrm{LOS}$',
  
# Make sure to turn off rotation augmentation when fitting these
# (or reformulate them in rotation-invariant variables, or find out how to 
# transform y values in fastai)
#     mp + 'center_x': "Main deflector $x$",
#     mp + 'center_y': "Main deflector $y$",
#     mp + 'gamma': "Main deflector $\gamma$",
#     mp + 'gamma1': "Main deflector $\gamma_1$",
#     mp + 'gamma2': "Main deflector $\gamma_2$",
#     mp + 'e1': "Main deflector $e_1$",
#     mp + 'e2': "Main deflector $e_2$",
#     'source_scaled_flux_radius': "Source flux radius [arcsec]",
#     'source_parameters_sersicfit_n': "Source best-fit Sersic index",
#     'source_parameters_sersicfit_q': "Source best-fit Sersic axis ratio",
}

n_params = len(fit_parameters)

debiased = False
with open('debiasing_weights.pkl', mode='rb') as f:
    meta['training_weight'] = pickle.load(f)
    debiased = True

In [None]:
dblock = dds.data_block(
    meta, fit_parameters,
    data_dir,
    uncertainty=uncertainty,
    augment_rotation='free',
)
# Datablock debugging
#dblock.summary(None)

In [None]:
batch_size = 256  # 512
dls = dblock.dataloaders(None, bs=batch_size)
dls.valid.show_batch()

In [None]:
# # Note the pixel value distribution is very skewed:
# bla = dls.train.one_batch()[0].cpu().numpy()
# Hist1d(bla[:,0,:,:].ravel(), bins=400).plot()
# #plt.yscale('log')

Some source images still have a large mean value. Should we subtract that so zero-padding makes sense? Or is something else going on?

In [None]:
# Show image
# img = dls.valid.one_batch()[0][0].cpu().numpy()
# import deepdarksub as dds
# dds.plot_image(img[0], vmin=1e-6)return 

In [None]:
# Check train & val shapes match
# dls.train.one_batch()[0][0].shape, dls.valid.one_batch()[0][0].shape, 

In [None]:
n_epochs = 25
architecture = fv.xresnet18
pretrained = False
freeze_epochs = 2 if len(meta) < 20_000 else 1   # Only relevant from pretrained

result_name = '_'.join([architecture.__name__, '%04d' % n_epochs, data_dir.stem])
if not pretrained:
    result_name += '_fromscratch'
if debiased:
    result_name += '_debiased'

if 'subhalo_parameters_sigma_sub' in fit_parameters:
    _sigma_sub_i = list(fit_parameters).index('subhalo_parameters_sigma_sub')
    def rho_sub(x, y):
        return stats.pearsonr(x[:, _sigma_sub_i],  y[:, _sigma_sub_i])[0]

learn = fv.cnn_learner(
    dls, architecture,
    n_in=1,
    n_out=dds.n_out(n_params, uncertainty),
    loss_func=dds.loss_for(n_params, uncertainty),
    metrics=[fv.AccumMetric(rho_sub, to_np=True, flatten=False)],
    pretrained=pretrained)
#learn.summary()

In [None]:
# To save a learner for use outside this notebook,
# first recreate the learner without metrics,
# then skip training and instead do:
# learn = learn.load(result_name)
# learn.export(result_name)

### Training

In [None]:
learn.lr_find()

In [None]:
if pretrained:
    base_lr = 0.002 * batch_size / 64
    # base_lr = 0.002  # default for fine_tune
    learn.fine_tune(n_epochs, base_lr=base_lr, freeze_epochs=freeze_epochs)
else:
    #base_lr = 1e-4  # From-scratch densenet needs lower LR
    # https://github.com/fastai/fastai/blob/75f4c17dc019aee9a0af08bd458a56e00d7393f8/fastai/learner.py#L18
    base_lr = 0.001 * batch_size / 64
    learn.fit_one_cycle(n_epochs, lr_max=base_lr)

0.09

## Save results

In [None]:
out = dict()
out.update(**dict(zip(
    ['train_loss', 'val_loss', 'rho_sub'], 
    np.stack(learn.recorder.values).T.tolist())))
out['train_loss_hr'] = [x.numpy().item() for x in learn.recorder.losses]
out['epoch_duration'] = learn.recorder.log[-1]   # only last epoch... oh well
out['architecture'] = architecture.__name__
out['freeze_epochs'] = freeze_epochs
out['base_lr'] = base_lr
out['batch_size'] = batch_size
out['pretrained'] = pretrained
out['fit_parameters'] = fit_parameters
out['uncertainty'] = uncertainty
out['n_images'] = len(meta)
out['n_epochs'] = n_epochs

with open(result_dir / (result_name + '.json'), mode='w') as f:
    json.dump(out, f)

learn.save(result_name)

## Quick Evaluation

In [None]:
normalizer = dds.Normalizer(meta, fit_parameters)

preds, targets = learn.get_preds(reorder=False)
y_pred, y_unc = normalizer.decode(preds, uncertainty=uncertainty, as_dict=True)
y_true = {p: meta[meta['is_val']][p].values 
          for p in fit_parameters}
#y_true, _ = normalizer.decode(targets)

In [None]:
# plt.scatter(
#     #y_pred['main_deflector_parameters_theta_E'],
#     meta[meta['is_val']]['main_deflector_parameters_theta_E'],
#     #meta[meta['is_val']]['subhalo_parameters_sigma_sub'],
#     y_pred['subhalo_parameters_sigma_sub'],
#     s=5, marker='.', edgecolors='none', c='k'
# )

In [None]:
# err, corr = zip(*[cov_to_std(q) for q in y_cov])
# err, corr = np.stack(err), np.stack(corr)
# y_unc = {p: err[:,i] for i, p in enumerate(fit_parameters)}
# Hist1d(corr[:,0,1], bins=100).plot()

In [None]:
n_rows = 3 # int(np.ceil(n_params**0.5))
n_cols = int(np.ceil(n_params/n_rows))
fsize = 3.5
f, axes = plt.subplots(n_rows, n_cols, 
                       figsize=(n_cols * fsize, n_rows * fsize))
axes_flat = axes.ravel()

for i, (p, label) in enumerate(fit_parameters.items()):
    
    r, p_r = stats.pearsonr(y_pred[p], y_true[p])
    rmse = np.mean((y_pred[p] - y_true[p])**2)**0.5
    
    ax = axes_flat[i]
    plt.sca(ax)
    plt.scatter(y_true[p], y_pred[p], 
                c=y_unc[p] if uncertainty == 'diagonal' else 'b',
                s=1, marker='.', #edgecolors='none',
                #alpha=0.1,
                cmap=plt.cm.Blues_r
               )
    ax.set_aspect('equal')
    xlim, ylim = plt.xlim(), plt.ylim()
    mi, ma = min(xlim[0], ylim[0]), max(xlim[1], ylim[1])
    plt.plot([mi, ma], [mi, ma], color='k', linewidth=1, alpha=0.5)
    #plt.axhline(np.median(y_true[p]), color='k', linewidth=1, alpha=0.5, linestyle='--')
    plt.xlim(*xlim), plt.ylim(*ylim)
    # plt.colorbar(label='Uncertainty')
    # plt.title(p)
    plt.xlabel(label, fontsize=14)
    
    plt.text(0.5, 1.1, 
             fr"$\rho={r:.03f} \;\; " + 
                 (fr" (p={p_r:.02f})$" 
                  if p_r > 0.005 
                  else fr"\mathrm{{RMSE}}={rmse:.3f}$"),
             ha='center',
             transform=ax.transAxes)
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

for ax in axes_flat[i + 1:]:
    ax.axis('off')

plt.subplots_adjust(hspace=0.5)
plt.savefig(result_dir / (result_name + '.png'), dpi=200, bbox_inches='tight')
#plt.savefig('10minutes_many_parameters.png', dpi=200, bbox_inches='tight')
plt.show()