In [1]:
import numpy as np
import pandas as pd
from pandas.core.frame import DataFrame
from typing import List
from collections import OrderedDict
import random

import torch.nn as nn
import torch.nn.functional as F
import torch

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torch.autograd as autograd
from torch.autograd import Variable

from annoy import AnnoyIndex

In [2]:
key = 'batch'
n_epochs = 150
num_workers=0
lr = 0.0005
b1 = 0.5
b2 = 0.999
latent_dim = 256
n_critic = 5
lambda_co = 3
lambda_rc = 1
seed = 8

In [3]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
setup_seed(seed)

In [4]:
class ScDataset(Dataset):
    def __init__(self):
        self.dataset = []
        self.variable = None
        self.labels = None
        self.transform = None
        self.sample = None
        self.trees = []

    def __len__(self):
        return 10 * 1024

    def __getitem__(self, index):
        dataset_samples = []
        for j, dataset in enumerate(self.dataset):
            rindex1 = np.random.randint(len(dataset))
            rindex2 = np.random.randint(len(dataset))
            alpha = np.random.uniform(0, 1)
            sample = alpha*dataset[rindex1] + (1-alpha)*dataset[rindex2]
            dataset_samples.append(sample)
        return dataset_samples

In [5]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(data_size, 1024),
            nn.BatchNorm1d(1024),
            Mish(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            Mish(),
            nn.Linear(512, latent_dim),
        )

    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.relu = torch.nn.ReLU()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            Mish(),
            nn.Linear(512, 1024),
            Mish(),
            nn.Linear(1024, data_size),
        )
        self.decoder2 = nn.Sequential(
            nn.Linear(n_classes, 512),
            Mish(),
            nn.Linear(512, 1024),
            Mish(),
            nn.Linear(1024, data_size),
        )

    def forward(self, ec, es):
        return self.relu(self.decoder(torch.cat((ec, es), dim=-1))+self.decoder2(es))

In [6]:
class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

In [7]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [8]:
def cat_data(data_A: np.float32, data_B: np.float32, labels: List[List[int]]=None):
    data = np.r_[data_A, data_B]
    if labels is None:
        label = np.zeros(len(data_A)+len(data_B))
        label[-len(data_B):] = 1
        label = np.array([label]).T
    else:
        label = np.r_[labels[0], labels[1]]
    return data, label

In [9]:
train = pd.read_csv('./4347_final_relative_abundances.txt',sep='\t',index_col=0).T
train

Unnamed: 0,s__Abiotrophia_defectiva,s__Acetobacter_unclassified,s__Achromobacter_piechaudii,s__Achromobacter_unclassified,s__Achromobacter_xylosoxidans,s__Acidaminococcus_fermentans,s__Acidaminococcus_intestini,s__Acidaminococcus_sp_BV3L6,s__Acidaminococcus_sp_D21,s__Acidaminococcus_sp_HPA0509,...,s__Vibrio_furnissii,s__Vibrio_kanaloae,s__Weissella_cibaria,s__Weissella_confusa,s__Weissella_halotolerans,s__Weissella_koreensis,s__Weissella_paramesenteroides,s__Weissella_unclassified,s__Wohlfahrtiimonas_chitiniclastica,s__Yersinia_enterocolitica
ACVD_1,0.01142,0.0,0.0,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.31613,0.11288,0.0,0.0,0.0,0.0,0.0,0.0
ACVD_2,0.00000,0.0,0.0,0.0,0.0,0.06002,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.0,0.0,0.0
ACVD_3,0.00088,0.0,0.0,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.0,0.0,0.0
ACVD_4,0.00000,0.0,0.0,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00058,0.0,0.0,0.0,0.0,0.0,0.0
ACVD_5,0.00000,0.0,0.0,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Underweight_4343,0.00031,0.0,0.0,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.0,0.0,0.0
Underweight_4344,0.00000,0.0,0.0,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.0,0.0,0.0
Underweight_4345,0.00000,0.0,0.0,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.0,0.0,0.0
Underweight_4346,0.00000,0.0,0.0,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.0,0.0,0.0


In [10]:
metadata = pd.read_csv('./Final_metadata_4347.csv',index_col=0).T

  interactivity=interactivity, compiler=compiler, result=result)


In [11]:
metadata

study,study.1,Study No. (From VG sheet (V-*) from SB sheet (S-*)),Title of Paper,Author (year),Journal,Study Accession,Sample Accession or Sample ID,Sample title (ENA/SRA),Sample title (Paper),Subject Id (If available),...,s__Subdoligranulum_variabile,s__Succinatimonas_hippei,s__Sutterella_wadsworthensis,s__Turicibacter_sanguinis,s__Varibaculum_cambriense,s__Veillonella_atypica,s__Veillonella_dispar,s__Veillonella_parvula,s__Weissella_cibaria,s__Weissella_confusa
V-2_ACVD,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142287,ZSL-004,ZSL-004,ZSL-004,...,0.000501901,0,0.0313742,0,0,0.00342788,0.00199693,0.134307,0.337587,0.120542
V-2_ACVD.1,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142288,ZSL-007,ZSL-007,ZSL-007,...,0.00354857,0,0,0,0,4.50212,0.122268,1.75344,0,0
V-2_ACVD.2,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142293,ZSL-010,ZSL-010,ZSL-010,...,0,0,0,0,0,0,0,0.00401139,0,0
V-2_ACVD.3,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142291,ZSL-011,ZSL-011,ZSL-011,...,0,0,0,0,0,0.00139602,0.00176556,0.145084,0,0.000595363
V-2_ACVD.4,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142284,ZSL-019,ZSL-019,ZSL-019,...,0.00126287,0,0,0,0,0.0129392,0.000238082,0.0163449,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
S-7_Underweight,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431948,SZAXPI029564-74,SZAXPI029564-74,GZCT014,...,0,0,0.16279,0,0,0.0322849,0.0587168,0.10368,0,0
S-7_Underweight.1,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431949,SZAXPI029565-77,SZAXPI029565-77,GZCT015,...,0,0,0,0,0,0,0.00125461,0.312764,0,0
S-7_Underweight.2,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431951,SZAXPI029567-80,SZAXPI029567-80,GZCT017,...,0.00115932,0,0,0,0,0.0865797,0.0366468,0.0827324,0,0
S-7_Underweight.3,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431964,SZAXPI029580-98,SZAXPI029580-98,GZCT030,...,0,0,0,0,0,0,0,0.00855007,0,0


