In [150]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re
%config Completer.use_jedi = False

In [161]:
data = pd.read_csv('metagenomics/abundance_stoolsubset.csv', dtype='str')
cols = data.columns

# def warn(*args, **kwargs):
#     pass
# import warnings
# warnings.warn = warn

In [None]:
data['disease'].value_counts()

In [3]:
### preprocess data ###

# filter for healthy samples without disease
# processed = data[(data['gender'].isin(['male', 'female'])) & (data['disease'] == 'n')].copy()


# filter for Type 2 diabetes and randomly select some controls then concatenate them and shuffle

temp1 = data[data['disease'].isin(['t2d', 'impaired_glucose_tolerance'])].copy()
# temp1 = data[data['disease'].isin(['cirrhosis'])].copy()

processed = pd.concat([temp1, data[(data['disease'] == 'n')].sample(300)]).sample(frac=1).reset_index(drop = True)

# drop columns that are not needed
to_drop = list(cols[2:4]) + list(cols[8:20]) + list(cols[21:211])
processed.drop(columns = to_drop, inplace = True)

# 2nd filtering round

bacteria = list(processed.columns)[7:]

s = re.compile(r's__\w+$')

not_species = [i for i in bacteria if not s.search(i)]

processed.drop(columns = not_species, inplace=True)


# create a new column as labels
processed['label'] = processed['disease'].apply(lambda x:  0 if (x == 'n') else 1)

processed.reset_index(inplace=True, drop=True)

species = processed.columns[7:833]

processed[species] = processed[species].apply(pd.to_numeric)

In [4]:
# rename feature names to shorter one

s = re.compile(r's__(\w+)')

short = []

for i in list(species):
    short.append( s.search(i).group(1).replace("_", " "))
    
new = dict(zip(species, short))

processed.rename(columns = new, inplace=True)

In [5]:
# data scaling

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn import metrics
from sklearn.metrics import balanced_accuracy_score

x = processed.iloc[:, 7:833]

scaler = StandardScaler()
x = StandardScaler().fit_transform(x)

y = processed['label']

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)
                                                   

In [6]:
def model_performance(model, x, y, n=50):
    ''' does 50 train test splits and calculates model metrics'''
    from sklearn import metrics
    
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)
    
    model.fit(x_train, y_train)
    
    report = metrics.classification_report(y_test, model.predict(x_test), output_dict=True)
    
    report_std = {}
    
    for i in report.keys():
        if i == 'accuracy':
            report_std[i] = []
        else:
            report_std[i] = {}
            for j in report[i].keys():
                report_std[i][j] = []

    
    scoring = ['balanced_accuracy', 'accuracy', 'f1','precision','recall','roc_auc']

    cross_val ={'balanced_accuracy':0, 'accuracy':0, 'f1':0,'precision':0,'recall':0,'roc_auc':0}
        
    cross_val_std ={'balanced_accuracy':[], 'accuracy':[], 'f1':[],'precision':[],'recall':[],'roc_auc':[]}
    
    for count in range(n):
        
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)
        
        model.fit(x_train, y_train)
        
        test = metrics.classification_report(y_test, model.predict(x_test), output_dict=True)
        
        for i in report.keys():
            if i == 'accuracy':
                report[i] += test[i]
                report_std[i].append(test[i])
            else:
                for j in report[i].keys():
                    report[i][j] += test[i][j]
                    report_std[i][j].append(test[i][j])
       
        for p in scoring:
            scores = cross_val_score(model, x_train, y_train, scoring = p)

            cross_val[p] += np.mean(scores)
            
            cross_val_std[p].append((np.mean(scores)))
            
    
    for i in scoring:
        print(i)
        print('mean: %0.3f' % (cross_val[i]/n))
        print('std dev: %0.3f' % (np.std(cross_val_std[i])))
        print()
        
    for i in report.keys():
        if i == 'accuracy':
            report[i] = report[i]/(n+1)
            report_std[i] = np.std(report_std[i])
        else:
            for j in report[i].keys():
                report[i][j] = report[i][j]/(n+1)
                report_std[i][j] = np.std(report_std[i][j])
        
    return report, report_std


# def model_performance(model, x, y):
#     ''' 20 fold cross val score'''
#     scoring = ['balanced_accuracy', 'accuracy', 'f1','precision','recall','roc_auc']
    
