In [1]:
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 2000)
pd.set_option('max_colwidth', 400)

In [9]:
df = pd.read_csv("../results/RIMB_wide_search3.csv")
# df = pd.read_csv("../results/RIMB_wide_search1.csv")
# df = pd.read_csv("../results/RIMB_targeted_search.csv")


# df = pd.concat([
#     pd.read_csv("../results/RIMB_wide_search2.csv"), 
#     pd.read_csv("../results/RIMB_wide_search1.csv"), 
#     pd.read_csv("../results/RIMB_targeted_search.csv"),
#     pd.read_csv("../results/RIMB_wide_search3.csv")
# ])
columns = [
    "oversampling_factor", 
    "activation",
    "redundant",
    "batch_size", 
    "initial_learning_rate", 
    "block_conv_layers",
    "layers", 
    "steps", 
    "filters", 
    "filter_scaling",
    "input_kernel_size",
    "residual_weights", 
    "train_cost", 
    "train_chi_squared", 
    "experiment_id",

]
df[df["total_items"] >= 10000][columns].sort_values(by="train_chi_squared")

Unnamed: 0,oversampling_factor,activation,redundant,batch_size,initial_learning_rate,block_conv_layers,layers,steps,filters,filter_scaling,input_kernel_size,residual_weights,train_cost,train_chi_squared,experiment_id


In [1]:
from exorim import RIM, PhysicalModel
from exorim.simulated_data import CenteredBinariesDataset
from exorim.definitions import DTYPE, rad2mas
from exorim.models import Model
import tensorflow as tf
import os, json
from argparse import Namespace
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from mpl_toolkits.axes_grid1 import make_axes_locatable
from ipywidgets import interactive

In [2]:
# model = "RIMB_wide_search2_005_TS8_F256_FS1.0_IK3_L2_NLtanh_B1_RWuniform_220131130235"
# model = "RIMB_wide_search2_004_TS6_F256_FS2.0_IK5_L2_NLtanh_B10_RWuniform_220131130235"
# model = "RIMB_wide_search2_008_TS10_F128_FS2.0_IK7_L2_NLtanh_B10_RWuniform_220131130237"
model = "RIMB_wide_search3_040_TS6_F128_FS2.0_IK7_L2_BCL1_NLleaky_relu_B1_RWuniform_OF3.0_220331213952"
# model = "RIMB_targeted_search_010_TS12_FS1.0_IK7_OF1.0_220324122728"
# model = "RIMB_targeted_search_013_TS12_FS2.0_IK5_OF2.0_220324123430"
model_dir = os.path.join(os.getenv("EXORIM_PATH"), "models", model)
with open(os.path.join(model_dir, "model_hparams.json"), "r") as f:
    model_hparams = json.load(f)

args = Namespace()
with open(os.path.join(model_dir, "script_params.json"), "r") as f:
    vars(args).update(json.load(f))

model = Model(**model_hparams)

pixels = args.pixels
phys = PhysicalModel(
    pixels=pixels,
    wavelength=args.wavelength,
    logim=True,
    oversampling_factor=args.oversampling_factor,
    chi_squared=args.chi_squared
)
rim = RIM(
    model=model,
    physical_model=phys,
    steps=args.steps,
    log_floor=args.log_floor,
    adam=True,
)

ckpt = tf.train.Checkpoint(step=tf.Variable(1),  net=model)
checkpoint_manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1)
checkpoint_manager.checkpoint.restore(checkpoint_manager.latest_checkpoint).expect_partial()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6ca93df9d0>

In [9]:
def super_gaussian(I, x0, y0, width):
    x = np.arange(pixels) - pixels//2 + 0.5 * (pixels%2)
    x, y = np.meshgrid(x, x)

    rho = np.hypot(x - x0, y - y0)
    im = np.exp(-0.5 * (rho/width)**2)
    im /= im.sum()
    im *= I
    return im