In [12]:
train.insert(0,'batch',metadata['Author (year)'].tolist())

In [13]:
train.insert(0,'Type','')

In [14]:
train['Type']=pd.Series(train.index.values).str.split('_',expand=True)[0].tolist()

In [16]:
train_p = train[train['Type']=='']

In [17]:
train_p

Unnamed: 0,Type,batch,s__Abiotrophia_defectiva,s__Acetobacter_unclassified,s__Achromobacter_piechaudii,s__Achromobacter_unclassified,s__Achromobacter_xylosoxidans,s__Acidaminococcus_fermentans,s__Acidaminococcus_intestini,s__Acidaminococcus_sp_BV3L6,...,s__Vibrio_furnissii,s__Vibrio_kanaloae,s__Weissella_cibaria,s__Weissella_confusa,s__Weissella_halotolerans,s__Weissella_koreensis,s__Weissella_paramesenteroides,s__Weissella_unclassified,s__Wohlfahrtiimonas_chitiniclastica,s__Yersinia_enterocolitica
Overweight_3579,Overweight,Jie (2017),0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3580,Overweight,Jie (2017),0.0,0.0,0.0,0.0,0.0,0.0,0.00653,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3581,Overweight,Jie (2017),0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,...,0.0,0.0,0.02384,0.10129,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3582,Overweight,Jie (2017),0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3583,Overweight,Jie (2017),0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Overweight_3812,Overweight,Obregon-Tito (2015),0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3813,Overweight,Obregon-Tito (2015),0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3814,Overweight,Obregon-Tito (2015),0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3815,Overweight,Obregon-Tito (2015),0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0


In [18]:
batches = sorted(list(set(train['batch'])))

In [19]:
train_p['batch'].value_counts()

Qin (2012)                 48
Jie (2017)                 40
Zhang (2015)               27
Feng (2015)                20
Nielsen (2014)             19
Karlsson (2013)            17
Vogtmann (2016)            12
Zeller (2014)              12
Obregon-Tito (2015)         9
Qin (2014)                  7
He (2017)                   7
Sankaranarayanan (2015)     4
Le Chatelier (2013)         4
Raymond (2015)              4
Karlsson (2012)             3
Guthrie (2017)              3
Nishijima (2016)            2
Name: batch, dtype: int64

In [20]:
train_p['batch'].value_counts()[6:].index

Index(['Vogtmann (2016)', 'Zeller (2014)', 'Obregon-Tito (2015)', 'Qin (2014)',
       'He (2017)', 'Sankaranarayanan (2015)', 'Le Chatelier (2013)',
       'Raymond (2015)', 'Karlsson (2012)', 'Guthrie (2017)',
       'Nishijima (2016)'],
      dtype='object')

In [21]:
batches_list = [ 'Other' if batch in set(train_p['batch'].value_counts()[6:].index) else batch for batch in train_p['batch'] ]

In [22]:
batches = sorted(list(set(batches_list)))

In [23]:
batches

['Feng (2015)',
 'Jie (2017)',
 'Karlsson (2013)',
 'Nielsen (2014)',
 'Other',
 'Qin (2012)',
 'Zhang (2015)']

In [24]:
b = train_p['batch']

In [25]:
train_p['batch'] = batches_list

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.


In [26]:
scd = ScDataset()

In [27]:
scd.variable = np.array(train_p.columns.tolist()[1:])

In [28]:
adata_values = [[train_p[train_p['batch']==batch].iloc[:,2:].values][0] for batch in batches]

In [29]:
std_ = [np.sum(np.std(item, axis=0)) for item in adata_values]

In [30]:
orders = np.argsort(std_)[::-1]
orders

array([4, 5, 1, 6, 0, 3, 2])

In [31]:
obs_names = [np.array( train_p[train_p['batch']== batch].index.tolist()) for batch in batches ]  

In [32]:
obs_names

