# FACS Training Data

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from IPython.display import HTML

from plnn.dataset import get_dataloaders

In [2]:
PCA_OR_NMF = "pca"
# PCA_OR_NMF = "nmf"

SUBDIR = "dec1"
# SUBDIR = "dec1_fitonsubset"
# SUBDIR = "dec1_varnorm"
# SUBDIR = "dec1_varnorm_fitonsubset"
# SUBDIR = "dec2"
# SUBDIR = "dec2_fitonsubset"

In [3]:
OUTDIR = f"out/facs_training_data/{SUBDIR}"
os.makedirs(OUTDIR, exist_ok=True)

In [4]:
basedatdir = f"../data/training_data/facs/{PCA_OR_NMF}"

if SUBDIR.startswith("dec1"):
    transition_idx = 1
    datdir = f"{basedatdir}/{SUBDIR}/transition1_subset_epi_tr_ce_an_pc12"
elif SUBDIR.startswith("dec2"):
    transition_idx = 2
    datdir = f"{basedatdir}/{SUBDIR}/transition2_subset_ce_pn_m_pc12"

datdir_train = f"{datdir}/training"
datdir_valid = f"{datdir}/validation"

nsims_train = np.genfromtxt(f"{datdir_train}/nsims.txt", dtype=int)
nsims_valid = np.genfromtxt(f"{datdir_valid}/nsims.txt", dtype=int)

print(f"Found {nsims_train} training simulations.")
print(f"Found {nsims_valid} validation simulations.")

Found 7 training simulations.
Found 4 validation simulations.


In [10]:
train_dloader, valid_dloader, train_dset, valid_dset = get_dataloaders(
    datdir_train=datdir_train,
    datdir_valid=datdir_valid,
    nsims_train=nsims_train,
    nsims_valid=nsims_valid,
    batch_size_train=1,
    batch_size_valid=1,
    shuffle_train=False,
    shuffle_valid=False,
    ndims=2,
    return_datasets=True,
    ncells_sample=200,
    length_multiplier=3,
    seed=42,
)

In [11]:
print(len(train_dset))
for i, item in enumerate(train_dset):
    print(i)

126
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125


In [None]:
t0, x0, t1, sigparams = train_dset[3][0]
x1 = train_dset[0][1]

print("t0:", t0)
print("t1:", t1)
print("x0.shape:", x0.shape)
print("x1.shape:", x1.shape)

plt.plot(x0[:,0], x0[:,1], '.k');
plt.plot(x1[:,0], x1[:,1], '.r');

In [None]:
k = 0 if transition_idx == 1 else 1
CONDITION_MAP_TRAIN = {
    0-k: "NO CHIR",
    1-k: "CHIR 2-3",
    2-k: "CHIR 2-4",
    3-k: "CHIR 2-5",
    4-k: "CHIR 2-5 FGF 2-3",
    5-k: "CHIR 2-5 FGF 2-4",
    6-k: "CHIR 2-5 FGF 2-5",
}
if transition_idx == 2:
    CONDITION_MAP_TRAIN.pop(-1)


CONDITION_MAP_VALID = {
    0: "CHIR 2-2.5",
    1: "CHIR 2-3.5",
    2: "CHIR 2-5 FGF 2-3.5",
    3: "CHIR 2-5 FGF 2-4.5",
}

In [None]:
i = 0
ani = train_dset.animate(
    i, show=False, interval=1000,
    xlims=[-0.1, 1] if PCA_OR_NMF == 'nmf' else [-8, 8],
    ylims=[-0.1, 1] if PCA_OR_NMF == 'nmf' else [-8, 8],
    col1='grey', 
    col2='k',
    title=CONDITION_MAP_TRAIN[i],
)

HTML(ani)

In [None]:
for i, cond_name in CONDITION_MAP_TRAIN.items():
    ani = train_dset.animate(
        i, show=False, interval=1000,
        xlims=[-0.1, 1] if PCA_OR_NMF == 'nmf' else [-8, 8],
        ylims=[-0.1, 1] if PCA_OR_NMF == 'nmf' else [-8, 8],
        col1='grey', 
        col2='k',
        title=cond_name,
        fps=1,
        saveas=f"{OUTDIR}/{PCA_OR_NMF}_training_data_anim_{cond_name}.mp4",
    )

for i, cond_name in CONDITION_MAP_VALID.items():
    ani = train_dset.animate(
        i, show=False, interval=1000,
        xlims=[-0.1, 1] if PCA_OR_NMF == 'nmf' else [-8, 8],
        ylims=[-0.1, 1] if PCA_OR_NMF == 'nmf' else [-8, 8],
        col1='grey', 
        col2='k',
        title=cond_name,
        fps=1,
        saveas=f"{OUTDIR}/{PCA_OR_NMF}_validation_data_anim_{cond_name}.mp4",
    )

    