def simuluate_binary(phys, angle=0., contrast=10., separation=2., width=2, sigma=1e-2):
    images = np.zeros(shape=[1, pixels, pixels, 1])
    for j in range(2):
        x0 = separation * np.cos(angle + j * np.pi)/2
        y0 = separation * np.sin(angle + j * np.pi)/2
        images[0, ..., 0] += super_gaussian(1. if j == 0 else 1/contrast, x0, y0, width)

    images = images / images.sum(axis=(1, 2), keepdims=True)  # renormalize in the range [0, 1]
    images = tf.constant(images, dtype=DTYPE)
    X, sigma = phys.noisy_forward(images, np.tile(np.array(sigma)[None, None], [1, phys.nbuv]))
    return X, images, sigma

In [35]:
floor = rim.log_floor
def f(angle=0., contrast=10, separation=10, width=2, sigma=1e-2):
    _sigma = sigma
    fig, axs = plt.subplots(1, 4, figsize=(32, 8))
    X, images, sigma = simuluate_binary(phys, angle, contrast, separation, width, sigma)
    predictions, chi_squared = rim.call(X, sigma)
    print(chi_squared)
    print(X)
    
    ax = axs[0]
    im = ax.imshow(predictions[-1, 0, ..., 0], cmap="hot", vmin=np.log10(floor), vmax=0)
    ax.axis("off")
    ax.set_title("Prediction")
    
    ax = axs[1]
    im = ax.imshow(np.maximum(images[0, ..., 0], floor), cmap="hot", norm=LogNorm(vmin=floor, vmax=1, clip=True))
    ax.axis("off")
    ax.set_title("Ground Truth")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    
    fft = np.abs(np.fft.fftshift(np.fft.fft2(images[..., 0])))[0]

    uv = phys.operators.UVC
    wavel = args.wavelength
    rho = np.hypot(uv[:, 0], uv[:, 1])
    fftfreq = np.fft.fftshift(np.fft.fftfreq(phys.pixels, phys.plate_scale))

    ax = axs[2]
    im = ax.imshow(np.abs(fft) + np.random.normal(size=np.abs(fft).shape, scale=_sigma), cmap="hot", extent=[fftfreq.min(), fftfreq.max()] * 2)
    ufreq = 1 / rad2mas(1 / uv[:, 0] * wavel)
    vfreq = 1 / rad2mas(1 / uv[:, 1] * wavel)
    ax.plot(ufreq, vfreq, "bo")
    ax.set_title("UV coverage")
    ax.axis("off")
    
    cp_gt = phys.forward(images)[0, phys.nbuv:]
    cp_pred = phys.forward(predictions[-1])[0, phys.nbuv:]
    axs[3].plot(cp_gt * 180 / np.pi, color="k", label="Ground Truth")
    axs[3].plot(cp_pred * 180 / np.pi, ls="--", lw=3, color="b", label="RIM")
    axs[3].set_title("7-holes JWST mask")
    axs[3].set_xlabel("Closure triangle")
    axs[3].set_ylabel("Closure phases (degrees)")
    axs[3].legend()
    
#     cp_gt = X[0, :phys.nbuv]#phys.forward(images)[0, :phys.nbuv]
#     cp_pred = phys.forward(phys.image_link(predictions[-1]) / phys.image_link(predictions[-1]).numpy().sum())[0, :phys.nbuv]
#     axs[3].plot(cp_gt, color="k", label="Ground Truth")
#     axs[3].plot(cp_pred, ls="--", lw=3, color="b", label="RIM")
#     axs[3].set_title("7-holes JWST mask")
#     axs[3].set_xlabel("Baselines")
#     axs[3].set_ylabel("Visibility")
#     axs[3].legend()
    
interactive(f, separation=(2, 20), angle=(0, np.pi), contrast=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 100, 1000, 10000], width=[1, 2, 3, 4], sigma=[1e-1, 1e-2, 1e-3, 1e-4, 1e-8])

interactive(children=(FloatSlider(value=0.0, description='angle', max=3.141592653589793), Dropdown(description…

In [41]:
phys.image_link(predictxions[-1]).numpy().sum()

0.9561462