In [1]:
import pickle
import numpy as np
import biom
import csv

In [2]:
with open ("oral_medium.p", "rb") as f:
    d = pickle.load(f)

In [3]:
d.keys()

dict_keys(['theta', 'Ytrain', 'Ytest', 'Vtrain', 'Vtest'])

In [4]:
d['theta'].shape

(20, 21)

In [5]:
d['theta'].dtype

dtype('float64')

In [6]:
len(d['Ytrain'])

22

In [7]:
len(d['Ytest'])

6

# top 20 taxa + the rest


Note: we set the rest of taxa (i.e. top 21, 22, ....)  to be taxa0

In [8]:
theta = np.zeros((20, 21))

In [9]:
theta[0, 0] = 1
theta[0, 1:] = -1

In [10]:
theta[1, 1:4] = 1
theta[1, 4:] = -1

In [11]:
theta[2, 1] = 1
theta[2, 2:4] = -1

In [12]:
theta[3, 4] = 1
theta[3, 5:] = -1

In [13]:
theta[4, 2] = 1
theta[4, 3] = -1

In [14]:
theta[5, 5] = 1
theta[5, 6:9] = -1

In [15]:
theta[6, 5:9] = 1
theta[6, 9:] = -1

In [16]:
theta[7, 6] = 1
theta[7, 7:9] = -1

In [17]:
theta[8, 9] = 1
theta[8, 10:] = -1

In [18]:
theta[9, 7] = 1
theta[9, 8] = -1

In [19]:
theta[10, 10:13] = 1
theta[10, 13:] = -1

In [20]:
theta[11, 10] = 1
theta[11, 11:13] = -1

In [21]:
theta[12, 13:15] = 1
theta[12, 15:] = -1

In [22]:
theta[13, 11] = 1
theta[13, 12] = -1

In [23]:
theta[14, 13] = 1
theta[14, 14] = -1

In [24]:
theta[15, 15] = 1
theta[15, 16:] = -1

In [25]:
theta[16, 16:18] = 1
theta[16, 18:] = -1

In [26]:
theta[17, 16] = 1
theta[17, 17] = -1

In [27]:
theta[18, 18] = 1
theta[18, 19:] = -1

In [28]:
theta[19, 19] = 1
theta[19, 20] = -1

In [29]:
num_taxa = 20

In [30]:
table = biom.load_table("all.biom")

In [31]:
data = table._data.toarray().T

In [32]:
data.shape

(1050, 6400)

In [33]:
sample_ids = table.ids()
taxon_ids = table.ids(axis="observation")

In [34]:
data_norm = data / data.sum(axis=-1, keepdims=True)
data_sum = np.sum(data_norm, axis=0) / np.sum(data_norm)
sel_idxes = [idx for i, idx in enumerate(np.argsort(-data_sum)) if i < num_taxa]
sel_taxon_ids = taxon_ids[sel_idxes]

# Change the order of selected idxes to match the ordering in phylogenetic tree

In [35]:
print("original odering:", sel_idxes)

original odering: [17, 15, 19, 62, 11, 106, 54, 117, 1, 9, 14, 23, 154, 24, 12, 5, 135, 141, 3, 124]


In [36]:
sel_idxes = [154, 54, 124, 3, 135,19, 15, 23, 62, 12, 117, 5, 14, 24, 106, 11, 141,1, 17, 9]

In [37]:
len(sel_idxes)

20

## note: the first taxa, i.e. taxa 0, is the rest of the taxa