#     for p in scoring:
#         scores = cross_val_score(model, x, y, cv=20, scoring = p)
#         print(p)
#         print('mean: %0.3f' % np.mean(scores))
#         print('std dev: %0.2f' % np.std(scores))
#         print()

In [10]:
from collections import Counter

c = Counter(pop)

test = dict(c)

test2 = dict(sorted(test.items(), key=lambda item: -item[1]))

import csv

with open('dict_diabetes.csv', 'w') as csv_file:  
    writer = csv.writer(csv_file)
    for key, value in test2.items():
       writer.writerow([key, value])

### Random Forest ###

In [7]:
from sklearn.ensemble import RandomForestClassifier

forest = RandomForestClassifier()

forest_params = {'n_estimators':[i for i in range(110)[10::20]], 'max_depth':[j for j in range(81)[5::15]], 
                 'min_samples_leaf':list(np.linspace(0.001,0.25,num=10)), 'max_features':['sqrt','log2']}

forest_cv = GridSearchCV(forest, forest_params, scoring = 'balanced_accuracy', n_jobs=2).fit(x_train, y_train)

forest_cv.best_params_

{'max_depth': 20,
 'max_features': 'log2',
 'min_samples_leaf': 0.001,
 'n_estimators': 50}

In [8]:
### for diabetes ###

# forest = RandomForestClassifier(max_features = 'sqrt')

# forest_params = {'n_estimators':[i for i in range(120)[65:85:2]], 'max_depth':[j for j in range(100)[45:65:2]], 
#                  'min_samples_leaf':list(np.linspace(0.0001,0.002,num=10))}

# forest_cv = GridSearchCV(forest, forest_params, scoring = 'balanced_accuracy', n_jobs=2).fit(x_train, y_train)

# forest_cv.best_params_

### for cirrhosis ###

forest = RandomForestClassifier(max_features = 'log2')

forest_params = {'n_estimators':[i for i in range(80)[45:65:2]], 'max_depth':[j for j in range(50)[15:35:2]], 
                 'min_samples_leaf':list(np.linspace(0.0001,0.02,num=20))}

forest_cv = GridSearchCV(forest, forest_params, scoring = 'balanced_accuracy', n_jobs=2).fit(x_train, y_train)

forest_cv.best_params_

{'max_depth': 15, 'min_samples_leaf': 0.015810526315789476, 'n_estimators': 55}

In [10]:
### diabetes ###

from sklearn.ensemble import RandomForestClassifier

forest = RandomForestClassifier(n_estimators = 63, min_samples_leaf=0.00157, max_features='sqrt', max_depth = 69)

### cirrhosis ###
# from sklearn.ensemble import RandomForestClassifier

# forest = RandomForestClassifier(n_estimators = 55, min_samples_leaf=0.0158, max_features='sqrt', max_depth = 15)

# model_performance(forest, x_train, y_train)

In [None]:
report, report_std = model_performance(forest, x, y)

In [None]:
report

In [None]:
report_std

In [64]:
forest.fit(x_train, y_train)

# print(metrics.classification_report(y_test, forest.predict(x_test)))

RandomForestClassifier(max_depth=69, max_features='sqrt',
                       min_samples_leaf=0.00157, n_estimators=63)

### AdaBoost ###

In [11]:
from sklearn.ensemble import AdaBoostClassifier

ada = AdaBoostClassifier()

ada_params = {'n_estimators':[i for i in range(150)[50:150:10]]}

ada_cv = GridSearchCV(ada, ada_params, scoring = 'balanced_accuracy', n_jobs=2).fit(x_train, y_train)

ada_cv.best_params_

{'n_estimators': 120}

In [12]:
### diabetes ###

# ada = AdaBoostClassifier()

# ada_params = {'n_estimators':[i for i in range(150)[75:95]]}

# ada_cv = GridSearchCV(ada, ada_params, scoring = 'balanced_accuracy', n_jobs=2).fit(x_train, y_train)

# ada_cv.best_params_

### cirrhosis ###


ada = AdaBoostClassifier()

ada_params = {'n_estimators':[i for i in range(150)[115:135]]}

ada_cv = GridSearchCV(ada, ada_params, scoring = 'balanced_accuracy', n_jobs=2).fit(x_train, y_train)

