In [1]:
import typing as ty
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import FrEIA.framework as ff
import FrEIA.modules as fm
sns.set()
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica"]})

In [2]:
data_path = "./data/microbiom"

counts = pd.read_csv(f"{data_path}/readcounts.csv", skiprows=[4774])#, sep=',', header=None)#, lineterminator='\n')#, error_bad_lines=False)
counts.rename(columns={"Unnamed: 0": "Name"}, inplace=True)
counts.set_index("Name", inplace=True)

with open(f"{data_path}/readcounts.csv", "r") as f:
    content = f.readlines()
error_line = content[4774].split(',')
idx_name = ''.join(error_line[:2])
error_line = [int(x) for x in error_line[2:]]
error_line = {k: {idx_name: v} for k, v in zip(counts.columns, error_line)}
error_line = pd.DataFrame(error_line, index=pd.Series([idx_name], name='Name'))
counts = pd.concat([counts.iloc[:4774], error_line, counts.iloc[4774:]])
counts = counts.transpose()
counts

Name,Candidatus_Korarchaeota (Archaea),Acidilobus_saccharovorans (Archaea),Caldisphaera_lagunensis (Archaea),Desulfurococcus_mucosus (Archaea),Ignicoccus_hospitalis (Archaea),Pyrodictium_delaneyi (Archaea),Pyrodictium_occultum (Archaea),Pyrolobus_fumarii (Archaea),Acidianus_hospitalis (Archaea),Sulfolobus_tokodaii (Archaea),...,Cardiovirus_B (Viruses),ssRNA_positive-strand_viruses_no_DNA_stage (Viruses),Senecavirus_A (Viruses),Rice_tungro_spherical_virus (Viruses),unclassified_bacterial_viruses (Viruses),Enterobacteria_phage_YYZ-2008 (Viruses),Salmonella_phage_118970_sal3 (Viruses),Streptococcus_phage_20617 (Viruses),Streptococcus_phage_phiARI0131-2 (Viruses),Torulaspora_delbrueckii_dsRNA_Mbarr-1_killer_virus (Viruses)
Sample5854,0,0,0,1,0,2,1,0,1,0,...,0,1,3,0,0,2,2,0,0,1
Sample691,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
Sample80,0,0,0,0,0,0,0,0,0,0,...,1,1,1,0,0,0,2,0,0,0
Sample1717,0,0,0,0,0,1,0,0,0,0,...,0,1,1,1,0,0,0,1,0,0
Sample5350,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Sample4090,0,0,0,1,0,0,1,0,0,0,...,0,0,0,0,0,1,1,0,0,0
Sample5937,0,0,0,0,0,0,0,0,0,0,...,1,0,3,0,0,0,0,0,0,0
Sample3006,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
Sample3324,0,0,0,0,0,0,0,0,0,0,...,0,0,1,1,0,2,1,0,0,0


In [5]:
phenodata = pd.read_csv(f"{data_path}/phenodata.csv")
phenodata.rename(columns={"Unnamed: 0": "Sample"}, inplace=True)
phenodata.set_index("Sample", inplace=True)
#phenodata.isnull().sum(axis=0)
nan_samples = [list(phenodata[phenodata[col].isna()].index) for col in phenodata.columns]
nan_samples = [xx for x in nan_samples for xx in x]
nan_samples = np.unique(nan_samples).tolist()
#nan_samples

In [6]:
nan_samples = [list(phenodata[phenodata[col].isna()].index) for col in phenodata.columns]
nan_samples = [xx for x in nan_samples for xx in x]
nan_samples = np.unique(nan_samples).tolist()

counts = counts[~phenodata.index.isin(nan_samples)]
phenodata = phenodata[~phenodata.index.isin(nan_samples)]

In [7]:
class MicrobiomDataset:
    def __init__(self, x: pd.DataFrame, y: pd.DataFrame) -> None:
        self.x = x
        self.y = y
    
    def __len__(self) -> int:
        return len(self.x)
    
    def __getitem__(self, i: int) -> ty.Tuple[torch.Tensor, torch.Tensor]:
        x = self.x.iloc[i]
        y = self.y.loc[x.name]
        return torch.tensor(x).float(), torch.tensor(y).float()

In [8]:
input_dim = counts.shape[1]
cond_dim = phenodata.shape[1]
n_blocks = 10
init_scale = 0.001
input_noise = 0.1
lr = 1e-4
batch_size = 128
n_epochs = 1000

In [9]:
def subnet_fc(dims_in: int, dims_out: int) -> nn.Sequential:
    subnet = nn.Sequential(
        nn.Linear(dims_in, 128),
        nn.ReLU(),
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, dims_out)
    )
    for l in subnet:
        if isinstance(l, nn.Linear):
            l.weight.data = init_scale * torch.randn(l.weight.shape)
            l.bias.data = init_scale * torch.randn(l.bias.shape)
    subnet[-1].weight.data.fill_(0.)
    subnet[-1].bias.data.fill_(0.)
    return subnet