In [38]:
def get_traj_and_input(data, sel_idxes):
    trajs, inputs = [], []
    traj, Input, timestamps, prev_time_idx, prev_timestamp = [], [], [], None, None
    others = []
    foods = []
    current_host_id = None

    tmp = []
    tmp_ = {}

    with open("map.onlymedtime.txt") as tsv:
        tsv_reader = csv.DictReader(tsv, delimiter="\t")
        for i, line in enumerate(tsv_reader):
            sampled_id = line["#SampleID"]
            host_id = line["hostid"]
            conttime = int(line["conttime"])

            drink = line["drink"]
            eat = line["eat"]
            sleep = line["sleep"]
            other = line["other"]

            tmp_[line["body site"]] = tmp_.get(line["body site"], 0) + 1

            if sampled_id not in sample_ids:
                # print("host {}'s sample {} at time {} doesn't exist".format(host_id, sampled_id, line["dayhourmin"]))
                continue
            if line["body site"] != "saliva":
                continue

            sample_interval = 15
            time_idx = conttime // sample_interval
            traj_terminate = False

            if host_id != current_host_id:
                traj_terminate = True
                # print("host_id:", host_id)

            if prev_time_idx is not None:
                cur_timestamp = line["dayhourmin"]
                if time_idx == prev_time_idx:
                    # print("host {} had two samples at time {} and {}, ignore the 2nd one".format(
                    #     host_id, prev_timestamp, cur_timestamp))
                    continue
                elif time_idx - prev_time_idx > 2 * 60 / sample_interval and line["hour"]:
                    # print(prev_timestamp, cur_timestamp)
                    traj_terminate = True
            prev_time_idx, prev_timestamp = time_idx, line["dayhourmin"]

            if traj_terminate:
                if len(traj) > 4:
                    traj, Input, timestamps = np.array(traj), np.array(Input), np.array(timestamps) - timestamps[0]
                    traj = np.concatenate([timestamps[:, None], traj], axis=-1)
                    Input = np.concatenate([timestamps[:-1, None], Input[1:, :]], axis=-1)
                    trajs.append(traj)
                    inputs.append(Input)
                current_host_id = host_id
                traj, Input, timestamps, prev_time_idx, prev_timestamp = [], [], [], None, None

            # observations (counts)
            assert len(np.where(sample_ids == sampled_id)[0]) == 1
            row_idx = np.where(sample_ids == sampled_id)[0][0]
            sel_taxon_counts = data[row_idx][sel_idxes]                    # (num_taxa,)
            remain_counts = data[row_idx].sum() - sel_taxon_counts.sum()   # scalar
            y = np.concatenate([sel_taxon_counts, [remain_counts]])
            traj.append(y)

            # inputs [drink, eat, sleep] all 0/1
            Input.append([drink != "", eat != "", sleep != ""])

            timestamps.append(time_idx)

            if other != "":
                others.append(other)
            if eat != "":
                foods.append(eat)

    if traj is not None:
        traj, Input, timestamps = np.array(traj), np.array(Input), np.array(timestamps)
        traj = np.concatenate([timestamps[:, None], traj], axis=-1)
        Input = np.concatenate([timestamps[:-1, None], Input[1:, :]], axis=-1)
        trajs.append(traj)
        inputs.append(Input)

    # print(set(tmp))
    # print(tmp_)

    # print("\nfoods includes:")
    # for food in set(foods):
    #     print(food)

    # print("\nothers includes:")
    # for other in set(others):
    #     print(other)

    return trajs, inputs

In [39]:
def save_data(trajs, inputs):
    idx = list(range(len(trajs)))
    n_train = int(len(trajs) * 0.8)
    total_num = np.sum([len(traj) for traj in trajs])
    while True:
        np.random.shuffle(idx)
        train_idx = idx[:n_train]
        test_idx = idx[n_train:]
        train_num = np.sum([len(trajs[idx]) for idx in train_idx])
        test_num = np.sum([len(trajs[idx]) for idx in test_idx])
        if 0.75 <= train_num / total_num <= 0.8:
            break

    c_data = {}
    c_data['theta'] = theta
    c_data['Ytrain'] = [trajs[idx] for idx in train_idx]  # count
    c_data['Ytest'] = [trajs[idx] for idx in test_idx]
    c_data['Vtrain'] = [inputs[idx] for idx in train_idx]
    c_data['Vtest'] = [inputs[idx] for idx in test_idx]

    Dy = len(trajs[0][0]) - 1
    with open("oral_{}_taxa.p".format(Dy), "wb") as f:
        pickle.dump(c_data, f)

    # stats
    print("Num of trajs", len(trajs))
    print("Num of total timestamps", total_num)
    print("Num of train timestamps", train_num)
    print("Num of test timestamps", test_num)