ada_cv.best_params_

{'n_estimators': 125}

In [19]:
### diabetes ###

from sklearn.ensemble import AdaBoostClassifier

ada = AdaBoostClassifier(n_estimators = 80)

# model_performance(ada, x_train, y_train)


### cirrhosis ###

# from sklearn.ensemble import AdaBoostClassifier

# ada = AdaBoostClassifier(n_estimators = 125)

# model_performance(ada, x_train, y_train)

In [None]:
report, report_std = model_performance(ada, x, y)

In [None]:
report

In [None]:
report_std

In [65]:
ada.fit(x_train, y_train)

# print(metrics.classification_report(y_test, ada.predict(x_test)))

AdaBoostClassifier(n_estimators=80)

### Gradient Boosting ###

In [6]:
from sklearn.ensemble import GradientBoostingClassifier

grad = GradientBoostingClassifier()

grad_params = {'n_estimators':[i for i in range(200)[100:300:20]]}

grad_cv = GridSearchCV(grad, grad_params).fit(x_train, y_train)

grad_cv.best_params_

In [87]:
### diabetes ###
# grad = GradientBoostingClassifier()

# grad_params = {'n_estimators':[i for i in range(200)[115:135:2]], 'max_depth':[i for i in range(6)[3:8]]}

# grad_cv = GridSearchCV(grad, grad_params).fit(x_train, y_train)

# grad_cv.best_params_


grad = GradientBoostingClassifier()

grad_params = {'n_estimators':[i for i in range(200)[170:190:2]], 'max_depth':[i for i in range(6)[3:]]}

grad_cv = GridSearchCV(grad, grad_params).fit(x_train, y_train)

grad_cv.best_params_

{'max_depth': 4, 'n_estimators': 119}

In [16]:
### diabetes ###

# grad = GradientBoostingClassifier(n_estimators = 119)

# grad_params = {'max_depth':[i for i in range(12)[4:10]]}

# grad_cv = GridSearchCV(grad, grad_params).fit(x_train, y_train)

# grad_cv.best_params_

### cirrhosis ###

grad = GradientBoostingClassifier(max_depth = 4)

grad_params = {'n_estimators':[i for i in range(200)[110:130:2]]}

grad_cv = GridSearchCV(grad, grad_params).fit(x_train, y_train)

grad_cv.best_params_

{'n_estimators': 114}

In [17]:
### diabetes ###

# grad = GradientBoostingClassifier(max_depth=4, n_estimators = 119)

# grad_params = {'min_samples_split':[i for i in range(15)[2::2]], 'max_features':['sqrt', 'log2']}

# grad_cv = GridSearchCV(grad, grad_params).fit(x_train, y_train)

# grad_cv.best_params_

### cirrhosis ###

grad = GradientBoostingClassifier(max_depth=4, n_estimators = 114)

grad_params = {'min_samples_split':[i for i in range(10)[2:]], 'max_features':['sqrt', 'log2']}

grad_cv = GridSearchCV(grad, grad_params).fit(x_train, y_train)

grad_cv.best_params_

{'max_features': 'log2', 'min_samples_split': 9}

In [96]:
### diabetes ###

# grad = GradientBoostingClassifier(max_depth=4, n_estimators = 119, max_features='sqrt')

# grad_params = {'min_samples_split':[i for i in range(25)[4::2]]}

# grad_cv = GridSearchCV(grad, grad_params).fit(x_train, y_train)

# grad_cv.best_params_

{'min_samples_split': 10}

In [25]:
### diabetes ###

from sklearn.ensemble import GradientBoostingClassifier

grad = GradientBoostingClassifier(max_depth=4, n_estimators = 125, max_features='sqrt', min_samples_split=10)

# grad.fit(x_train, y_train)

# model_performance(grad, x_train, y_train)

### cirrhosis ###

# from sklearn.ensemble import GradientBoostingClassifier

# grad = GradientBoostingClassifier(max_depth=4, n_estimators = 114, max_features='log2', min_samples_split=9)

# grad.fit(x_train, y_train)

# model_performance(grad, x_train, y_train)

In [None]:
report, report_std = model_performance(grad, x, y)

In [None]:
report

In [None]:
report_std

In [66]:
grad.fit(x_train, y_train)

