# Evolving surfaces

In [None]:
%matplotlib inline

In [None]:
%run notebook_setup.py

In [None]:
import starry
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import colors
import time
from scipy.interpolate import interp1d
import theano
import theano.tensor as tt
import theano.sparse as ts
import pymc3 as pm
import pymc3.distributions.transforms as tr
import exoplanet as exo
from scipy.sparse import csr_matrix, csc_matrix

In [None]:
starry.config.lazy = False
starry.config.quiet = True
np.random.seed(0)

## Generate

Containers:

In [None]:
class Truth(object):
    pass


truth = Truth()

In [None]:
class Data(object):
    pass


data = Data()

Parameters:

In [None]:
truth.ydeg = 20
truth.inc = 85.0
truth.prot = 1.129337
truth.alpha = 0.1

truth.nspots = 30
truth.tau_mu = 20 * truth.prot
truth.tau_sig = 1.0

data.tmax = 100
data.tpad = 50.0
data.npts = 1001
data.t = np.sort(
    np.linspace(0, data.tmax, data.npts)
    + (1e-3 * data.tmax / data.npts) * np.random.randn(data.npts)
)
data.ferr = 1e-4

Generate the spot expansions:

In [None]:
map = starry.Map(truth.ydeg)

# Spot latitude distribution: isotropic
lat = lambda: (np.arccos(2 * np.random.random() - 1) - 0.5 * np.pi) * 180 / np.pi

# Spot longitude distribution: isotropic
lon = lambda: 360.0 * np.random.random()

# Spot size distribution
sigma = lambda: max(0.01, np.exp(-3.5 + 0.4 * np.random.randn()))

# Spot intensity distribution
intensity = lambda: -min(0.5, np.exp(-3 + 0.5 * np.random.randn()))

# Generate the Ylm coeffs for each spot
truth.y = np.empty((truth.nspots, (truth.ydeg + 1) ** 2))
truth.lats = np.zeros(truth.nspots)
for n in tqdm(range(truth.nspots)):
    map.reset()
    truth.lats[n] = lat()
    map.add_spot(lat=truth.lats[n], lon=lon(), sigma=sigma(), intensity=intensity())
    truth.y[n] = map.amp * map.y
truth.y[:, 0] = 0

Compute the empirical mean and covariance of this distribution:

In [None]:
N = 99999
y = np.empty((N, (truth.ydeg + 1) ** 2 - 1))
for n in tqdm(range(N)):
    map.reset()
    map.add_spot(lat=lat(), lon=lon(), sigma=sigma(), intensity=intensity())
    y[n] = map.amp * map.y[1:]

truth.ymu = np.mean(y, axis=0)
truth.ycov = np.cov(y.T)
truth.ycov[np.diag_indices_from(truth.ycov)] += 1e-12

Draw and visualize a sample:

In [None]:
map.reset()
map[1:, :] = np.random.multivariate_normal(truth.ymu, truth.ycov)
map.show(projection="moll", colorbar=True)

Here are our actual spots:

In [None]:
nx = 1 + int(np.ceil(np.sqrt(truth.nspots)))
ny = 1
while ny * nx < truth.nspots:
    ny += 1
fig, ax = plt.subplots(ny, nx, figsize=(12, 5))
ax = ax.flatten()
for axis in ax:
    axis.axis("off")
for k in tqdm(range(truth.nspots)):
    map.reset()
    map[1:, :] = truth.y[k, 1:]
    img = np.pi * map.render(projection="moll", res=100)
    ax[k].imshow(
        img,
        origin="lower",
        extent=(-1, 1, -0.5, 0.5),
        cmap="Greys_r",
        vmin=0.9,
        vmax=1,
    )
    x_el = np.linspace(-1, 1, 1000)
    y_el = 0.5 * np.sqrt(1 - x_el ** 2)
    ax[k].plot(x_el, y_el, "k-", lw=1, clip_on=False)
    ax[k].plot(x_el, -y_el, "k-", lw=1, clip_on=False)

Get the spot timescales:

In [None]:
truth.tau = truth.tau_mu + truth.tau_sig * np.random.randn(truth.nspots)

In [None]:
plt.hist(truth.tau)
plt.xlabel("timescale [days]");

Get the spot emergence times:

In [None]:
truth.t0 = np.sort((data.tmax + data.tpad) * np.random.random(truth.nspots) - data.tpad)