In [40]:
trajs, inputs = get_traj_and_input(data, sel_idxes)

In [41]:
len(inputs)

28

In [42]:
len(trajs)

28

In [43]:
save_data(trajs, inputs)

Num of trajs 28
Num of total timestamps 817
Num of train timestamps 622
Num of test timestamps 195


# TOP 15 taxa

In [44]:
theta = np.zeros((15, 16))

In [45]:
theta[0, 0] = 1
theta[0, 1:] = -1
theta[1, 1:3] = 1
theta[1, 3:] = -1
theta[2, 1] = 1
theta[2, 2] = -1
theta[3, 3:6] = 1
theta[3, 6:] = -1
theta[4, 3] = 1
theta[4, 4:6] = -1
theta[5, 6] = 1
theta[5, 7:] = -1
theta[6, 4] = 1
theta[6, 5] = -1
theta[7, 7:9] = 1
theta[7, 9:] = -1
theta[8, 7] = 1
theta[8, 8] = -1
theta[9, 9:11] = 1
theta[9, 11:] = -1
theta[10, 9] = 1
theta[10, 10] = -1
theta[11, 11] = 1
theta[11, 12:] = -1
theta[12, 12] = 1
theta[12, 13:] = -1
theta[13, 13] = 1
theta[13, 14:] = -1
theta[14, 14] = 1
theta[14, 15] = -1

In [46]:
num_taxa = 15

In [47]:
data_norm = data / data.sum(axis=-1, keepdims=True)
data_sum = np.sum(data_norm, axis=0) / np.sum(data_norm)
sel_idxes = [idx for i, idx in enumerate(np.argsort(-data_sum)) if i < num_taxa]
sel_taxon_ids = taxon_ids[sel_idxes]

# Change the order of selected idxes to match the ordering in phylogenetic tree

In [48]:
print("original odering:", sel_idxes)

original odering: [17, 15, 19, 62, 11, 106, 54, 117, 1, 9, 14, 23, 154, 24, 12]


In [49]:
len(sel_idxes)

15

In [50]:
sel_idxes = [54, 154, 19, 15, 23, 62, 117, 12, 14, 24, 106, 11, 1, 17, 9]

In [51]:
len(sel_idxes)

15

In [52]:
trajs, inputs = get_traj_and_input(data, sel_idxes)
save_data(trajs, inputs)

Num of trajs 28
Num of total timestamps 817
Num of train timestamps 614
Num of test timestamps 203


# TOP 10 taxa

In [53]:
theta = np.zeros((10, 11))

In [54]:
theta[0, 0] = 1
theta[0, 1:] = -1
theta[1, 1] = 1
theta[1, 2:] = -1
theta[2, 2:4] = 1
theta[2, 4:] = -1
theta[3, 2] = 1
theta[3, 3] = -1
theta[4, 4] = 1
theta[4, 5:] = -1
theta[5, 5] = 1
theta[5, 6:] = -1
theta[6, 6:8] = 1
theta[6, 8:] = -1
theta[7, 6] = 1
theta[7, 7] = -1
theta[8, 8] = 1
theta[8, 9:] = -1
theta[9, 9] = 1
theta[9, 10] = -1

In [55]:
num_taxa = 10

In [56]:
data_norm = data / data.sum(axis=-1, keepdims=True)
data_sum = np.sum(data_norm, axis=0) / np.sum(data_norm)
sel_idxes = [idx for i, idx in enumerate(np.argsort(-data_sum)) if i < num_taxa]
sel_taxon_ids = taxon_ids[sel_idxes]

# Change the order of selected idxes to match the ordering in phylogenetic tree

In [57]:
print("original odering:", sel_idxes)

original odering: [17, 15, 19, 62, 11, 106, 54, 117, 1, 9]


In [58]:
len(sel_idxes)

10

In [59]:
sel_idxes = [54, 15, 19, 117, 62, 11, 106, 9, 17, 1]

In [60]:
len(sel_idxes)

10

In [61]:
trajs, inputs = get_traj_and_input(data, sel_idxes)
save_data(trajs, inputs)

Num of trajs 28
Num of total timestamps 817
Num of train timestamps 635
Num of test timestamps 182