[array(['Overweight_3722', 'Overweight_3723', 'Overweight_3724',
        'Overweight_3725', 'Overweight_3726', 'Overweight_3727',
        'Overweight_3728', 'Overweight_3729', 'Overweight_3730',
        'Overweight_3731', 'Overweight_3732', 'Overweight_3733',
        'Overweight_3734', 'Overweight_3735', 'Overweight_3736',
        'Overweight_3737', 'Overweight_3738', 'Overweight_3739',
        'Overweight_3740', 'Overweight_3741'], dtype='<U15'),
 array(['Overweight_3579', 'Overweight_3580', 'Overweight_3581',
        'Overweight_3582', 'Overweight_3583', 'Overweight_3584',
        'Overweight_3585', 'Overweight_3586', 'Overweight_3587',
        'Overweight_3588', 'Overweight_3589', 'Overweight_3590',
        'Overweight_3591', 'Overweight_3592', 'Overweight_3593',
        'Overweight_3594', 'Overweight_3595', 'Overweight_3596',
        'Overweight_3597', 'Overweight_3598', 'Overweight_3599',
        'Overweight_3600', 'Overweight_3601', 'Overweight_3602',
        'Overweight_3603', '

In [33]:
ec_obs_names = None
for item in orders:
    if ec_obs_names is None:
        ec_obs_names = obs_names[item]
    else:
        ec_obs_names = np.r_[ec_obs_names, obs_names[item]]

In [34]:
ec_obs_names

array(['Overweight_3619', 'Overweight_3620', 'Overweight_3621',
       'Overweight_3622', 'Overweight_3623', 'Overweight_3624',
       'Overweight_3625', 'Overweight_3626', 'Overweight_3627',
       'Overweight_3628', 'Overweight_3629', 'Overweight_3630',
       'Overweight_3696', 'Overweight_3697', 'Overweight_3698',
       'Overweight_3699', 'Overweight_3700', 'Overweight_3701',
       'Overweight_3702', 'Overweight_3769', 'Overweight_3770',
       'Overweight_3771', 'Overweight_3772', 'Overweight_3773',
       'Overweight_3774', 'Overweight_3775', 'Overweight_3776',
       'Overweight_3777', 'Overweight_3778', 'Overweight_3779',
       'Overweight_3780', 'Overweight_3781', 'Overweight_3782',
       'Overweight_3783', 'Overweight_3784', 'Overweight_3785',
       'Overweight_3786', 'Overweight_3787', 'Overweight_3788',
       'Overweight_3789', 'Overweight_3790', 'Overweight_3791',
       'Overweight_3792', 'Overweight_3793', 'Overweight_3794',
       'Overweight_3795', 'Overweight_37

In [35]:
scd.dataset = [adata_values[i] for i in orders]

In [36]:
dataloader = DataLoader(
        dataset = scd,
        batch_size=512,
        num_workers=num_workers,
    )

In [37]:
global data_size
global n_classes

In [38]:
data_size = scd.dataset[0].shape[1]
n_classes = len(scd.dataset)

In [39]:
data_size

903

In [40]:
n_classes

7

In [41]:
EC = Encoder(latent_dim)
Dec = Decoder(latent_dim + n_classes)
mse_loss = torch.nn.MSELoss()

In [42]:
EC

Encoder(
  (encoder): Sequential(
    (0): Linear(in_features=903, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish()
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Mish()
    (6): Linear(in_features=512, out_features=256, bias=True)
  )
)

In [43]:
Dec

Decoder(
  (relu): ReLU()
  (decoder): Sequential(
    (0): Linear(in_features=263, out_features=512, bias=True)
    (1): Mish()
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): Mish()
    (4): Linear(in_features=1024, out_features=903, bias=True)
  )
  (decoder2): Sequential(
    (0): Linear(in_features=7, out_features=512, bias=True)
    (1): Mish()
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): Mish()
    (4): Linear(in_features=1024, out_features=903, bias=True)
  )
)

In [44]:
cuda = True if torch.cuda.is_available() else False
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

In [45]:
if cuda:
    EC.cuda()
    Dec.cuda()
    mse_loss.cuda()

In [46]:
EC.apply(weights_init_normal)

Encoder(
  (encoder): Sequential(
    (0): Linear(in_features=903, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish()
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Mish()
    (6): Linear(in_features=512, out_features=256, bias=True)
  )
)

In [47]:
Dec.apply(weights_init_normal)

Decoder(
  (relu): ReLU()
  (decoder): Sequential(
    (0): Linear(in_features=263, out_features=512, bias=True)
    (1): Mish()
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): Mish()
    (4): Linear(in_features=1024, out_features=903, bias=True)
  )
  (decoder2): Sequential(
    (0): Linear(in_features=7, out_features=512, bias=True)
    (1): Mish()
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): Mish()
    (4): Linear(in_features=1024, out_features=903, bias=True)
  )
)

In [48]:
optimizer_Dec = torch.optim.Adam(Dec.parameters(), lr=lr, betas=(b1, b2))
optimizer_EC = torch.optim.Adam(EC.parameters(), lr=lr, betas=(b1, b2))

In [49]:
Dec.train()

Decoder(
  (relu): ReLU()
  (decoder): Sequential(
    (0): Linear(in_features=263, out_features=512, bias=True)
    (1): Mish()
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): Mish()
    (4): Linear(in_features=1024, out_features=903, bias=True)
  )
  (decoder2): Sequential(
    (0): Linear(in_features=7, out_features=512, bias=True)
    (1): Mish()
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): Mish()
    (4): Linear(in_features=1024, out_features=903, bias=True)
  )
)

In [50]:
for epoch in range(n_epochs):
    Dec.train()
    EC.train()

    for i, data in enumerate(dataloader):
        datum = [Variable(item.type(FloatTensor)) for item in data]
        batch_size = datum[0].shape[0]

        ES_data1 = -np.zeros((n_classes * batch_size, n_classes))
        for j in range(n_classes):
            ES_data1[j*batch_size:(j+1)*batch_size, j] = 1
        ES_data1 = Variable(torch.tensor(ES_data1).type(FloatTensor))
        ES_data2 = -np.zeros((n_classes * batch_size, n_classes))
        ES_data2[np.arange(n_classes*batch_size),np.random.randint(n_classes, size=n_classes*batch_size)] = 1
        ES_data2 = Variable(torch.tensor(ES_data2).type(FloatTensor))

        optimizer_Dec.zero_grad()
        optimizer_EC.zero_grad()

        loss1_data1 = torch.cat(datum, dim=0)
        loss4 = mse_loss(EC(loss1_data1), EC(Dec(EC(loss1_data1), ES_data2)))
        ae_loss = mse_loss(Dec(EC(loss1_data1), ES_data1), loss1_data1)

        all_loss = (lambda_co * loss4) + (lambda_rc * ae_loss)
        all_loss.backward()

        optimizer_Dec.step()
        optimizer_EC.step()

    print(
        "[Epoch %d/%d] [Reconstruction loss: %f] [Cotent loss: %f]"
        % (epoch+1, n_epochs,
           ae_loss.item(),
           loss4.item(),
          )
    )

[Epoch 1/150] [Reconstruction loss: 0.846651] [Cotent loss: 0.005998]
[Epoch 2/150] [Reconstruction loss: 0.708342] [Cotent loss: 0.014234]
[Epoch 3/150] [Reconstruction loss: 0.451807] [Cotent loss: 0.005853]
[Epoch 4/150] [Reconstruction loss: 0.366602] [Cotent loss: 0.003486]
[Epoch 5/150] [Reconstruction loss: 0.342170] [Cotent loss: 0.002818]
[Epoch 6/150] [Reconstruction loss: 0.323349] [Cotent loss: 0.005186]
[Epoch 7/150] [Reconstruction loss: 0.302919] [Cotent loss: 0.005173]
[Epoch 8/150] [Reconstruction loss: 0.281084] [Cotent loss: 0.003649]
[Epoch 9/150] [Reconstruction loss: 0.252270] [Cotent loss: 0.003011]
[Epoch 10/150] [Reconstruction loss: 0.242686] [Cotent loss: 0.002927]
[Epoch 11/150] [Reconstruction loss: 0.223873] [Cotent loss: 0.002698]
[Epoch 12/150] [Reconstruction loss: 0.219147] [Cotent loss: 0.002932]
[Epoch 13/150] [Reconstruction loss: 0.208345] [Cotent loss: 0.002996]
[Epoch 14/150] [Reconstruction loss: 0.175432] [Cotent loss: 0.002670]
[Epoch 15/150] 

[Epoch 117/150] [Reconstruction loss: 0.052600] [Cotent loss: 0.000572]
[Epoch 118/150] [Reconstruction loss: 0.050767] [Cotent loss: 0.000545]
[Epoch 119/150] [Reconstruction loss: 0.048986] [Cotent loss: 0.000624]
[Epoch 120/150] [Reconstruction loss: 0.048873] [Cotent loss: 0.000549]
[Epoch 121/150] [Reconstruction loss: 0.048319] [Cotent loss: 0.000572]
[Epoch 122/150] [Reconstruction loss: 0.048933] [Cotent loss: 0.000537]
[Epoch 123/150] [Reconstruction loss: 0.048944] [Cotent loss: 0.000541]
[Epoch 124/150] [Reconstruction loss: 0.048938] [Cotent loss: 0.000567]
[Epoch 125/150] [Reconstruction loss: 0.047996] [Cotent loss: 0.000553]
[Epoch 126/150] [Reconstruction loss: 0.046778] [Cotent loss: 0.000639]
[Epoch 127/150] [Reconstruction loss: 0.046458] [Cotent loss: 0.000600]
[Epoch 128/150] [Reconstruction loss: 0.047169] [Cotent loss: 0.000555]
[Epoch 129/150] [Reconstruction loss: 0.048502] [Cotent loss: 0.000533]
[Epoch 130/150] [Reconstruction loss: 0.048079] [Cotent loss: 0.

In [51]:
Dec.eval()
EC.eval()

with torch.no_grad():
    data = Variable(FloatTensor(scd.dataset[0]))
    label = np.full((len(scd.dataset[0]),1), batches[orders[0]])
    static_sample = EC(data)
    transform_data = static_sample.cpu().detach().numpy()
    for j in range(1, len(scd.dataset)):
        data = Variable(FloatTensor(scd.dataset[j]))
        static_sample = EC(data)
        fake_data = static_sample.cpu().detach().numpy()
        fake_label = np.full((len(scd.dataset[j]),1), batches[orders[j]])
        transform_data, label = cat_data(transform_data, fake_data, [label, fake_label])

In [52]:
transform_data.shape

(238, 256)

In [53]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    random.seed(seed)

In [54]:
def calculate_gradient_penalty(real_data, fake_data, D):
    eta = torch.FloatTensor(real_data.size(0),1).uniform_(0,1)
    eta = eta.expand(real_data.size(0), real_data.size(1))
    cuda = True if torch.cuda.is_available() else False
    if cuda:
        eta = eta.cuda()
    else:
        eta = eta

    interpolated = eta * real_data + ((1 - eta) * fake_data)

    if cuda:
        interpolated = interpolated.cuda()
    else:
        interpolated = interpolated

    # define it to calculate gradient
    interpolated = Variable(interpolated, requires_grad=True)

    # calculate probability of interpolated examples
    prob_interpolated = D(interpolated)

    # calculate gradients of probabilities with respect to examples
    gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(
                               prob_interpolated.size()).cuda() if cuda else torch.ones(
                               prob_interpolated.size()),
                           create_graph=True, retain_graph=True)[0]

    grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return grad_penalty


In [55]:
def normalize(data: np.float32) -> np.float32:
    norm = data#(np.exp2(data)-1)
    return norm

In [56]:
def acquire_pairs(X, Y, k, metric):
    X = normalize(X)
    Y = normalize(Y)

    f = X.shape[1]
    t1 = AnnoyIndex(f, metric)
    t2 = AnnoyIndex(f, metric)
    for i in range(len(X)):
        t1.add_item(i, X[i])
    for i in range(len(Y)):
        t2.add_item(i, Y[i])
    t1.build(10)
    t2.build(10)

    mnn_mat = np.bool8(np.zeros((len(X), len(Y))))
    sorted_mat = np.array([t2.get_nns_by_vector(item, k) for item in X])
    for i in range(len(sorted_mat)):
        mnn_mat[i,sorted_mat[i]] = True
    _ = np.bool8(np.zeros((len(X), len(Y))))
    sorted_mat = np.array([t1.get_nns_by_vector(item, k) for item in Y])
    for i in range(len(sorted_mat)):
        _[sorted_mat[i],i] = True
    mnn_mat = np.logical_and(_, mnn_mat)
    pairs = [(x, y) for x, y in zip(*np.where(mnn_mat>0))]
    return pairs

In [57]:
def create_pairs_dict(pairs):
    pairs_dict = {}
    for x,y in pairs:
        if x not in pairs_dict.keys():
            pairs_dict[x] = [y]
        else:
            pairs_dict[x].append(y)
    return pairs_dict

In [58]:
class ScDataset(Dataset):
    def __init__(self, n_sample=3000):
        self.dataset = []
        self.cali_dataset = []
        self.variable = None
        self.anchor_index = 0
        self.query_index = 1
        self.pairs = None
        self.labels = None
        self.transform = None
        self.sample = None
        self.metric = 'euclidean'
        self.k1 = None
        self.k2 = None
        self.n_sample = n_sample


    def change_dataset(self, index: int=1):
        self.query_index = index


    def acquire_anchor(self, index: int=0):
        self.anchor_index = index


    def calculate_mnn_pairs(self):
        tmp = np.arange(len(self.dataset[self.anchor_index]))
        np.random.shuffle(tmp)
        self.sample = self.cali_dataset[self.anchor_index][tmp[:self.n_sample]]
        ####
        tmp2 = np.arange(len(self.dataset[self.query_index]))
        np.random.shuffle(tmp2)
        self.query_sample = self.cali_dataset[self.query_index][tmp2[:self.n_sample]]
        ####
        
        if (self.k1 is None) or (self.k2 is None):
            self.k2 = int(min(len(self.sample), len(self.query_sample))/100)
            self.k1 = max(int(self.k2/2), 1)
        
        print('Calculating Anchor Pairs...')
        anchor_pairs = acquire_pairs(self.sample, self.sample, self.k1, self.metric)
        print('Calculating Query Pairs...')
        query_pairs = acquire_pairs(self.query_sample, self.query_sample, self.k1, self.metric)
        print('Calculating KNN Pairs...')
        pairs = acquire_pairs(self.sample, self.query_sample, self.k1, self.metric)
        print('Calculating Random Walk Pairs...')
        anchor_pairs_dict = create_pairs_dict(anchor_pairs)
        query_pairs_dict = create_pairs_dict(query_pairs)
        pair_plus = []
        for x, y in pairs:
            start = (x, y)
            for i in range(50):
                pair_plus.append(start)
                start = (random.choice(anchor_pairs_dict[start[0]]), random.choice(query_pairs_dict[start[1]]))

        self.datasetA = self.dataset[self.query_index][tmp2[:self.n_sample]][[y for x,y in pair_plus], :]
        self.datasetB = self.dataset[self.anchor_index][tmp[:self.n_sample]][[x for x,y in pair_plus], :]
        print('Done.')

    def __len__(self):
        return 10*1024


    def __getitem__(self, index):
        return random.choice(self.datasetA), random.choice(self.datasetB)

In [59]:
def train(scd, n_dataset, n_epochs):
    scd.change_dataset(n_dataset)
    scd.calculate_mnn_pairs()

    n_epochs = n_epochs
    n_classes = 2
    data_size = scd.dataset[0].shape[1]
    lr = 0.0002
    b1 = 0.5
    b2 = 0.999
    latent_dim = 256
    n_critic = 100


    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

    dataloader = DataLoader(
        dataset = scd,
        batch_size=1024,
    )


    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.relu = nn.ReLU(inplace=True)
            self.encoder = nn.Sequential(
                nn.Linear(data_size, 1024),
                nn.BatchNorm1d(1024),
                Mish(),
                nn.Linear(1024, 512),
                nn.BatchNorm1d(512),
                Mish(),
                nn.Linear(512, latent_dim),
                nn.BatchNorm1d(latent_dim),
            )
            self.decoder = nn.Sequential(
                nn.Linear(latent_dim, 512),
                nn.BatchNorm1d(512),
                Mish(),
                nn.Linear(512, 1024),
                nn.BatchNorm1d(1024),
                Mish(),
                nn.Linear(1024, data_size),
            )

        def forward(self, x):
            latent_data = self.encoder(x)
            gen_data = self.decoder(latent_data)
            return self.relu(gen_data + x)


    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(data_size, 512),
                Mish(),
                nn.Linear(512, 512),
                Mish(),
            )

            # Output layers
            self.adv_layer = nn.Sequential(nn.Linear(512, 1))

        def forward(self, data):
            out = self.model(data)
            validity = self.adv_layer(out)
            return validity

    # Initialize generator and discriminator
    G_AB = Generator()
    D_B = Discriminator()

    if cuda:
        G_AB.cuda()
        D_B.cuda()

    # Initialize weights
    G_AB.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

    optimizer_G_AB = torch.optim.Adam(G_AB.parameters(), lr=lr, betas=(b1, b2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

    for epoch in range(n_epochs):
        G_AB.train()
        for i, (data_A, data_B) in enumerate(dataloader):
            batch_size = data_A.shape[0]

            # Configure input
            real_data = Variable(data_B.type(FloatTensor))

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D_B.zero_grad()
            z = Variable(data_A.type(FloatTensor))
            gen_data = G_AB(z)


            # Loss for real images
            real_validity  = D_B(real_data)
            fake_validity  = D_B(gen_data)


            # Compute W-div gradient penalty
            div_gp = calculate_gradient_penalty(real_data, gen_data, D_B)

            # Adversarial loss
            db_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + 10*div_gp
            db_loss.backward()
            optimizer_D_B.step()


            # -----------------
            #  Train Generator
            # -----------------

            if i % n_critic == 0:
                optimizer_G_AB.zero_grad()
                z = Variable(data_A.type(FloatTensor), requires_grad=True)
                gen_data = G_AB(z)
                fake_validity = D_B(gen_data)
                gab_loss = -torch.mean(fake_validity)
                gab_loss.backward()

                optimizer_G_AB.step()


        # --------------
        # Log Progress
        # --------------

        print(
            "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch+1, n_epochs,
               db_loss.item(),
               gab_loss.item(),
              )
        )


    G_AB.eval()
    with torch.no_grad():
        z = Variable(FloatTensor(scd.dataset[scd.query_index]))
        static_sample = G_AB(z)
        fake_data = static_sample.cpu().detach().numpy()
    return fake_data

In [60]:
cali_batches = sorted(list(set(label.T[0])))

In [61]:
scd1 = scd

In [62]:
len(label.T[0])

238

In [63]:
scd = ScDataset(len(label.T[0]))

In [64]:
scd.metric = 'angular'
scd.k1 = None
scd.k2 = None
scd.variable = scd1.variable

In [65]:
print('Orders:','<-'.join(batches[i] for i in orders))

Orders: Other<-Qin (2012)<-Jie (2017)<-Zhang (2015)<-Feng (2015)<-Nielsen (2014)<-Karlsson (2013)


In [66]:
scd.dataset = [adata_values[i] for i in orders]

In [67]:
cali_adata = pd.DataFrame(transform_data)
cali_adata.insert(0,'batch',label.T[0])
cali_adata

Unnamed: 0,batch,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
0,Other,-0.085109,-0.134151,0.562397,0.178173,0.503652,0.152453,1.004732,-0.494178,0.846799,...,-0.206212,-0.292491,0.938009,-0.198684,0.148335,-0.127635,-0.475556,-0.245269,0.212345,-0.058509
1,Other,-0.066880,0.033022,0.451024,-0.130119,0.366637,0.132086,0.814253,-0.420701,0.477716,...,-0.184468,-0.061697,0.521373,-0.030952,0.058610,-0.096027,-0.416868,-0.291197,0.256635,0.010130
2,Other,0.131742,0.205865,0.154822,-0.044801,0.363387,0.645000,0.033453,-0.095347,0.352739,...,-0.426277,0.131419,0.643652,-0.103064,-0.056633,0.062488,-0.118545,-0.042113,0.321148,-0.249054
3,Other,0.440696,0.207748,0.015435,0.314654,0.415750,0.116383,-0.178496,0.954549,-0.210281,...,0.153311,0.486356,-0.465357,0.027676,0.003930,-0.515607,-0.315821,0.166881,-0.616933,-0.133195
4,Other,-0.039332,0.469693,-0.258113,-0.486127,0.570958,0.008927,0.070328,0.222047,-0.225921,...,0.318291,0.134490,-0.485869,0.106637,0.025186,-0.305881,-0.149542,0.202563,-0.011708,-0.427119
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
233,Karlsson (2013),-0.054899,-0.114295,-0.299255,-0.344850,0.778089,-0.352746,0.421063,-0.019471,-0.538001,...,0.243075,0.257538,-0.435933,0.257206,0.388614,-0.607854,-0.632491,0.204100,-0.513940,-0.393637
234,Karlsson (2013),-0.143413,0.367240,-0.059782,-0.095537,0.402280,0.330589,0.441663,-0.357425,0.011872,...,-0.389118,0.032960,0.339075,0.154442,0.070765,0.136696,-0.065614,-0.429427,-0.005799,-0.199758
235,Karlsson (2013),-0.352752,0.053045,0.308048,-0.242014,0.201634,0.339689,0.982311,-0.282756,0.772103,...,-0.304490,0.122337,0.465672,-0.062403,0.029476,-0.279917,-0.033707,-0.193340,-0.043484,-0.154153
236,Karlsson (2013),0.132263,0.278398,0.445930,0.102948,0.127591,0.989901,1.062595,0.128021,0.807491,...,-0.877912,0.613692,-0.180695,-0.038836,-0.797772,-0.438369,-0.494734,0.132233,-0.355499,-0.171551


In [68]:
cali_adata_values = [np.array(cali_adata[cali_adata['batch'] == batch].iloc[:,1:]) for batch in batches]
cali_orders = orders

In [69]:
scd.cali_dataset = [cali_adata_values[i] for i in cali_orders]

In [70]:
scd.transform = np.copy(scd.dataset[scd.anchor_index])

In [71]:
for i in range(1, len(scd.dataset)):
    print(f'Merging dataset {batches[orders[i]]} to {batches[orders[0]]}')
    fake_data = train(scd, i, n_epochs=n_epochs)
    scd.transform = np.r_[scd.transform, fake_data]

Merging dataset Qin (2012) to Other
Calculating Anchor Pairs...
Calculating Query Pairs...
Calculating KNN Pairs...
Calculating Random Walk Pairs...
Done.
[Epoch 1/150] [D loss: 2.533240] [G loss: 0.098595]
[Epoch 2/150] [D loss: -4.233782] [G loss: 1.347926]
[Epoch 3/150] [D loss: -7.419712] [G loss: 5.422555]
[Epoch 4/150] [D loss: -9.769296] [G loss: 8.705208]
[Epoch 5/150] [D loss: -11.873896] [G loss: 10.311496]
[Epoch 6/150] [D loss: -13.678122] [G loss: 10.984254]
[Epoch 7/150] [D loss: -16.513418] [G loss: 12.402990]
[Epoch 8/150] [D loss: -19.464573] [G loss: 12.753711]
[Epoch 9/150] [D loss: -25.382742] [G loss: 14.984046]
[Epoch 10/150] [D loss: -28.204611] [G loss: 17.044416]
[Epoch 11/150] [D loss: -29.211567] [G loss: 16.691614]
[Epoch 12/150] [D loss: -30.491737] [G loss: 18.537552]
[Epoch 13/150] [D loss: -29.508541] [G loss: 16.139851]
[Epoch 14/150] [D loss: -28.974674] [G loss: 15.310011]
[Epoch 15/150] [D loss: -29.769655] [G loss: 14.738556]
[Epoch 16/150] [D loss:

[Epoch 146/150] [D loss: -2.866333] [G loss: -2.647931]
[Epoch 147/150] [D loss: -2.840258] [G loss: -3.156275]
[Epoch 148/150] [D loss: -3.056214] [G loss: -3.244450]
[Epoch 149/150] [D loss: -2.784029] [G loss: -3.209890]
[Epoch 150/150] [D loss: -2.727530] [G loss: -2.559269]
Merging dataset Jie (2017) to Other
Calculating Anchor Pairs...
Calculating Query Pairs...
Calculating KNN Pairs...
Calculating Random Walk Pairs...
Done.
[Epoch 1/150] [D loss: 2.689025] [G loss: 0.013690]
[Epoch 2/150] [D loss: -3.373254] [G loss: 0.747743]
[Epoch 3/150] [D loss: -5.885541] [G loss: 3.773993]
[Epoch 4/150] [D loss: -7.500936] [G loss: 5.907306]
[Epoch 5/150] [D loss: -8.758796] [G loss: 6.721192]
[Epoch 6/150] [D loss: -10.825097] [G loss: 5.680718]
[Epoch 7/150] [D loss: -12.718515] [G loss: 5.525244]
[Epoch 8/150] [D loss: -16.577427] [G loss: 5.219119]
[Epoch 9/150] [D loss: -19.577354] [G loss: 6.464717]
[Epoch 10/150] [D loss: -24.283176] [G loss: 6.644366]
[Epoch 11/150] [D loss: -23.19

[Epoch 141/150] [D loss: -4.905488] [G loss: -1.503639]
[Epoch 142/150] [D loss: -5.471201] [G loss: -1.469045]
[Epoch 143/150] [D loss: -6.076669] [G loss: -1.459521]
[Epoch 144/150] [D loss: -5.422449] [G loss: -2.708526]
[Epoch 145/150] [D loss: -5.371319] [G loss: -3.036574]
[Epoch 146/150] [D loss: -5.499758] [G loss: -2.813015]
[Epoch 147/150] [D loss: -5.046124] [G loss: -1.811972]
[Epoch 148/150] [D loss: -5.286959] [G loss: -0.432054]
[Epoch 149/150] [D loss: -5.205064] [G loss: 0.267814]
[Epoch 150/150] [D loss: -4.561729] [G loss: -1.244957]
Merging dataset Zhang (2015) to Other
Calculating Anchor Pairs...
Calculating Query Pairs...
Calculating KNN Pairs...
Calculating Random Walk Pairs...
Done.
[Epoch 1/150] [D loss: 2.334821] [G loss: -0.021755]
[Epoch 2/150] [D loss: -3.830316] [G loss: 0.740123]
[Epoch 3/150] [D loss: -6.736512] [G loss: 4.098629]
[Epoch 4/150] [D loss: -8.365829] [G loss: 6.550574]
[Epoch 5/150] [D loss: -10.246868] [G loss: 6.995980]
[Epoch 6/150] [D l

[Epoch 135/150] [D loss: -5.077820] [G loss: -2.774669]
[Epoch 136/150] [D loss: -5.040736] [G loss: -3.120801]
[Epoch 137/150] [D loss: -4.679348] [G loss: -3.787840]
[Epoch 138/150] [D loss: -5.269039] [G loss: -3.827911]
[Epoch 139/150] [D loss: -5.338739] [G loss: -2.537917]
[Epoch 140/150] [D loss: -5.003823] [G loss: -2.842631]
[Epoch 141/150] [D loss: -4.865599] [G loss: -1.737848]
[Epoch 142/150] [D loss: -5.049326] [G loss: -2.444032]
[Epoch 143/150] [D loss: -5.128186] [G loss: -3.469351]
[Epoch 144/150] [D loss: -5.009744] [G loss: -2.283162]
[Epoch 145/150] [D loss: -4.946242] [G loss: -2.638453]
[Epoch 146/150] [D loss: -5.021830] [G loss: -3.847171]
[Epoch 147/150] [D loss: -5.559969] [G loss: -2.009144]
[Epoch 148/150] [D loss: -5.268880] [G loss: -1.509955]
[Epoch 149/150] [D loss: -5.335052] [G loss: -1.712147]
[Epoch 150/150] [D loss: -5.200483] [G loss: -0.896362]
Merging dataset Feng (2015) to Other
Calculating Anchor Pairs...
Calculating Query Pairs...
Calculating 

[Epoch 130/150] [D loss: -6.600838] [G loss: -6.963813]
[Epoch 131/150] [D loss: -6.421518] [G loss: -8.650974]
[Epoch 132/150] [D loss: -5.956381] [G loss: -8.904876]
[Epoch 133/150] [D loss: -6.082608] [G loss: -9.420974]
[Epoch 134/150] [D loss: -5.569641] [G loss: -9.761418]
[Epoch 135/150] [D loss: -4.801637] [G loss: -8.517113]
[Epoch 136/150] [D loss: -4.965639] [G loss: -8.964037]
[Epoch 137/150] [D loss: -5.450614] [G loss: -8.043010]
[Epoch 138/150] [D loss: -4.392063] [G loss: -8.361324]
[Epoch 139/150] [D loss: -4.942701] [G loss: -7.576364]
[Epoch 140/150] [D loss: -4.114256] [G loss: -6.028533]
[Epoch 141/150] [D loss: -4.337460] [G loss: -4.477823]
[Epoch 142/150] [D loss: -4.427132] [G loss: -5.862222]
[Epoch 143/150] [D loss: -4.237522] [G loss: -5.386871]
[Epoch 144/150] [D loss: -4.410846] [G loss: -5.443810]
[Epoch 145/150] [D loss: -3.914043] [G loss: -7.022650]
[Epoch 146/150] [D loss: -4.435160] [G loss: -5.475119]
[Epoch 147/150] [D loss: -4.096092] [G loss: -3.

[Epoch 125/150] [D loss: -2.188093] [G loss: -2.922372]
[Epoch 126/150] [D loss: -2.411885] [G loss: -4.346983]
[Epoch 127/150] [D loss: -2.395798] [G loss: -4.284204]
[Epoch 128/150] [D loss: -2.083354] [G loss: -4.425739]
[Epoch 129/150] [D loss: -2.193103] [G loss: -3.229205]
[Epoch 130/150] [D loss: -2.817819] [G loss: -3.073701]
[Epoch 131/150] [D loss: -2.562043] [G loss: -2.481215]
[Epoch 132/150] [D loss: -2.456436] [G loss: -2.461132]
[Epoch 133/150] [D loss: -2.405882] [G loss: -2.335457]
[Epoch 134/150] [D loss: -2.017991] [G loss: -2.741501]
[Epoch 135/150] [D loss: -1.923932] [G loss: -3.661996]
[Epoch 136/150] [D loss: -2.407387] [G loss: -4.365159]
[Epoch 137/150] [D loss: -2.607802] [G loss: -3.428320]
[Epoch 138/150] [D loss: -2.548063] [G loss: -3.201672]
[Epoch 139/150] [D loss: -1.991327] [G loss: -3.237089]
[Epoch 140/150] [D loss: -2.508366] [G loss: -3.669039]
[Epoch 141/150] [D loss: -2.129419] [G loss: -1.999306]
[Epoch 142/150] [D loss: -2.230400] [G loss: -1.

[Epoch 122/150] [D loss: -2.236860] [G loss: -0.549646]
[Epoch 123/150] [D loss: -2.652869] [G loss: 0.017008]
[Epoch 124/150] [D loss: -2.343039] [G loss: 0.630740]
[Epoch 125/150] [D loss: -2.413611] [G loss: 0.526171]
[Epoch 126/150] [D loss: -2.587402] [G loss: -0.997939]
[Epoch 127/150] [D loss: -2.289528] [G loss: -1.386266]
[Epoch 128/150] [D loss: -2.288023] [G loss: -0.968568]
[Epoch 129/150] [D loss: -1.924323] [G loss: -0.832170]
[Epoch 130/150] [D loss: -2.461164] [G loss: -0.640040]
[Epoch 131/150] [D loss: -2.382743] [G loss: -1.032893]
[Epoch 132/150] [D loss: -2.432757] [G loss: -1.861722]
[Epoch 133/150] [D loss: -2.430608] [G loss: -1.707263]
[Epoch 134/150] [D loss: -2.164179] [G loss: -1.108292]
[Epoch 135/150] [D loss: -2.446087] [G loss: -0.780252]
[Epoch 136/150] [D loss: -2.127028] [G loss: -0.095571]
[Epoch 137/150] [D loss: -2.468103] [G loss: 0.317473]
[Epoch 138/150] [D loss: -2.441617] [G loss: 0.219344]
[Epoch 139/150] [D loss: -2.655423] [G loss: -0.56591

In [72]:
output = pd.DataFrame(scd.transform)
output.columns=train_p.columns.tolist()[2:]

In [73]:
out_index =[]
for i in orders:
    out_index += train_p[train_p['batch']==batches[i]].index.tolist() 

In [74]:
output.insert(0,'ID',out_index)

In [75]:
list(ec_obs_names)==out_index

True

In [76]:
metadata.insert(0,'ID',train.index.tolist())

In [77]:
output.insert(1,'batch',output['ID'].replace(train['batch'].to_dict()))

In [78]:
output.insert(1,'Type',output['ID'].str.split('_',expand=True)[0])

In [79]:
output

Unnamed: 0,ID,Type,batch,s__Abiotrophia_defectiva,s__Acetobacter_unclassified,s__Achromobacter_piechaudii,s__Achromobacter_unclassified,s__Achromobacter_xylosoxidans,s__Acidaminococcus_fermentans,s__Acidaminococcus_intestini,...,s__Vibrio_furnissii,s__Vibrio_kanaloae,s__Weissella_cibaria,s__Weissella_confusa,s__Weissella_halotolerans,s__Weissella_koreensis,s__Weissella_paramesenteroides,s__Weissella_unclassified,s__Wohlfahrtiimonas_chitiniclastica,s__Yersinia_enterocolitica
0,Overweight_3619,Overweight,Zeller (2014),0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
1,Overweight_3620,Overweight,Zeller (2014),0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
2,Overweight_3621,Overweight,Zeller (2014),0.0,0.000000,0.0,0.0,0.0,0.025272,0.00000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
3,Overweight_3622,Overweight,Zeller (2014),0.0,0.000000,0.0,0.0,0.0,0.000000,0.00084,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
4,Overweight_3623,Overweight,Zeller (2014),0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
233,Overweight_3691,Overweight,Karlsson (2013),0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
234,Overweight_3692,Overweight,Karlsson (2013),0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.0,0.0,0.014736,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
235,Overweight_3693,Overweight,Karlsson (2013),0.0,0.021887,0.0,0.0,0.0,0.000000,0.00000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.034965,0.0,0.0
236,Overweight_3694,Overweight,Karlsson (2013),0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0


In [80]:
output.to_csv('output'+phenotype+'.csv')

In [81]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import umap

In [82]:
def data2umap(data, n_pca=0):
    if n_pca > 0:
        pca = PCA(n_components=n_pca)
        embedding = pca.fit_transform(data)
    else:
        embedding = data
    embedding_ = umap.UMAP(
        n_neighbors=30,
        min_dist=0.3,
        metric='cosine',
        n_components = 2,
        learning_rate = 1.0,
        spread = 1.0,
        set_op_mix_ratio = 1.0,
        local_connectivity = 1,
        repulsion_strength = 1,
        negative_sample_rate = 5,
        angular_rp_forest = False,
        verbose = False
    ).fit_transform(embedding)
    return embedding_
def umap_plot(data, hue, title, save_path):
    import seaborn as sns
    fig = sns.lmplot(
        x = 'UMAP_1',
        y = 'UMAP_2',
        data = data,
        fit_reg = False,
        legend = True,
        size = 9,
        hue = hue,
        palette = sns.color_palette("husl", len(set(hue))),
        scatter_kws = {'s':4, "alpha":0.6}
    )
    plt.title(title, weight='bold').set_fontsize('20')
    fig.savefig(save_path)
    plt.close()
def gplot(embedding_, batch_info, celltype_info, filename):
    test = pd.DataFrame(embedding_, columns=['UMAP_1', 'UMAP_2'])
    test['Label1'] = batch_info
    test['Label2'] = celltype_info
    title = f' '
    for i in range(1,3):
        hue = f'Label{i}'
        save_path = './pic/'+filename + f'{i}.png'
        umap_plot(test, hue, title, save_path)

In [83]:
output.iloc[:,3:]

Unnamed: 0,s__Abiotrophia_defectiva,s__Acetobacter_unclassified,s__Achromobacter_piechaudii,s__Achromobacter_unclassified,s__Achromobacter_xylosoxidans,s__Acidaminococcus_fermentans,s__Acidaminococcus_intestini,s__Acidaminococcus_sp_BV3L6,s__Acidaminococcus_sp_D21,s__Acidaminococcus_sp_HPA0509,...,s__Vibrio_furnissii,s__Vibrio_kanaloae,s__Weissella_cibaria,s__Weissella_confusa,s__Weissella_halotolerans,s__Weissella_koreensis,s__Weissella_paramesenteroides,s__Weissella_unclassified,s__Wohlfahrtiimonas_chitiniclastica,s__Yersinia_enterocolitica
0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
1,0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
2,0.0,0.000000,0.0,0.0,0.0,0.025272,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
3,0.0,0.000000,0.0,0.0,0.0,0.000000,0.00084,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
4,0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
233,0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
234,0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.014736,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
235,0.0,0.021887,0.0,0.0,0.0,0.000000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.034965,0.0,0.0
236,0.0,0.000000,0.0,0.0,0.0,0.000000,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0


In [84]:
train_p.iloc[:,2:]

Unnamed: 0,s__Abiotrophia_defectiva,s__Acetobacter_unclassified,s__Achromobacter_piechaudii,s__Achromobacter_unclassified,s__Achromobacter_xylosoxidans,s__Acidaminococcus_fermentans,s__Acidaminococcus_intestini,s__Acidaminococcus_sp_BV3L6,s__Acidaminococcus_sp_D21,s__Acidaminococcus_sp_HPA0509,...,s__Vibrio_furnissii,s__Vibrio_kanaloae,s__Weissella_cibaria,s__Weissella_confusa,s__Weissella_halotolerans,s__Weissella_koreensis,s__Weissella_paramesenteroides,s__Weissella_unclassified,s__Wohlfahrtiimonas_chitiniclastica,s__Yersinia_enterocolitica
Overweight_3579,0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3580,0.0,0.0,0.0,0.0,0.0,0.0,0.00653,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3581,0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.02384,0.10129,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3582,0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3583,0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Overweight_3812,0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3813,0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3814,0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0
Overweight_3815,0.0,0.0,0.0,0.0,0.0,0.0,0.00000,0.0,0.0,0.0,...,0.0,0.0,0.00000,0.00000,0.0,0.0,0.0,0.00000,0.0,0.0


In [85]:
embedding_ = data2umap(output.iloc[:,3:], n_pca=30)
gplot(embedding_, np.array(train_p.iloc[:,:]['batch']), np.array(output.iloc[:,:]['Type']), 'iMAP2-OW-GHMI')



In [86]:
embedding_ = data2umap(train_p.iloc[:2635,2:], n_pca=30)
gplot(embedding_, np.array(train_p.iloc[:,:]['batch']), np.array(output.iloc[:,:]['Type']), 'OW-GHMI')

