# Load Data

In [1]:
import pandas as pd 

# columns correspond to unitigs so we must transpose this table
def load_unitig_data(path) -> pd.DataFrame:
    sr = pd.read_csv(path, sep=' ')
    sr.set_index('pattern_id',inplace=True)
    return sr.T[1:]

azm_sr = load_unitig_data("azm_sr_gwas_filtered_unitigs.Rtab")
cfx_sr = load_unitig_data("cfx_sr_gwas_filtered_unitigs.Rtab")
cip_sr = load_unitig_data("cip_sr_gwas_filtered_unitigs.Rtab")

metadata = pd.read_csv('metadata.csv')
metadata.set_index('Sample_ID',inplace=True)


In [77]:
metadata.info()

<class 'pandas.core.frame.DataFrame'>
Index: 3786 entries, ERR1549286 to ERR2172354
Data columns (total 30 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   Year            3536 non-null   float64
 1   Country         3785 non-null   object 
 2   Continent       3785 non-null   object 
 3   Beta.lactamase  1927 non-null   object 
 4   Azithromycin    3480 non-null   object 
 5   Ciprofloxacin   3129 non-null   object 
 6   Ceftriaxone     3436 non-null   object 
 7   Cefixime        3405 non-null   object 
 8   Tetracycline    1472 non-null   object 
 9   Penicillin      1465 non-null   object 
 10  NG_MAST         3779 non-null   object 
 11  Group           3786 non-null   int64  
 12  azm_mic         3478 non-null   float64
 13  cip_mic         3088 non-null   float64
 14  cro_mic         3434 non-null   float64
 15  cfx_mic         3401 non-null   float64
 16  tet_mic         1472 non-null   float64
 17  pen_mic         1465 no

In [2]:
metadata.head(5)

Unnamed: 0_level_0,Year,Country,Continent,Beta.lactamase,Azithromycin,Ciprofloxacin,Ceftriaxone,Cefixime,Tetracycline,Penicillin,...,log2_cro_mic,log2_cfx_mic,log2_tet_mic,log2_pen_mic,azm_sr,cip_sr,cro_sr,cfx_sr,tet_sr,pen_sr
Sample_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ERR1549286,2015.0,UK,Europe,,>256,,0.016,,,,...,-5.965784,,,,1.0,,0.0,,,
ERR1549290,2015.0,UK,Europe,,>256,,0.004,,,,...,-7.965784,,,,1.0,,0.0,,,
ERR1549291,2015.0,UK,Europe,,>256,,0.006,,,,...,-7.380822,,,,1.0,,0.0,,,
ERR1549287,2015.0,UK,Europe,,>256,,0.006,,,,...,-7.380822,,,,1.0,,0.0,,,
ERR1549288,2015.0,UK,Europe,,>256,,0.008,,,,...,-6.965784,,,,1.0,,0.0,,,


    # Null/NA Cleaning in Labels

1. Remove rows with NaN in the labels we are trying to predict: 'azm_sr','cfx_sr', 'cip_sr'

---

Note to Jacob:
If I remove all rows with nulls, then we drop down to ~1k entries. Not ideal. I will remove only the nulls in the target labels. By doing that I was able to preserve ~2800 entries.

Additionally, I can technically replace the nulls in the feature set with averages whether they are continuous or discrete, but since we are going to be building some kind of predictive model later, it would be bad practice to run column averages in df.fillna() before we split our dataset into training and test sets.

Therefore, I am going to split the dataset into a training and test set first.

What do you think?

-Jacob
We could evaluate how to fill the NaN values based on how the data is skewed for each feature we want to predict. 
if skew > 0 -> fill NaN with mean .. there are more 0s than 1s, mean will reflect the distribution of 0s and 1s
else if skew < 0 -> fill NaN with median .. there are more 1s than 0s, using the median ensures our negatively skewed data wont be affected by an uneven distribution of 0s and 1s when we fill it


In [21]:

def impute_cols_by_skew(df, columns):
    for column_name in columns:
        sr_skew = df[column_name].skew()
        impute_value = None
        if sr_skew > 0:
            impute_value = df[column_name].mean()
        elif sr_skew < 0:
            impute_value = df[column_name].median()
        df[column_name].fillna(impute_value)

#metadata.dropna(axis=0, how='any', inplace=True, subset=['azm_sr', 'cfx_sr', 'cip_sr'])

It looks like they're all positively skewed so we can use mean values, lets try continuous for now and see what happens. It was worth trying this though

In [4]:
metadata.head(5)

Unnamed: 0_level_0,Year,Country,Continent,Beta.lactamase,Azithromycin,Ciprofloxacin,Ceftriaxone,Cefixime,Tetracycline,Penicillin,...,log2_cro_mic,log2_cfx_mic,log2_tet_mic,log2_pen_mic,azm_sr,cip_sr,cro_sr,cfx_sr,tet_sr,pen_sr
Sample_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ERR1549286,2015.0,UK,Europe,,>256,,0.016,,,,...,-5.965784,,,,1.0,,0.0,,,
ERR1549290,2015.0,UK,Europe,,>256,,0.004,,,,...,-7.965784,,,,1.0,,0.0,,,
ERR1549291,2015.0,UK,Europe,,>256,,0.006,,,,...,-7.380822,,,,1.0,,0.0,,,
ERR1549287,2015.0,UK,Europe,,>256,,0.006,,,,...,-7.380822,,,,1.0,,0.0,,,
ERR1549288,2015.0,UK,Europe,,>256,,0.008,,,,...,-6.965784,,,,1.0,,0.0,,,


# Removing Un-needed Columns
2. Removing 'Year', and unimportant labels 'cro_sr', 'tet_sr', 'pen_sr'

In [5]:
useless_columns = ['Year', 'cro_sr', 'tet_sr', 'pen_sr']

metadata.drop(labels=useless_columns, axis=1, inplace=True)

In [58]:
metadata.columns

Index(['Country', 'Continent', 'Beta.lactamase', 'Azithromycin',
       'Ciprofloxacin', 'Ceftriaxone', 'Cefixime', 'Tetracycline',
       'Penicillin', 'NG_MAST', 'Group', 'azm_mic', 'cip_mic', 'cro_mic',
       'cfx_mic', 'tet_mic', 'pen_mic', 'log2_azm_mic', 'log2_cip_mic',
       'log2_cro_mic', 'log2_cfx_mic', 'log2_tet_mic', 'log2_pen_mic',
       'azm_sr', 'cip_sr', 'cfx_sr'],
      dtype='object')

# Cleaning non-numeric entries in numeric fields to NaN.

3. Turn Non Numeric Entries in Numeric Columns into NaN
4. Cast all numeric rows into float32
---


Notes: Turning them into NaN for now. Will engineer values for all NaNs after train and test splits are made

In [6]:
numeric_columns = [
'Azithromycin',
'Ciprofloxacin',
'Ceftriaxone',
'Cefixime',
'Tetracycline',
'Penicillin',
'NG_MAST',
'Group',
'azm_mic',
'cip_mic',
'cro_mic',
'cfx_mic',
'tet_mic',
'pen_mic',
'log2_azm_mic',
'log2_cip_mic',
'log2_cro_mic',
'log2_cfx_mic',
'log2_tet_mic',
'log2_pen_mic',
'azm_sr',
'cip_sr',
'cfx_sr',
]

for column in numeric_columns:
    metadata[column] = pd.to_numeric(metadata[column], errors='coerce', downcast="float") #https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_numeric.html

In [7]:
metadata.info()

<class 'pandas.core.frame.DataFrame'>
Index: 3786 entries, ERR1549286 to ERR2172354
Data columns (total 26 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   Country         3785 non-null   object 
 1   Continent       3785 non-null   object 
 2   Beta.lactamase  1927 non-null   object 
 3   Azithromycin    3373 non-null   float32
 4   Ciprofloxacin   3069 non-null   float32
 5   Ceftriaxone     3432 non-null   float32
 6   Cefixime        3366 non-null   float32
 7   Tetracycline    1470 non-null   float32
 8   Penicillin      1461 non-null   float32
 9   NG_MAST         3238 non-null   float32
 10  Group           3786 non-null   float32
 11  azm_mic         3478 non-null   float32
 12  cip_mic         3088 non-null   float32
 13  cro_mic         3434 non-null   float32
 14  cfx_mic         3401 non-null   float32
 15  tet_mic         1472 non-null   float32
 16  pen_mic         1465 non-null   float32
 17  log2_azm_mic    3478 no

In [8]:
metadata.head(10)

Unnamed: 0_level_0,Country,Continent,Beta.lactamase,Azithromycin,Ciprofloxacin,Ceftriaxone,Cefixime,Tetracycline,Penicillin,NG_MAST,...,pen_mic,log2_azm_mic,log2_cip_mic,log2_cro_mic,log2_cfx_mic,log2_tet_mic,log2_pen_mic,azm_sr,cip_sr,cfx_sr
Sample_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ERR1549286,UK,Europe,,,,0.016,,,,9768.0,...,,9.0,,-5.965784,,,,1.0,,
ERR1549290,UK,Europe,,,,0.004,,,,9768.0,...,,9.0,,-7.965784,,,,1.0,,
ERR1549291,UK,Europe,,,,0.006,,,,9768.0,...,,9.0,,-7.380822,,,,1.0,,
ERR1549287,UK,Europe,,,,0.006,,,,9768.0,...,,9.0,,-7.380822,,,,1.0,,
ERR1549288,UK,Europe,,,,0.008,,,,9768.0,...,,9.0,,-6.965784,,,,1.0,,
ERR1549299,UK,Europe,,,,0.012,,,,,...,,9.0,,-6.380822,,,,1.0,,
ERR1549292,UK,Europe,,,,0.023,,,,9768.0,...,,9.0,,-5.442222,,,,1.0,,
ERR1549298,UK,Europe,,0.5,,0.094,,,,,...,,-1.0,,-3.411196,,,,0.0,,
ERR1549296,UK,Europe,,0.5,,0.094,,,,,...,,-1.0,,-3.411196,,,,0.0,,
ERR1549300,UK,Europe,,,,0.008,,,,,...,,9.0,,-6.965784,,,,1.0,,


# One Hot Encode Categorical Columns

5. Turn Categorical Location entries into numerical representation https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.get_dummies.html
6. Handle 'Beta.lactamase' special case, as it is discrete with many NaNs. Cannot be engineered in the same way as continuous features. One hot encode like the other discrete features but set to all 0 if NaN

In [86]:
geographic_columns=['Country', 'Continent']
metadata = pd.get_dummies(data=metadata, prefix="Encoded", columns=geographic_columns, dtype=float)
metadata = pd.get_dummies(data=metadata, prefix="Encoded_Beta.lactamase", columns=['Beta.lactamase'], dtype=float)


In [9]:
metadata.head(5)

Unnamed: 0_level_0,Country,Continent,Beta.lactamase,Azithromycin,Ciprofloxacin,Ceftriaxone,Cefixime,Tetracycline,Penicillin,NG_MAST,...,pen_mic,log2_azm_mic,log2_cip_mic,log2_cro_mic,log2_cfx_mic,log2_tet_mic,log2_pen_mic,azm_sr,cip_sr,cfx_sr
Sample_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ERR1549286,UK,Europe,,,,0.016,,,,9768.0,...,,9.0,,-5.965784,,,,1.0,,
ERR1549290,UK,Europe,,,,0.004,,,,9768.0,...,,9.0,,-7.965784,,,,1.0,,
ERR1549291,UK,Europe,,,,0.006,,,,9768.0,...,,9.0,,-7.380822,,,,1.0,,
ERR1549287,UK,Europe,,,,0.006,,,,9768.0,...,,9.0,,-7.380822,,,,1.0,,
ERR1549288,UK,Europe,,,,0.008,,,,9768.0,...,,9.0,,-6.965784,,,,1.0,,


In [10]:
metadata.info()

<class 'pandas.core.frame.DataFrame'>
Index: 3786 entries, ERR1549286 to ERR2172354
Data columns (total 26 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   Country         3785 non-null   object 
 1   Continent       3785 non-null   object 
 2   Beta.lactamase  1927 non-null   object 
 3   Azithromycin    3373 non-null   float32
 4   Ciprofloxacin   3069 non-null   float32
 5   Ceftriaxone     3432 non-null   float32
 6   Cefixime        3366 non-null   float32
 7   Tetracycline    1470 non-null   float32
 8   Penicillin      1461 non-null   float32
 9   NG_MAST         3238 non-null   float32
 10  Group           3786 non-null   float32
 11  azm_mic         3478 non-null   float32
 12  cip_mic         3088 non-null   float32
 13  cro_mic         3434 non-null   float32
 14  cfx_mic         3401 non-null   float32
 15  tet_mic         1472 non-null   float32
 16  pen_mic         1465 non-null   float32
 17  log2_azm_mic    3478 no

# Split Dataframe into Train and Test

In [13]:
from sklearn.model_selection import train_test_split
train_inputs,test_inputs = train_test_split(metadata, test_size=0.4, random_state=42)   # partition and mix entries
print(train_inputs.size,":", test_inputs.size)

59046 : 39390


In [23]:
# apply skew based imputation

targets = ['azm_sr','cfx_sr','cip_sr']
impute_cols_by_skew(train_inputs, targets)
impute_cols_by_skew(test_inputs, targets)

train_inputs.head(5)


Unnamed: 0_level_0,Country,Continent,Beta.lactamase,Azithromycin,Ciprofloxacin,Ceftriaxone,Cefixime,Tetracycline,Penicillin,NG_MAST,...,pen_mic,log2_azm_mic,log2_cip_mic,log2_cro_mic,log2_cfx_mic,log2_tet_mic,log2_pen_mic,azm_sr,cip_sr,cfx_sr
Sample_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
15335_5#74,USA,America,S,1.0,8.0,0.015,,,,1407.0,...,,0.0,3.0,-6.058894,,,,0.0,1.0,0.001471
16853_3#19,Italy,Europe,,0.094,0.004,0.002,0.016,,,10676.0,...,,-3.411196,-7.965784,-8.965784,-5.965784,,,0.0,0.0,0.0
10625_6#18,Japan,Asia,S,0.75,32.0,0.094,0.19,2.0,2.0,7328.0,...,2.0,-0.415038,5.0,-3.411196,-2.395929,1.0,1.0,0.0,1.0,0.0
10_062,Greece,Europe,,16.0,,0.047,0.125,,,3806.0,...,,4.0,,-4.411195,-3.0,,,1.0,0.457735,0.0
8727_5#86,USA,America,S,0.5,32.0,0.03,0.06,2.0,2.0,1978.0,...,2.0,-1.0,5.0,-5.058894,-4.058894,1.0,1.0,0.0,1.0,0.0


# Normalize Numerical Features 

In [25]:
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler

normalizer = MinMaxScaler()

train_inputs[numeric_columns] = normalizer.fit_transform(train_inputs[numeric_columns])
test_inputs[numeric_columns] = normalizer.fit_transform(test_inputs[numeric_columns])

train_inputs.head(5)

Unnamed: 0_level_0,Country,Continent,Beta.lactamase,Azithromycin,Ciprofloxacin,Ceftriaxone,Cefixime,Tetracycline,Penicillin,NG_MAST,...,pen_mic,log2_azm_mic,log2_cip_mic,log2_cro_mic,log2_cfx_mic,log2_tet_mic,log2_pen_mic,azm_sr,cip_sr,cfx_sr
Sample_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SRR3360667,UK,Europe,,0.01464,0.499984,,,0.03122,0.062471,0.009818,...,0.062471,0.248393,0.937366,,,0.665905,0.732724,0.0,1.0,0.001468
8727_8#71,USA,America,S,0.03028,0.000203,0.014754,0.062266,0.003875,0.007781,,...,0.007781,0.311027,0.244704,0.493648,0.665714,0.465447,0.532266,0.0,0.0,0.0
SRR2736190,Canada,America,,0.499499,,0.00175,0.00075,,,0.202451,...,,0.561562,,0.250715,0.167143,,,1.0,0.46976,0.0
17176_1#77,USA,America,S,0.01464,9.4e-05,0.014754,0.031008,,,0.247275,...,,0.248393,0.187902,0.493648,0.582142,,,0.0,0.0,0.0
10356_1#2,Caribbean,America,R,0.000438,0.0,0.00075,0.003751,0.001453,0.093722,,...,0.093722,0.032793,0.062634,0.167143,0.334287,0.372173,0.77181,0.0,0.0,0.0


# Re-engineer Numerical NaN features if possible

### Data Stats

In [88]:
j=[0,0,0]
samples = metadata.index
for sample in samples:
    if metadata['azm_sr'][sample]:
        j[0] +=1
    if metadata['cfx_sr'][sample]:
        j[1] +=1
    if metadata['cip_sr'][sample]:
        j[2] +=1

print(j[0]/len(samples), "% of samples have resistance to azm")
print(j[1]/len(samples), "% of samples have resistance to cfx")
print(j[2]/len(samples), "% of samples have resistance to cip")


0.05650319829424307 % of samples have resistance to azm
0.0017768301350390902 % of samples have resistance to cfx
0.4541577825159915 % of samples have resistance to cip


In [89]:
samples = azm_sr.index

# some random unitig from azm_sr

from random import randint
randomUnitig = azm_sr.columns[randint(0,azm_sr.shape[1])]
j = 0
print(azm_sr.shape)
for i in range(azm_sr.shape[0]):
    if azm_sr[randomUnitig][samples[i]]:
        j += 1

print(randomUnitig,"\npresent in", j/azm_sr.shape[0],"% of azm_sr samples (",j,"/",azm_sr.shape[0],')' )

(3970, 515)
ACCCTGGACGCCGGCTACCGCTACCACAACT 
present in 0.9954659949622167 % of azm_sr samples ( 3952 / 3970 )
