In [None]:
import networkx as nx
import pandas as pd
import numpy as np
from typing import Type
import random
from collections import defaultdict

In [None]:
# Imports the specific files used in this
def importFiles() -> Type[pd.core.frame.DataFrame]:
    
    # Import the path file
    physical = pd.read_csv(r"C:\Users\jd_fr\Desktop\Summer Research 2020\physical.csv", names = ["Gene1", "Gene2"], sep = " ")
    
    # Import the data file
    col_names = ["Gene","Colors"]
    first = pd.read_csv(r"C:\Users\jd_fr\Desktop\Summer Research 2020\\MIPSFirstLevel.csv", names = col_names, sep = "\s{2,6}")

    return first,physical

In [None]:
# Creates a graph and populates it with the imported data
def populateGraph(first, physical) -> Type[nx.classes.graph.Graph]:
    G = nx.Graph()
    
    # Populate nodes from full file
    for index,row in first.iterrows():
        G.add_node(row.values[0], colors = row.values[1])
        
    # Adding edges from full file
    for index,row in physical.iterrows():
        G.add_edge(row.values[0],row.values[1])
    
    return G

In [None]:
# Returns a dictionary containing frequency information of the overall data set
def getDataMode() -> dict:
    col_names = ["Gene","Colors"]
    first = pd.read_csv(r"C:\Users\jd_fr\Desktop\Summer Research 2020\\MIPSFirstLevel.csv", names = col_names, sep = "\s{2,6}")
    mishapedList = freqOverall(first)
    fixedList = convertColors(mishapedList)
    return toFreq(fixedList)

#Returns a list of overall frquencies for a list
def freqOverall(first: list) -> list:
    a = []
    justVals = first.drop("Gene", axis = 1)
    for index,row in justVals.iterrows():
        a.append(row.values[0])
    return a

#Converts a list of colors to a properly formatted list of colors
def convertColors(colors: list) -> list:
    fixedList = []
    for color in colors:
        fixedList.extend(color.split())
    return fixedList

# Converts a list to a dict of frequencies for provided list
def toFreq(lst: list) -> dict:
    freq = defaultdict(int)
    for l in lst:
        freq[l] += 1
    return freq

In [None]:
#Finds the colors of a set of neighbors,
def findColors(neigh, Graph: Type[networkx.classes.graph.Graph], atrs):
    neighs = []
    for gene in neigh:
        try:
            atribs = list(atrs[gene].split())
        except KeyError:
            continue
        for color in atribs:
            if(len(color) == 0 or '#' in color):
                continue
            neighs.extend(color.split())
    return neighs

In [None]:
# Returns the mode of a list, in the event of a tie this will return the smallest number. 
def findMode(colors: list) -> int:
    colors = [int(col) for col in colors]
    return (max(set(colors),key=colors.count))

In [None]:
# Returns true if the specified guess is in the string-formatted list of atributes.
def isCorrect(guess: int, atribs: str) -> bool: 
    fixedList = atribs.split()
    if('#' in fixedList):
        return False
    intatribs = [int(i) for i in fixedList]
    return (guess in intatribs)

In [None]:
#Runs an interation of testing given the starting Graph
def run(G: Type[networkx.classes.graph.Graph]):
    with open("DSD-src-0.50//RESULTS.txt", "r") as a_file: 
        num, denom, var = 0,0,1
        first = []
        
        atrs = nx.get_node_attributes(G, 'colors')
        data_mode = getDataMode()
    
        for line in a_file: #Need to loop through partitioned set instead of full set. Which cwould switch the line data.
            if(var == 1): #Stores the information from the first iteration (the gene names)
                var += 1
                first = line.strip().split()
                continue
            else:
                var += 1
                #Randomly partition here.
                if(random.randint(0,1) == 1):
                    continue
                denom += 1
                stripped_line = line.strip() 
                lineList = stripped_line.split() #Full line
                numplist = np.array(lineList) #All of the stuff, 0 index has the name (now a numpy array)
                geneName = lineList[0]
                try:
                    geneAtribs = atrs[geneName]
                except KeyError: #The cases when the gene does not exist in the node set (so ignore them)
                    denom -= 1
                    continue
            
                minsIndx = numplist.argsort()[:5] #Index of minimum values
            
                idxList = list(minsIndx)
                neighs = []
                for i in idxList:
                    if(i != 0):
                        neighs.append(first[i-1])
                neighColors = findColors(neighs, G, atrs)
                if(len(neighColors) == 0):
                    continue
                guess = findMode(neighColors)
                correct = isCorrect(guess, atrs[lineList[0]])
                if(correct):
                    num += 1

        return num/denom

In [None]:
# Iterates through 10 cycles of the DSD algoritm with the provided data.
def main():
    ans = []
    first, physical = importFiles()
    G = populateGraph(first, physical)
    for i in range(10):
        ans.append(run(G))
    getStats(ans)

In [None]:
#Takes a list of averages and prints out the mean and STDev
def getStats(overall: list):
    mean = sum(overall)/len(overall)
    var = sum([((x-mean)**2) for x in overall]) / len(overall)
    stdev = var**.5
    print('Mean: {}'.format(mean))
    print('StDev: {}'.format(stdev))

In [None]:
main()