# print(metrics.classification_report(y_test, grad.predict(x_test)))

GradientBoostingClassifier(max_depth=4, max_features='sqrt',
                           min_samples_split=10, n_estimators=125)

In [None]:
fig, ax = plt.subplots(1,2, figsize = (22,12))

#plot ROC curve

metrics.plot_roc_curve(forest, x_test, y_test, ax = ax[0], name = "Random Forest", ls = "-.", color='r', linewidth = 2)
metrics.plot_roc_curve(ada, x_test, y_test, ax = ax[0], name = "Adaboost")
metrics.plot_roc_curve(grad, x_test, y_test, ax = ax[0], name = "Gradient Boosting", ls = '--', color='black', linewidth=2)
ax[0].tick_params(axis='both', labelsize= 14)
ax[0].set_xlabel('False Postive Rate', fontsize = 24)
ax[0].set_ylabel('True Postive Rate', fontsize = 24)
ax[0].text(-0.08, 1.065, "A", fontsize=24, fontweight='bold', va='top', ha='right')
ax[0].legend(loc = 'lower left', fontsize = 14)

#plot precision-recall curve

metrics.plot_precision_recall_curve(forest, x_test, y_test, ax = ax[1], name = "Random Forest", ls = "-.", color='r', linewidth = 2)
metrics.plot_precision_recall_curve(ada, x_test, y_test, ax = ax[1], name = "Adaboost")
metrics.plot_precision_recall_curve(grad, x_test, y_test, ax = ax[1], name = "Gradient Boosting", ls = '--', color='black', linewidth=2)
ax[1].set_xlabel('Recall', fontsize = 24)
ax[1].set_ylabel('Precision', fontsize = 24)
ax[1].legend(loc = 'lower left',fontsize=14)
ax[1].tick_params(axis='both', labelsize= 16)
ax[1].text(-0.08, 1.035, "B", fontsize=24, fontweight='bold', va='top', ha='right')


plt.tight_layout()

plt.savefig('diabetes_performance',dpi=300)
# plt.savefig('cirrhosis_performance',dpi=300)


#always put savefig before show(), if not will save empty image.
plt.show()

In [21]:
from sklearn.inspection import permutation_importance


grad_importance = permutation_importance(grad, x_test, y_test, scoring = 'balanced_accuracy', n_jobs = 2, n_repeats = 100, random_state = 42)


In [7]:
# 50 fold permutation importance


from sklearn.inspection import permutation_importance

from sklearn.ensemble import GradientBoostingClassifier

impt_score = np.zeros((826,))

for i in range(50):
    # split data
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)

    grad = GradientBoostingClassifier(max_depth=4, n_estimators = 125, max_features='sqrt', min_samples_split=10).fit(x_train, y_train)
    
    perm_impt = permutation_importance(grad, x_test, y_test, scoring='balanced_accuracy', n_jobs = 2)
    
    # filter dataset
    
    impt_score += perm_impt['importances_mean']
    
temp1 = pd.DataFrame(columns = ['features','impt'])
    
temp1['features'] = short
temp1['impt'] = impt_score/50
    
temp2 = temp1[temp1['impt'] > 0].copy()
temp2.sort_values(by='impt', ascending=False, inplace=True, ignore_index=True)

In [None]:
temp2[:10]

In [126]:
### diabetes ###

# dataframe to record the features and importance

# perm_impt = pd.DataFrame(columns = ['features','impt','impt_std'])

# perm_impt['features'] = short
# perm_impt['impt'] = forest_importance['importances_mean']
# perm_impt['impt_std'] = forest_importance['importances_std']

# diabetes_impt = perm_impt[perm_impt['impt'] > 0].copy()

# diabetes_impt.sort_values(by='impt', ascending=False, inplace=True, ignore_index=True)

# # export the csv file
# diabetes_impt.to_csv('diabetes_impt.csv', index=False)

In [22]:
### cirrhosis ###

# dataframe to record the features and importance

perm_impt = pd.DataFrame(columns = ['features','impt','impt_std'])

perm_impt['features'] = short
perm_impt['impt'] = grad_importance['importances_mean']
perm_impt['impt_std'] = grad_importance['importances_std']

cirrho_impt = perm_impt[perm_impt['impt'] > 0].copy()

cirrho_impt.sort_values(by='impt', ascending=False, inplace=True, ignore_index=True)

