In [1]:
import pandas as pd
import numpy as np
from itertools import product, combinations
import math
import pickle
from tqdm import tqdm

In [2]:
def loadObj(f):
    with open(f, 'rb') as file: 
        return pickle.load(file)

In [3]:
def saveObj(obj, f):
    with open(f, 'wb') as file: 
        pickle.dump(obj, file) 

In [4]:
def extractBaseGroups(S, P, d):
    if d == 2:
        return tuple(tuple(S[i] for i in px) for px in P)
    else:
        return tuple(S[i] for i in P)

In [5]:
def getCounts(count_col, seq_col, position_groups):

    'Get a frequency dictionary'

    if isinstance(position_groups[0][0], int):
        d = 1
    else:
        d = 2
    
    dct_counts = {}
    
    for pg in tqdm(position_groups):
        
        dct_counts[pg] = {}

        for c, s in zip(count_col, seq_col):
            bg = extractBaseGroups(s, pg, d)
            if bg in dct_counts[pg].keys():
                dct_counts[pg][bg] += c
            else:
                dct_counts[pg][bg] = c
    return dct_counts

In [6]:
def getMI_groups(dct_single, dct_paired):

    count_sum = sum([dct_single[list(dct_single.keys())[0]][k] for k in list(dct_single.values())[0].keys()])
    res = []
    for pgc in dct_paired.keys():
        pg1, pg2 = pgc
        pgc_res = []
        for bgc in dct_paired[pgc].keys():
            bg1, bg2 = bgc
    
            paired_frequency = dct_paired[pgc][bgc] / count_sum
            partial_frequency_1 = dct_single[pg1][bg1] / count_sum
            partial_frequency_2 = dct_single[pg2][bg2] / count_sum
    
            pgc_res.append(paired_frequency * math.log((paired_frequency)/(partial_frequency_1 * partial_frequency_2))) # Compute MI
    
        res.append([pgc, sum(pgc_res)])
        
    df_res = pd.DataFrame(res, columns=['pgc', 'mi'])
    df_res['pg1'] = ['-'.join([str(x) for x in pgc[0]]) for pgc in df_res['pgc']]
    df_res['pg2'] = ['-'.join([str(x) for x in pgc[1]]) for pgc in df_res['pgc']]
    del df_res['pgc']
    
    return df_res

In [13]:
dataset = '../../datasets/datasets_prepped/strc_km.csv'

position_groups = [(0, 21), (1,), (2, 23), (3, 22), (4, 20), (5, 19), (6, 18), (7, 17), (8,), (9,),  (10, 14), (11,), (12,), (13,), (15,), (16,)]

position_groups_combined = list(combinations(position_groups, r=2))
position_groups_combined_flattened = [(x+y) for x, y in position_groups_combined]

In [14]:
df = pd.read_csv(dataset, usecols=['varseq', 'count'])


In [23]:
df['count'] = df['count'].sample(frac=1).values

---
### All together

In [25]:
dct_counts_single = getCounts(df['count'], df['varseq'], position_groups)
dct_counts_paired = getCounts(df['count'], df['varseq'], position_groups_combined)

100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [01:24<00:00,  5.25s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 120/120 [21:32<00:00, 10.77s/it]


In [26]:
saveObj(dct_counts_single, 'dct_counts_single_all_km_shuf.pkl')
saveObj(dct_counts_paired, 'dct_counts_paired_all_km_shuf.pkl')

In [27]:
df_res = getMI_groups(dct_counts_single, dct_counts_paired)

In [28]:
df_res.to_csv('mi_groups_all_km_shuf.csv')

---
### Per Local king

In [None]:
# Load neighborhood assignment
df_nbh_assignments = pd.read_csv('../fitness_lanscape_eda/neighborhood_assignment/lk_dists_and_nbh_assignments.csv', index_col=0)

# Concat the neighborhood assignments onto the dataset
df_merged = pd.concat([df, df_nbh_assignments], axis=1)

# Local king idxs (names)
lks = [0, 11, 113, 550, 673]

# Subset the dataset to only contain determinable sequences (sequences which are clearly closest to one particular local king) This is about 70% of the dataset
df_merged = df_merged[df_merged['type'] == 'd']

In [None]:
for lk in lks:
    # Subset the dataframe further to get only the sequences from a given neighborhood
    df_sub = df_merged[df_merged['lk'] == lk]

    # Compute single and paired counts
    dct_counts_single = getCounts(df_sub['count'], df_sub['varseq'], position_groups)
    dct_counts_paired = getCounts(df_sub['count'], df_sub['varseq'], position_groups_combined)

    # Save these
    saveObj(dct_counts_single, 'dct_counts_single_LK' + str(lk) + '.pkl')
    saveObj(dct_counts_paired, 'dct_counts_paired_LK' + str(lk) + '.pkl')

    # Compute mutual information for the froups
    df_res = getMI_groups(dct_counts_single, dct_counts_paired)

    # Save the dataframe containing the MI results
    df_res.to_csv('mi_groups_LK' + str(lk) + '.csv')