In [None]:
plt.hist(truth.t0)
plt.xlabel("emergence time [days]");

Get the spot amplitudes as a function of time:

In [None]:
truth.a = np.exp(
    -((data.t.reshape(1, -1) - truth.t0.reshape(-1, 1)) ** 2)
    / (2 * truth.tau.reshape(-1, 1) ** 2)
)

In [None]:
plt.imshow(
    truth.a, aspect="auto", extent=(0, data.tmax, truth.nspots, 0), vmin=0, vmax=1
)
plt.colorbar(label="amplitude")
plt.plot(truth.t0, 0.5 + np.arange(truth.nspots), "w|", ms=7.5)
plt.xlim(0, data.tmax)
plt.xlabel("time [days]")
plt.ylabel("spot number");

Visualize the star:

In [None]:
def get_movie(
    t=data.t,
    y=truth.y,
    lats=truth.lats,
    prot=truth.prot,
    alpha=truth.alpha,
    a=truth.a,
    downsamp=10,
    res=300,
):

    # Instantiate a map of the right degree
    map = starry.Map(ydeg=np.sqrt(y.shape[1]) - 1)

    # Theano function for rendering one spot
    def _render(y, theta, res):
        """Render the map on a Mollweide grid."""
        # Compute the Cartesian grid
        xyz = map.ops.compute_moll_grid(res)[-1]

        # Compute the polynomial basis
        pT = map.ops.pT(xyz[0], xyz[1], xyz[2])

        # Rotate the map
        Ry = map.ops.left_project(
            tt.transpose(tt.tile(y, [theta.shape[0], 1])),
            np.array(0.5 * np.pi),
            np.array(0.0),
            theta,
            np.array(0.0),
            np.array(np.inf),
            np.array(0.0),
        )

        # Change basis to polynomials
        A1Ry = ts.dot(map.ops.A1, Ry)

        # Dot the polynomial into the basis
        res = tt.reshape(tt.dot(pT, A1Ry), [res, res, -1])

        # We need the shape to be (nframes, npix, npix)
        return res.dimshuffle(2, 0, 1)

    # Compile the theano function
    with theano.configparser.change_flags(compute_test_value="off"):
        _y = tt.dvector()
        _theta = tt.dvector()
        _res = tt.iscalar()
        render_spot = theano.function([_y, _theta, _res], _render(_y, _theta, _res))

    # Sum the contribution from each spot in the co-rotating frame
    nim = len(t[::downsamp])
    img = np.ones((nim, res, res))
    theta_eq = 2 * np.pi / prot * t
    theta = theta_eq.reshape(1, -1) * (
        1 - alpha * np.sin(np.pi / 180 * lats.reshape(-1, 1)) ** 2
    )
    theta_diff = theta - theta_eq.reshape(1, -1)
    for k in tqdm(range(len(lats))):
        imgk = np.pi * render_spot(y[k], theta_diff[k, ::downsamp], res)
        img += a[k, ::downsamp].reshape(-1, 1, 1) * imgk

    return img

In [None]:
truth.movie = get_movie()
map.show(image=truth.movie, projection="moll", colorbar=True)

Get the light curve:

In [None]:
def get_model(
    t=data.t,
    y=truth.y,
    lats=truth.lats,
    prot=truth.prot,
    inc=truth.inc,
    alpha=truth.alpha,
    a=truth.a,
):

    # Instantiate a map of the right degree
    map = starry.Map(ydeg=np.sqrt(y.shape[1]) - 1, inc=inc)

    # Angular phases of each spot
    theta_eq = 360.0 / prot * t
    theta = theta_eq.reshape(1, -1) * (
        1 - alpha * np.sin(np.pi / 180 * lats.reshape(-1, 1)) ** 2
    )

    # Sum the contribution from each spot
    model = np.ones_like(t)
    for k in range(len(lats)):
        model += a[k] * map.design_matrix(theta=theta[k]).dot(y[k])

    return model