nodes = [ff.InputNode(input_dim, name="inp")]
cond = ff.ConditionNode(cond_dim, name="cond")
for i in range(1, n_blocks + 1):
    nodes.append(ff.Node([nodes[-1].out0], fm.PermuteRandom, {"seed": i}, name=f"permute_{i}"))
    nodes.append(ff.Node([nodes[-1].out0], fm.GLOWCouplingBlock, {"clamp": 2.0, "subnet_constructor": subnet_fc}, conditions=[cond], name=f"coupling_{i}"))
nodes.append(ff.OutputNode([nodes[-1].out0], name="out"))
nodes.append(cond)

inn = ff.ReversibleGraphNet(nodes, verbose=False).cuda()
optim = torch.optim.Adam(inn.parameters(), lr=lr)

In [10]:
idx = torch.randperm(len(counts))
train_idx, test_idx = idx[:int(len(counts) * 0.8)], idx[int(len(counts) * 0.2):]
train_counts, test_counts = counts.iloc[train_idx.tolist()], counts.iloc[test_idx.tolist()]
trainset = MicrobiomDataset(train_counts, phenodata)
testset = MicrobiomDataset(test_counts, phenodata)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=16)

In [12]:
for e in tqdm(range(n_epochs)):
    epoch_losses = []
    for x, y in tqdm(train_loader, leave=False):
        x, y = x.cuda(), y.cuda()
        
        x += torch.abs(input_noise * torch.randn_like(x))
        
        optim.zero_grad()
        z, jac = inn(x, y)
        zz = 0.5 * torch.sum(z**2, dim=1)
        nll = zz - jac
        l_fwd = torch.mean(nll) / x.size(1)
        l_fwd.backward()
        optim.step()
        
        epoch_losses.append(l_fwd.item())
    if e % 10 == 0:
        print(f"{e}: {np.mean(epoch_losses)}")

  0%|          | 0/18 [00:00<?, ?it/s]

0: 730345.4192165799


  0%|          | 0/18 [00:00<?, ?it/s]

1: 2677.797821044922


  0%|          | 0/18 [00:00<?, ?it/s]

2: 477.26744503445093


  0%|          | 0/18 [00:00<?, ?it/s]

3: 105.91987122429742


  0%|          | 0/18 [00:00<?, ?it/s]

4: 29.732804934183758


  0%|          | 0/18 [00:00<?, ?it/s]

5: 10.272218147913614


  0%|          | 0/18 [00:00<?, ?it/s]

6: 4.5491732358932495


  0%|          | 0/18 [00:00<?, ?it/s]

7: 2.501934621069166


  0%|          | 0/18 [00:00<?, ?it/s]

8: 1.6214497023158603


  0%|          | 0/18 [00:00<?, ?it/s]

9: 1.1673606832822163


  0%|          | 0/18 [00:00<?, ?it/s]

10: 0.8890213469664255


  0%|          | 0/18 [00:00<?, ?it/s]

11: 0.704665376080407


  0%|          | 0/18 [00:00<?, ?it/s]

12: 0.5965918021069633


  0%|          | 0/18 [00:00<?, ?it/s]

13: 0.5344894114467833


  0%|          | 0/18 [00:00<?, ?it/s]

14: 0.49204082290331524


  0%|          | 0/18 [00:00<?, ?it/s]

15: 0.46212585601541734


  0%|          | 0/18 [00:00<?, ?it/s]

16: 0.44273560908105636


  0%|          | 0/18 [00:00<?, ?it/s]

17: 0.4271456135643853


  0%|          | 0/18 [00:00<?, ?it/s]

18: 0.4152983609173033


  0%|          | 0/18 [00:00<?, ?it/s]

19: 0.40901060071256423


  0%|          | 0/18 [00:00<?, ?it/s]

20: 0.401022066672643


  0%|          | 0/18 [00:00<?, ?it/s]

21: 0.39929226371977067


  0%|          | 0/18 [00:00<?, ?it/s]

22: 0.39444854855537415


  0%|          | 0/18 [00:00<?, ?it/s]

23: 0.39158908194965786


  0%|          | 0/18 [00:00<?, ?it/s]

24: 0.3880762772427665


  0%|          | 0/18 [00:00<?, ?it/s]

25: 0.384730178448889


  0%|          | 0/18 [00:00<?, ?it/s]

26: 0.381675210263994


  0%|          | 0/18 [00:00<?, ?it/s]

27: 0.38110345436467064


  0%|          | 0/18 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [20]:
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.cuda(), y.cuda()
        z, jac = inn(x, y)
        rev, jac = inn(z, y, rev=True)
        break
((x - rev.int())**2).mean()

tensor(0.2885, device='cuda:0')

In [21]:
x

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')