# export the csv file
cirrho_impt.to_csv('cirrhosis_impt.csv', index=False)

In [None]:
# diabetes_impt = pd.read_csv('diabets_impt.csv')

# diabetes_impt.head()

cirrho_impt = pd.read_csv('cirrhosis_impt.csv')

cirrho_impt.head()


In [None]:
### diabetes ####

#changing labels for control and obese

processed['label'].apply(lambda x: 'Diabetes' if x == 1 else 'Ctrl')

test = list(temp2['features'][:10])

# create a dataframe for mean abundance
means = processed.groupby(by = 'label').mean().iloc[:,6:-1]

diabetes_means = pd.DataFrame(columns = ['Bacteria','Ctrl','Diabetes'])

diabetes_means['Bacteria'] = diabetes_impt['features'][:10]
diabetes_means['Diabetes'] = list(means[test].iloc[1, :])
diabetes_means['Ctrl'] = list(means[test].iloc[0, :])



In [37]:
### cirrhosis ###

# processed['label'].apply(lambda x: 'Cirrhosis' if x == 1 else 'Ctrl')


# test = list(cirrho_impt['features'][:10])

# # create a dataframe for mean abundance
# means = processed.groupby(by = 'label').mean().iloc[:,6:-1]

# cirrho_means = pd.DataFrame(columns = ['Bacteria','Ctrl','Cirrhosis'])

# cirrho_means['Bacteria'] = cirrho_impt['features'][:10]
# cirrho_means['Cirrhosis'] = list(means[test].iloc[1, :])
# cirrho_means['Ctrl'] = list(means[test].iloc[0, :])


In [None]:
### diabetes ###

importance vs sample line

plt.figure(figsize=(12,8))

plt.plot(diabetes_impt['impt'])

plt.xlabel('Species', fontsize = 16, fontweight = 'bold', labelpad=10)
plt.tick_params(axis='both', labelsize=10)

plt.ylabel('Permutation Importance', fontsize = 16, fontweight = 'bold', labelpad=10)

plt.tight_layout()

plt.savefig('diabetes_impt_line.png', dpi=200)

In [None]:
### cirrhosis ###

# importance vs sample line

# plt.figure(figsize=(12,8))

# plt.plot(cirrho_impt['impt'])

# plt.xlabel('Species', fontsize = 16, fontweight = 'bold', labelpad=10)
# plt.tick_params(axis='both', labelsize=10)

# plt.ylabel('Permutation Importance', fontsize = 16, fontweight = 'bold', labelpad=10)

# plt.tight_layout()

# plt.savefig('cirrhosis_impt_line.png', dpi=200)

In [35]:
cols = list(processed.columns[:7]) + list(temp2['features'][:10]) + ['label']

diab = processed[cols].copy()

diab['label'] = diab['label'].apply(lambda x: 'Diabetes' if x == 1 else 'Ctrl')

In [46]:
diab_means = diab.groupby(by='label').mean().T

In [62]:
test = list(temp2['features'][:10])

In [71]:
stuff = list(processed.columns[:7]) + test + ['label']

grad_stuff = processed[stuff].copy()

grad_stuff['label'] = grad_stuff['label'].apply(lambda x: 'Diabetes' if x == 1 else 'Ctrl')

In [74]:
import scipy as sp

mann = []

for i in test:
    ctrl = grad_stuff[i][grad_stuff['label'] == 'Ctrl']
    fat = grad_stuff[i][grad_stuff['label'] == 'Diabetes']
    mann.append(sp.stats.mannwhitneyu(ctrl, fat))

In [76]:
test

['Collinsella aerofaciens',
 'Odoribacter splanchnicus',
 'Megasphaera unclassified',
 'Parabacteroides unclassified',
 'Ruminococcus sp 5 1 39BFAA',
 'Clostridium bolteae',
 'Clostridium citroniae',
 'Bifidobacterium adolescentis',
 'Bifidobacterium longum',
 'Bacteroides vulgatus']

In [75]:
mann