In [None]:
def plot_lc(t, fluxes, styles=None, nrow=5, ncol=3, figsize=(12, 10)):

    fig = plt.figure(figsize=figsize)
    ax_main = plt.subplot2grid((nrow, ncol), (0, 0), colspan=ncol, rowspan=2)
    ax_sub = [
        plt.subplot2grid((nrow, ncol), (2 + i, j))
        for i in range(nrow - 2)
        for j in range(ncol)
    ]
    nsub = len(ax_sub)
    npts = len(t)

    if styles is None:
        styles = [dict() for flux in fluxes]
    for flux, style in zip(fluxes, styles):
        ax_main.plot(t, flux, **style)

        for k, ax in enumerate(ax_sub):

            a = int(k * npts / nsub)
            b = int((k + 1) * npts / nsub)
            ax.plot(t[a:b], flux[a:b], **style)

    ax_main.legend(fontsize=8, loc="lower left")

    for label in ax_main.get_yticklabels() + ax_main.get_xticklabels():
        label.set_fontsize(10)
    for ax in ax_sub:
        for label in ax.get_yticklabels() + ax.get_xticklabels():
            label.set_fontsize(8)
    ax_main.set_ylabel("flux")
    for ax in ax_sub[-ncol:]:
        ax.set_xlabel("time [days]", fontsize=12)
    for ax in ax_sub[::ncol]:
        ax.set_ylabel("flux", fontsize=12)

In [None]:
truth.flux0 = get_model(alpha=0)
truth.flux = get_model(alpha=truth.alpha)

In [None]:
plot_lc(
    data.t,
    [truth.flux0, truth.flux],
    styles=[
        dict(color="C1", lw=1, alpha=0.5, label="solid"),
        dict(color="C0", lw=2, label="diff rot"),
    ],
)

Generate the dataset:

In [None]:
data.flux = truth.flux + data.ferr * np.random.randn(len(truth.flux))

In [None]:
plot_lc(
    data.t,
    [truth.flux, data.flux],
    styles=[
        dict(color="C0", lw=1, alpha=0.5, label="true"),
        dict(color="k", ls="None", marker=".", ms=2, label="observed"),
    ],
)

## Inference

In [None]:
ydeg = 10
N = (ydeg + 1) ** 2
nnodes = 60

# The dataset
t = data.t
flux = data.flux
ferr = data.ferr
npts = data.npts

# Our priors
ymu = truth.ymu[: N - 1]
ycov = truth.ycov[: N - 1, : N - 1]

# Things we'll assume we know
inc = truth.inc
prot = truth.prot

In [None]:
# Pre-compute the starry design matrix
map = starry.Map(ydeg, inc=inc)
theta_x = 360.0 / prot * t
X = map.design_matrix(theta=theta_x)[:, 1:]

In [None]:
# Pre-compute the interpolation matrix (linear)
tnodes = np.linspace(t[0], t[-1], nnodes)
dt = tnodes[1] - tnodes[0]
diags = np.zeros((nnodes, npts))
for k in range(nnodes):
    w = 1 - np.abs(tnodes[k] - t) / dt
    w[w < 0] = 0
    diags[k] = w
I = np.hstack([np.diag(diag) for diag in diags])

# Visualize
I = np.hstack([np.diag(diag) for diag in diags])
plt.figure()
for diag in diags:
    plt.plot(diag)
plt.figure()
f = np.array(I)
f[f == 0] = np.nan
f[~np.isnan(f)] = 1.0
plt.imshow(f, aspect="auto");

In [None]:
# Pre-compute the interpolation matrix (cubic)
if False:

    tnodes = np.linspace(t[0], t[-1], nnodes)
    dt = tnodes[1] - tnodes[0]
    diags = np.zeros((nnodes, npts))

    # Catmull–Rom spline
    M = np.array(
        [[0, 1, 0, 0], [-0.5, 0, 0.5, 0], [1, -2.5, 2, -0.5], [-0.5, 1.5, -1.5, 0.5]]
    )

    for k in range(1, nnodes - 2):
        u = (t - tnodes[k]) / dt
        idx = (u >= 0) & (u < 1)
        u = u[idx]
        U = np.vander(u, N=4, increasing=True)
        A = U.dot(M)
        diags[k - 1 : k + 3, idx] += A.T

    # Linear interpolation at endpoints
    for k in [0, -1]:
        w = 1 - np.abs(tnodes[k] - t) / dt
        w[w < 0] = 0
        diags[k] = w
    w = 1 - (tnodes[1] - t) / dt
    w[w < 0] = 0
    w[w > 1] = 0
    diags[1] += w
    w = 1 + (tnodes[-2] - t) / dt
    w[w > 1] = 0
    diags[-2] += w

    # Visualize
    I3 = np.hstack([np.diag(diag) for diag in diags])
    plt.figure()
    for diag in diags:
        plt.plot(diag)
    plt.figure()
    f = np.array(I3)
    f[f == 0] = np.nan
    f[~np.isnan(f)] = 1.0
    plt.imshow(f, aspect="auto");

