Investigating the $L_{n,\boldsymbol y}$ matrix
==

In [None]:
%matplotlib notebook
import mushi
import histories
import utils
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
from numpy.linalg import cond, svd
from matplotlib.colors import LogNorm

### define a flat $\eta(t)$

In [None]:
# time grid
t = np.concatenate(([0], np.logspace(-1, 4, 100), [np.inf]))
# η(t) values in each epoch
y = 100 * np.ones(len(t) - 1)
# histories object
η = histories.η(t[1:-1], y)

### The condition numbers of $L_{n,\mathbf y}$ for several values of $n$
Things go to 💩 for $n$ larger than a couple hundred or so. This is a problem. Most of the badness is from $C$.

In [None]:
# vector of different sample sizes n
n_array = np.logspace(1, 3, 100, dtype=np.int)

def L(n: int):
    return (utils.C(n) @ utils.M(n, t, y)).astype('float64')
condition_numbers = [cond(L(n)) for n in n_array]

In [None]:
plt.figure(figsize=(5, 3))
plt.plot(n_array, condition_numbers)
plt.xlabel('$n$')
plt.ylabel('condition number')
# plt.xscale('log')
plt.yscale('log')
plt.tight_layout()
plt.show()

### Singular value spectrum of $L_{n,\mathbf y}$ with $n=100$

In [None]:
n = 100
sfs = mushi.kSFS(η, n=n)
U, σ, Vh = svd(sfs.L.astype('float64'), full_matrices=False)
plt.figure(figsize=(6, 2))
plt.plot(np.arange(1, min(n, η.m + 1)), σ, '.')
plt.yscale('log')
plt.ylabel('singular value')    
plt.show()

### Top few right singular vectors

In [None]:
df = pd.DataFrame(Vh.T, index=pd.Index(t[:-1], name='time'))
df = df.melt().set_index('variable')
df['singular value'] = σ[df.index]
df['time'] = np.tile(t[:-1], df.index[-1] + 1)

# filter to top 20 singular values
df = df[df.index < 20]

In [None]:
plt.figure(figsize=(10, 3))
ax = sns.lineplot(x='time', y='value', hue='singular value', units='singular value',
                  data=df, estimator=None,
                  palette=sns.color_palette("RdBu", n_colors=df['singular value'].nunique()),
                  legend=False)
plt.tight_layout()
plt.xscale('log')
plt.show()

### TMRCA CDF

In [None]:
plt.figure(figsize=(3, 2))
plt.plot(η.change_points, sfs.tmrca_cdf())
plt.xlabel('$t$')
plt.xscale('symlog')
plt.ylim([0, 1])
plt.ylabel('TMRCA CDF')
plt.tight_layout()
plt.show()