[MannwhitneyuResult(statistic=33397.0, pvalue=7.224711382942652e-05),
 MannwhitneyuResult(statistic=29644.0, pvalue=2.261856987976006e-09),
 MannwhitneyuResult(statistic=34121.0, pvalue=6.408672929580585e-06),
 MannwhitneyuResult(statistic=33558.5, pvalue=6.898033308249406e-05),
 MannwhitneyuResult(statistic=31396.0, pvalue=2.9740521426316235e-07),
 MannwhitneyuResult(statistic=26114.0, pvalue=2.1677522501221008e-14),
 MannwhitneyuResult(statistic=28766.0, pvalue=6.125876764052805e-11),
 MannwhitneyuResult(statistic=38847.0, pvalue=0.14688867697010616),
 MannwhitneyuResult(statistic=32030.0, pvalue=3.5573550970751967e-06),
 MannwhitneyuResult(statistic=32459.0, pvalue=1.1807412955071055e-05)]

In [None]:
### diabetes ###

ind = np.arange(len(diab_means))

width = 0.4

fig, ax = plt.subplots(figsize = (16,10))

color = ['royalblue', 'orange']
ax.barh(ind, diab_means['Ctrl'], width, color = color[1], label = 'Ctrl', edgecolor='black')
ax.barh(ind + width, diab_means['Diabetes'], width, color = color[0], label = 'Diabetes', edgecolor='black')
ax.set(yticks = ind + width/2, yticklabels = temp2['features'][:10])
#ax.set_title('Mea', fontsize = 18, fontweight='bold', pad=20)
ax.set_xscale('log')
ax.set_ylabel('Species', fontsize = 18, fontweight='bold')
ax.set_xlabel("Relative Abundance", fontsize = 18, fontweight = 'bold', labelpad = 20)
ax.legend(loc='upper right', fontsize = 'x-large')
ax.tick_params(axis='both', labelsize = 14)
plt.gca().invert_yaxis()

#plt.savefig('diabetes_impt_abundance.png', dpi=200, bbox_inches = "tight")

plt.show()

In [None]:
### cirrhosis ###

# ind = np.arange(len(cirrho_means))

# width = 0.4

# fig, ax = plt.subplots(figsize = (16,10))

# color = ['royalblue', 'orange']
# ax.barh(ind, cirrho_means.Cirrhosis, width, color = color[0], label = 'Cirrhosis')
# ax.barh(ind + width, cirrho_means.Ctrl, width, color = color[1], label = 'Ctrl')
# ax.set(yticks = ind + width/2, yticklabels = cirrho_means.Bacteria)
# #ax.set_title('Mea', fontsize = 18, fontweight='bold', pad=20)
# ax.set_xscale('log')
# ax.set_ylabel('Species', fontsize = 18, fontweight='bold')
# ax.set_xlabel("Log Relative Abundance", fontsize = 18, fontweight = 'bold', labelpad = 20)
# ax.legend(loc='lower right', fontsize = 'x-large')
# ax.tick_params(axis='both', labelsize = 14)
# plt.gca().invert_yaxis()

# plt.savefig('cirrhosis_impt_abundance.png', dpi=200, bbox_inches = "tight")

# plt.show()

In [169]:
### for plotting correlation heatmap for all data ###


processed = data.copy()

# drop columns that are not needed
to_drop = list(cols[2:4]) + list(cols[8:20]) + list(cols[21:211])
processed.drop(columns = to_drop, inplace = True)

# 2nd filtering round

bacteria = list(processed.columns)[7:]

s = re.compile(r's__\w+$')

not_species = [i for i in bacteria if not s.search(i)]

processed.drop(columns = not_species, inplace=True)


# create a new column as labels
processed['label'] = processed['disease'].apply(lambda x:  0 if (x == 'n') else 1)

processed.reset_index(inplace=True, drop=True)

species = processed.columns[7:833]

processed[species] = processed[species].apply(pd.to_numeric)

# rename feature names to shorter one

s = re.compile(r's__(\w+)')

short = []

for i in list(species):
    short.append( s.search(i).group(1).replace("_", " "))
    
new = dict(zip(species, short))

processed.rename(columns = new, inplace=True)

In [185]:
test = processed.iloc[:, 7:-1].corr()

In [None]:
import seaborn as sns

fig, ax = plt.subplots(1,1, figsize = (10,8))

sns.heatmap(test, xticklabels = 82, yticklabels = 82, ax=ax, cmap='magma')

plt.tight_layout()

plt.savefig('correlation.png', dpi=300)