In [None]:
# The full design matrix
from scipy.linalg import block_diag

XL = block_diag(*[X for n in range(nnodes)])
A = I.dot(XL)

In [None]:
f = np.array(A)
f[f == 0] = np.nan
plt.imshow(f, aspect="auto")
plt.colorbar();

In [None]:
L1 = np.array(ycov)
plt.imshow(L1)
plt.colorbar();

In [None]:
amp = 1
tau = 2.5
k = np.arange(nnodes).reshape(1, -1) - np.arange(nnodes).reshape(-1, 1)
L2 = amp * np.exp(-0.5 * (k * dt / tau) ** 2)
plt.imshow(L2)
plt.colorbar();

In [None]:
L = np.kron(L2, L1)
L += 1e-12 * np.eye(L.shape[0])

In [None]:
fig = plt.figure(figsize=(12, 12))

f = np.array(L)
vmax = np.max(np.abs(f))
f /= vmax
f[np.abs(f) < 1e-5] = 0

imm = plt.imshow(np.log10(-f), cmap="Blues", vmin=-5, vmax=0)
cbm = plt.colorbar(shrink=0.65)
cbm.set_ticks([-5, -4, -3, -2, -1, 0])
cbm.set_ticklabels([r"$-10^{%d}$" % n for n in cbm.get_ticks()])
imp = plt.imshow(np.log10(f), cmap="Reds", vmin=-5, vmax=0)
cbp = plt.colorbar(shrink=0.65)
cbp.set_ticks([-5, -4, -3, -2, -1, 0])
cbp.set_ticklabels([r"$10^{%d}$" % n for n in cbm.get_ticks()])

In [None]:
mu = np.concatenate([ymu for n in range(nnodes)])
y_guess, y_guess_cov = starry.linalg.solve(A, flux - 1.0, C=ferr ** 2, mu=mu, L=L)

In [None]:
model_guess = 1 + A.dot(y_guess)
model_sig = np.diag(A.dot(y_guess_cov).dot(A.T))

In [None]:
plot_lc(
    t,
    [flux, model_guess],
    styles=[
        dict(color="k", ls="None", marker=".", ms=2, label="observed"),
        dict(color="C0", lw=1, alpha=0.5, label="MAP"),
    ],
    nrow=6,
    ncol=3,
    figsize=(12, 12),
)

In [None]:
def get_y(y, time):
    Y = y.reshape(nnodes, N - 1)
    if time < tnodes[0]:
        return Y[0]
    elif time >= tnodes[-1]:
        return Y[-1]
    k = np.argmin(time >= tnodes) - 1
    return ((time - tnodes[k]) * Y[k + 1] + (tnodes[k + 1] - time) * Y[k]) / (
        tnodes[k + 1] - tnodes[k]
    )

In [None]:
def _get_image(_y):
    map = starry.Map(ydeg, lazy=True)
    map[1:, :] = _y
    return np.pi * map.render(projection="rect", res=300)


with theano.configparser.change_flags(compute_test_value="off"):
    _y = tt.dvector()
    get_image = theano.function([_y], _get_image(_y))


downsamp = 2
nim = len(t[::downsamp])
img = np.zeros((nim, 300, 300))
for k in tqdm(range(nim)):
    img[k] = get_image(get_y(y_guess, t[::downsamp][k]))

In [None]:
map.show(image=img, projection="rect", colorbar=True, interval=15)

In [None]:
lat, lon = map.get_latlon_grid(300, projection="rect")

In [None]:
lag = 30

Z = np.zeros((nim, 300)) * np.nan
for k in range(300):

    if np.abs(lat[k][0]) > 80:
        Z[:, k] = np.nan
        continue

    f = img[:, k, :]
    idx = ~np.isnan(f[0])
    for j in range(lag, nim):
        f0 = f[j - lag][idx]
        fj = f[j][idx]
        if len(fj):
            corr = correlate(np.tile(f0, 2), fj, mode="valid")
            cc = len(corr) - 1 - np.argmax(corr)
            if cc > len(corr) // 2:
                cc -= len(corr) - 1
            Z[j, k] = cc * np.nanmean(np.diff(lon[k]))

In [None]:
fig, ax = plt.subplots(3, figsize=(10, 12), sharex=True)

