In [112]:
import pickle
from custom_nets.resnet import ResNet, train_model, evaluate
from torchsummary import summary
import torch
import pandas as pd

In [113]:
microbiome = pd.read_csv('../../data/raw/curated_metagenomics/relative_abundance.csv',index_col=0).transpose()
metadata = pd.read_csv('../../data/raw/curated_metagenomics/metadata.csv',index_col='sample_id',low_memory=False)

# %% [markdown]
# For this example we will try to classify disease from healthy based on microbiome. Disease is classified as diseased (according to the original data) and BMI<16 | BMI=>30. These are the boundaries of severe underweight and obesity.
study_name = 'QinN_2014'

#get stool samples
metadata = metadata.loc[metadata.body_site == 'stool',:]

#Add obesity disease tags to disease BMI
to_change = metadata.BMI>=30
metadata.loc[to_change,'disease'] = 'obesity'

to_change = metadata.BMI<16
metadata.loc[to_change,'disease'] = 'severe_underweight'

#Add obesity disease tags to disease BMI
metadata = metadata.loc[metadata.BMI==metadata.BMI,:]

# Remove all disease NaNs
metadata = metadata.loc[metadata.disease=='healthy',:]

#
to_keep = metadata.age_category != 'newborn'
metadata = metadata.loc[to_keep,:]

# Get the overlapping set of samples between metadata and microbiome data
overlapping_samples = list(set(metadata.index) & set(microbiome.index))
microbiome= microbiome.loc[overlapping_samples,:]
metadata = metadata.loc[overlapping_samples,:]


base_metadata = metadata.loc[metadata.study_name != study_name,:]
base_microbiome = microbiome.loc[base_metadata.index,:]

target_metadata = metadata.loc[metadata.study_name == study_name,:]
target_microbiome = microbiome.loc[target_metadata.index,:]



In [114]:
with open('../../data/raw/curated_metagenomics/resnet_params.pkl', 'rb') as fp:
    params_H = pickle.load(fp)

with open('../../data/raw/curated_metagenomics/resnet_params_YachidaS_2019.pkl', 'rb') as fp:
    params_Y = pickle.load(fp)

with open('../../data/raw/curated_metagenomics/resnet_params_qin.pkl', 'rb') as fp:
    params_Q = pickle.load(fp)

with open('../../data/raw/curated_metagenomics/resnet_params_all.pkl', 'rb') as fp:
    params_all = pickle.load(fp)

In [115]:
random = torch.zeros(size=(148,2047)).cuda()


In [116]:
model = ResNet(**sublist).cuda()
model = model.eval()

In [117]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [120]:
pickle_list = [params_H,
               params_Y,
               params_Q,
               params_all]
names_list = [
    'H','Y', 'Q', 'all'
]
for i,sublist in enumerate(pickle_list):
    print(names_list[i])
    print(sublist)
    model = ResNet(**sublist)
    for name,param in model.named_parameters():
        if ('head' in name) | ('last_normalization' in name) | ('first_layer' in name):
            param.requires_grad = True
        else:
            param.requires_grad = False

    print(count_parameters(model))
    print('\n')


H
{'d_numerical': 2047, 'd': 148, 'd_hidden_factor': 1, 'n_layers': 1, 'hidden_dropout': 0.36734979065758694, 'residual_dropout': 0.37222088474141257, 'd_out': 1}
303549


Y
{'d_numerical': 2047, 'd': 278, 'd_hidden_factor': 1, 'n_layers': 5, 'hidden_dropout': 0.16080099411546672, 'residual_dropout': 0.33468549816694226, 'd_out': 1}
570179


Q
{'d_numerical': 2047, 'd': 65, 'd_hidden_factor': 1, 'n_layers': 1, 'hidden_dropout': 0.41919453514817584, 'residual_dropout': 0.19197424191729415, 'd_out': 1}
133316


all
{'d_numerical': 2047, 'd': 97, 'd_hidden_factor': 2, 'n_layers': 1, 'hidden_dropout': 0.32422753112973324, 'residual_dropout': 0.2917433119476729, 'd_out': 1}
198948


