In [2]:
import os
import pandas as pd
from combat.pycombat import pycombat
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import f1_score
import joblib

In [3]:
os.chdir("../../Dataset/Merged")

In [4]:
dataset = pd.read_csv('MergedDatasetZeroes.csv', index_col=0)

sampleID = dataset['SampleID']
datasetID = dataset['SampleID'].apply(lambda x: x.split('-')[0]).values
indicator = dataset['Label']
dataset = dataset.drop(columns=['SampleID', 'Label'])

dataset = pycombat(dataset.transpose(), datasetID).transpose()

dataset.insert(0, 'SampleID', sampleID)
dataset.insert(1, 'Label', indicator)

def getPatientID(sampleID):
    return sampleID.split('-')[0] + '-' + sampleID.split('-')[1].split('_', 1)[1]

dataset.insert(1, 'PatientID', dataset['SampleID'].apply(getPatientID))
gruppi = dataset.groupby('PatientID')

def sanity_check(gruppi):
    for group_name, group_data in gruppi:
        if 'Control' in group_data['SampleID'].iloc[0]:
            for e in group_data['SampleID']:
                if not 'Control' in e:
                    print("Errore in gruppo:", group_name)
                    break
        else:
            for e in group_data['SampleID']:
                if 'Control' in e:
                    print("Errore in gruppo:", group_name)
                    break

sanity_check(gruppi)

splitter = GroupShuffleSplit(n_splits=2, test_size=0.25, random_state = 42)
split = splitter.split(dataset, groups=dataset['PatientID'])
train_inds, test_inds = next(split)

train = dataset.iloc[train_inds].sample(frac=1, random_state=42)
test = dataset.iloc[test_inds].sample(frac=1, random_state=42)

print("Dataset di train:")
print(train.shape)
print("I malati sono: ", sum(train['Label'] == 1))
print("I sani sono: ", sum(train['Label'] == 0))

print("\nDataset di test:")
print(test.shape)
print("I malati sono: ", sum(test['Label'] == 1))
print("I sani sono: ", sum(test['Label'] == 0))

y_train = train['Label']
x_train = train.drop(columns=['SampleID', 'Label', 'PatientID'])

y_test = test['Label']
x_test = test.drop(columns=['SampleID', 'Label', 'PatientID'])

ensemble = joblib.load('../../Modelli/DatasetZeroes/ensembleSoft.pkl')

Found 7 batches.
Adjusting for 0 covariate(s) or covariate level(s).
Standardizing Data across genes.
Fitting L/S model and finding priors.
Finding parametric adjustments.


  np.absolute(d_new-d_old)/d_old))  # maximum difference between new and old estimate


Adjusting the Data
Dataset di train:
(1593, 12094)
I malati sono:  695
I sani sono:  898

Dataset di test:
(520, 12094)
I malati sono:  245
I sani sono:  275


In [None]:
def calcScores(x_test, y_test, model):
    scores = {col: 0 for col in x_test.columns}
    i = 1
    for exclude in range(x_train.shape[1]):
        x = x_test.copy()
        x[x.columns[exclude]] = x[x.columns[exclude]].mean()
        scores[x.columns[exclude]] = f1_score(y_test, model.predict(x))
        print(i)
        i += 1
    return scores

scores = calcScores(x_test, y_test, ensemble)
joblib.dump(scores, '../../ShapValues/DatasetZeroes/ensemble_ablationScoresPURI.pkl')

In [5]:
scores = joblib.load('../../ShapValues/DatasetZeroes/ensembleSoft_ablationScoresPURI.pkl')

In [6]:
original = f1_score(y_test, ensemble.predict(x_test))
new_scores = {key: original/value for key, value in scores.items()}
# joblib.dump(new_scores, '../../ShapValues/DatasetZeroes/ensembleSoft_ablationScores.pkl')

In [4]:
new_scores = joblib.load('../../ShapValues/DatasetZeroes/ensembleSoft_ablationScores.pkl')

In [7]:
sorted_scores = dict(sorted(new_scores.items(), key=lambda item: item[1], reverse=True))
bestFeatures = {key: value for key, value in sorted_scores.items() if value > 1.0}
worstFeatures = {key: value for key, value in sorted_scores.items() if value < 1.0}
print("Le feature migliori sono: ", len(bestFeatures), bestFeatures)
print("Le feature rumorose sono: ", len(worstFeatures), worstFeatures)

Le feature migliori sono:  3 {'DDX17': 1.0021938441388343, 'HLA-DRB1': 1.0019646365422397, 'HLA-DRB5': 1.0010665169800728}
Le feature rumorose sono:  7 {'TCN1': 0.9980353634577603, 'RNF10': 0.9980353634577603, 'FRG1': 0.9980353634577603, 'FBLN2': 0.9980353634577603, 'TRMT5': 0.9980353634577603, 'TRIM21': 0.9980353634577603, 'RPS28': 0.9980353634577603}