vmax = max(-np.nanmin(Z), np.nanmax(Z))
vmin = -vmax
im = ax[0].imshow(
    Z, aspect="auto", extent=(-90, 90, data.tmax, 0), vmin=vmin, vmax=vmax, cmap="RdBu",
)
plt.colorbar(im, ax=ax[0])

mean = np.nanmean(Z, axis=0)
med = np.nanmedian(Z, axis=0)
std = np.nanstd(Z, axis=0)
ax[1].plot(lat[:, 0], mean)
ax[1].plot(lat[:, 0], med)
ax[1].set_xlim(-90, 90)
cb_ = plt.colorbar(im, ax=ax[1])
cb_.ax.set_visible(False)

delt = t[lag] - t[0]
signal = -truth.alpha * 360.0 / truth.prot * delt * np.sin(lat[:, 0] * np.pi / 180) ** 2
ax[2].plot(lat[:, 0], signal)
ax[2].set_xlim(-90, 90)
cb_ = plt.colorbar(im, ax=ax[2])
cb_.ax.set_visible(False)

Baseline:

In [None]:
map.load("earth")
map.alpha = 1
img = map.render(projection="rect", theta=np.linspace(0, 360.0, nim))
map.alpha = 0.0

In [None]:
map.show(image=img, projection="rect", colorbar=True, interval=15)

In [None]:
lat, lon = map.get_latlon_grid(300, projection="rect")

In [None]:
lag = 30

Z = np.zeros((nim, 300)) * np.nan
for k in range(300):

    if np.abs(lat[k][0]) > 80:
        Z[:, k] = np.nan
        continue

    f = img[:, k, :]
    idx = ~np.isnan(f[0])
    for j in range(lag, nim):
        f0 = f[j - lag][idx]
        fj = f[j][idx]
        if len(fj):
            corr = correlate(np.tile(f0, 2), fj, mode="valid")
            cc = len(corr) - 1 - np.argmax(corr)
            if cc > len(corr) // 2:
                cc -= len(corr) - 1
            Z[j, k] = cc * np.nanmean(np.diff(lon[k]))

In [None]:
fig, ax = plt.subplots(3, figsize=(10, 12), sharex=True)

vmax = max(-np.nanmin(Z), np.nanmax(Z))
vmin = -vmax
im = ax[0].imshow(
    Z, aspect="auto", extent=(-90, 90, data.tmax, 0), vmin=vmin, vmax=vmax, cmap="RdBu",
)
plt.colorbar(im, ax=ax[0])

mean = np.nanmean(Z, axis=0)
med = np.nanmedian(Z, axis=0)
std = np.nanstd(Z, axis=0)
ax[1].plot(lat[:, 0], mean)
ax[1].plot(lat[:, 0], med)
ax[1].set_xlim(-90, 90)
cb_ = plt.colorbar(im, ax=ax[1])
cb_.ax.set_visible(False)

signal = -360.0 / nim * lag * np.sin(lat[:, 0] * np.pi / 180) ** 2
ax[2].plot(lat[:, 0], signal)
ax[2].set_xlim(-90, 90)
cb_ = plt.colorbar(im, ax=ax[2])
cb_.ax.set_visible(False)

Sampling

In [None]:
breakpoint()

In [None]:
with pm.Model() as model:

    m = tt.ones(npts)
    y = tt.zeros((nnodes, N))
    for k in range(nnodes):
        y = tt.set_subtensor(
            y[k],
            pm.MvNormal(
                "y{}".format(k),
                ymu,
                ycov,
                shape=(N,),
                testval=y_guess.reshape(nnodes, -1)[k],
            ),
        )

    m = 1.0 + tt.dot(A, tt.flatten(y))
    pm.Deterministic("m", m)
    m_guess = exo.eval_in_model(m)

    # Likelihood
    pm.Normal("obs", mu=m, sd=ferr, observed=flux)

In [None]:
with model:
    map_soln = exo.optimize(options=dict(maxiter=399))

In [None]:
plot_lc(
    t,
    [flux, m_guess, map_soln["m"]],
    styles=[
        dict(color="k", ls="None", marker=".", ms=2, label="observed"),
        dict(color="C1", lw=2, alpha=0.5, label="guess"),
        dict(color="C0", lw=1, alpha=1, label="MAP"),
    ],
    nrow=6,
    ncol=3,
    figsize=(12, 12),
)