# Classifying Population Dynamics

In [1]:
%run ../src/dataset.py
%run ../src/augmentation.py
%run ../src/simulation.py
%run ../src/nets.py
%run ../src/trainer.py
%run ../src/utils.py

Train a model based on Wright-Fisher simulations (without age structure).

In [2]:
classifier = TimeseriesClassifier(
    classifier=MultiLayerPerceptron(128, layers=(64, 64, 64)),
    embedder=ResNet(1)
)

The following cell provides some provisional parameter priors to generate samples:

In [8]:
from torch.distributions import Uniform, Normal

neutral_simulator = Simulator(
    n_agents=1000,
    timesteps=10,
    age_window=None,
    disable_pbar=True,
    summarize=True
)

prior = IndependentPriors(
    Normal(0., 0.1),      # beta prior
    Uniform(0.0001, 0.1), # mu prior
    Uniform(0.1, 0.5)     # p_death prior
)

x = prior.sample()
print(x)
simulator(x)

tensor([-0.0779,  0.0025,  0.2322])


array([[56.        , 44.41817557, 38.27924108, 34.67361245, 32.27472244,
        30.5183134 , 29.14879624]])

Train the actual model. In each epoch, we simulate 10k samples with different parameters, $\theta$. The training procedure can probably be improved, but it already produces reasonable results.

In [6]:
classifier, train_loader, val_loader = train(
    simulator=neutral_simulator,
    prior=prior,
    num_simulations=10000,
    classifier=classifier, 
    n_epochs=40, 
    batch_size=500, 
    device="cuda", 
    learning_rate=0.001,
    num_workers=40
)

Epoch  1: train loss = 0.44, val loss = 0.60, AUC = 0.89[32m ++[0m
Epoch  2: train loss = 0.32, val loss = 0.32, AUC = 0.92[32m ++[0m
Epoch  3: train loss = 0.31, val loss = 0.32, AUC = 0.93[31m --[0m
Epoch  4: train loss = 0.31, val loss = 0.31, AUC = 0.93[32m ++[0m
Epoch  5: train loss = 0.30, val loss = 0.33, AUC = 0.93[31m --[0m
Epoch  6: train loss = 0.30, val loss = 0.29, AUC = 0.93[32m ++[0m
Epoch  7: train loss = 0.29, val loss = 0.28, AUC = 0.93[32m ++[0m
Epoch  8: train loss = 0.30, val loss = 0.28, AUC = 0.93[32m ++[0m
Epoch  9: train loss = 0.30, val loss = 0.29, AUC = 0.93[31m --[0m
Epoch 10: train loss = 0.30, val loss = 0.28, AUC = 0.94[32m ++[0m
Epoch 11: train loss = 0.28, val loss = 0.28, AUC = 0.94[31m --[0m
Epoch 12: train loss = 0.29, val loss = 0.28, AUC = 0.94[32m ++[0m
Epoch 13: train loss = 0.29, val loss = 0.29, AUC = 0.93[31m --[0m
Epoch 14: train loss = 0.29, val loss = 0.31, AUC = 0.93[31m --[0m
Epoch 15: train loss = 0.28, val l

In [56]:
classifier.eval()

def classify(theta, simulator, num_experiments=100):
    x = np.vstack([simulator(theta) for _ in range(num_experiments)])
    with torch.no_grad():
        output = classifier(torch.FloatTensor(x).unsqueeze(1).to("cuda"))
    print(f"p(x|theta=1) = {(output > 0.5).float().mean().item():.3f}")

Test the performance of the model by classifying some generated samples. Given some $\theta_i$, we simulate 100 samples, and use the trained classifier to 'predict' whether they are examples of neutral evolution or not. The function returns the fraction of biased samples. 

First, let's test whether the model accurately classifies neutral samples to be neutral:

In [57]:
theta = torch.tensor([0., 0.01, 0.5])
classify(theta, neutral_simulator)

p(x|theta=1) = 0.020


That seems to be working. And what about non-neutral, biased samples?

In [58]:
theta = torch.tensor([0.1, 0.01, 0.5])
classify(theta, neutral_simulator)

p(x|theta=1) = 1.000


Also not bad. 

Next, we generate samples with the age-structured model and see whether the model (which was only trained on samples _without_ age structure) can still accurately discriminate between neutral an non-neutral samples:

In [24]:
age_structured_simulator = Simulator(
    n_agents=1000, 
    timesteps=1,
    age_window=(0, 2),
    disable_pbar=True,
    summarize=True
)

In [59]:
theta = torch.tensor([0., 0.01, 0.5])
classify(theta, age_structured_simulator)

p(x|theta=1) = 0.010


In [60]:
theta = torch.tensor([0.1, 0.01, 0.5])
classify(theta, age_structured_simulator)

p(x|theta=1) = 0.990


The first preliminary results suggest that the performance of the classifier is not strongly affected by age structure. Let's look at that in some more detail.