In [488]:
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

In [489]:
from typing import TypedDict

class Likelihood(TypedDict):
    A: float
    B: float
    C: float
    D: float
    
def create_likelihood(a: float, b: float, c: float, d: float) -> Likelihood:
    return {'A': a, 'B': b, 'C': c, 'D': d}

def random_likelihood() -> Likelihood:
    rand_nums = np.random.random(4)
    a, b, c, d = rand_nums / sum(rand_nums)
    return create_likelihood(a, b, c, d)

def normalize_likelihood(likelihood: Likelihood) -> Likelihood:
    total = sum(likelihood.values())
    return {k: v/total for k, v in likelihood.items()}


def sum_to_one(likelihood: Likelihood) -> Likelihood:
    total = sum(likelihood.values())
    complement = 1 - total
    if complement < 0:
        raise ValueError(f"Total is greater than 1: {total}")
    if complement > 0:
        zero_value_count = sum([1 for v in likelihood.values() if v == 0])
        fill_value = complement / zero_value_count
        return {k: v if v != 0 else fill_value for k, v in likelihood.items()}
    return likelihood

In [490]:
node_types = ['A', 'B', 'C', 'D']
node_count = [8, 4, 4, 3]
node_count_dict = dict(zip(node_types, node_count))
n_questions = sum(node_count)

In [491]:
init_answers_df = pd.read_csv('init_answers.csv', index_col=0).iloc[:n_questions]
init_answers = init_answers_df.to_dict('records')
init_answers = [sum_to_one(x) for x in init_answers]
init_answers = [normalize_likelihood(x) for x in init_answers]

In [492]:
node_answers_type = [f"{t}{i+1}" for t in node_types for i in range(node_count_dict[t])]
assert len(node_answers_type) == n_questions, f"Number of answers: {len(node_answers_type)}, expected: {n_questions}"
node_questions = [f"Q{i+1}" for i in range(n_questions)]

In [493]:
G = nx.Graph()
G.add_nodes_from(node_answers_type, bipartite=0)
G.add_nodes_from(node_questions, bipartite=1)

for i, init_answer in enumerate(init_answers):
    question_node = node_questions[i]
    for j, (node_type, node_count) in enumerate(node_count_dict.items()):
        for k in range(node_count):
            answer_node = f"{node_type}{k+1}"
            G.add_edge(question_node, answer_node, weight=init_answer[node_type])

print(G)
print(f"Number of answers: {len(node_answers_type)}")

# # Draw bipartite graph
# pos = nx.bipartite_layout(G, node_answers_type)
# nx.draw_networkx_nodes(G, pos, nodelist=node_answers_type, node_color='r')
# nx.draw_networkx_nodes(G, pos, nodelist=node_questions, node_color='b')
# nx.draw_networkx_edges(G, pos, width=1.0, alpha=0.5)
# plt.show()

Graph with 38 nodes and 361 edges
Number of answers: 19


In [494]:
matching = nx.max_weight_matching(G, maxcardinality=True)
matching = [(v, k) if v in node_questions else (k, v) for k, v in dict(matching).items()]
matching = {k: v for k, v in dict(matching).items() if k in node_questions}
matching_df = pd.Series(matching)

assert len(matching_df) == n_questions, f"Number of questions: {len(matching_df)}, expected: {n_questions}"

In [495]:
# Process matching for nice display
matching_df.index = matching_df.index.str.replace('Q', '')
matching_df.index = matching_df.index.astype(int)
matching_df = matching_df.sort_index()

matching_df = matching_df.apply(lambda x: x[0])

matching_df.index.name = 'Question'
matching_df.name = 'Answer'

In [496]:
matching_df_int = matching_df.apply(lambda x: node_types.index(x))
matching_df_int.name = 'Answer_int'
df = pd.concat([matching_df, matching_df_int], axis=1)

In [497]:
df['Likeliness'] = df.apply(lambda x: init_answers[x.name-1][x['Answer']], axis=1)

In [498]:
init_answers_df = pd.DataFrame(init_answers)
init_answers_df.index = init_answers_df.index + 1

In [499]:
df = pd.concat([df, init_answers_df], axis=1)
df.drop('Answer_int', axis=1, inplace=True)
df = df.round(2)
df.to_csv('quizz_answers.csv')
print(df['Likeliness'].describe())
df

count    19.00
mean      0.25
std       0.00
min       0.25
25%       0.25
50%       0.25
75%       0.25
max       0.25
Name: Likeliness, dtype: float64


Unnamed: 0,Answer,Likeliness,A,B,C,D
1,D,0.25,0.25,0.25,0.25,0.25
2,D,0.25,0.25,0.25,0.25,0.25
3,D,0.25,0.25,0.25,0.25,0.25
4,C,0.25,0.25,0.25,0.25,0.25
5,C,0.25,0.25,0.25,0.25,0.25
6,C,0.25,0.25,0.25,0.25,0.25
7,C,0.25,0.25,0.25,0.25,0.25
8,B,0.25,0.25,0.25,0.25,0.25
9,B,0.25,0.25,0.25,0.25,0.25
10,B,0.25,0.25,0.25,0.25,0.25
