<a href="https://colab.research.google.com/github/anandraiyer/access_forums_eval/blob/main/testing_framework.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
import numpy as np
import math
from itertools import combinations
from itertools import chain
from itertools import product
from itertools import starmap
from functools import partial
import networkx as nx

In [2]:
#The ground truth is stored here
data = pd.read_csv('/content/drive/MyDrive/final.csv')

In [3]:
data.columns

Index(['Thread', 'DateTime', 'Author', 'Post', 'ParentPosts', 'PostID',
       'ThreadID', 'AuthorID', 'OriginID', 'DialogAct', 'ParentID_List'],
      dtype='object')

In [4]:
df = data[['Thread','ThreadID','PostID','DateTime','Author','Post','ParentID_List']]

In [5]:
df.head()

Unnamed: 0,Thread,ThreadID,PostID,DateTime,Author,Post,ParentID_List
0,Testing the new Site!,1,0,2015-11-30 11:59:19.544597,Tim Ford,Thank you very much for all those who worked o...,-1
1,Netflix not accessible to blind people using a...,2,1,2015-11-30 12:05:11.019288,Tim Ford,"Hi All, For those out there who want to use N...",-1
2,Testing the new Site!,1,2,2015-11-30 12:06:01.976409,"Walker, Michael E","Hi Tim, the group is working fine. I got your ...",[0]
3,Testing the new Site!,1,3,2015-11-30 12:12:07.800324,ken lawrence,Should the JDH mail be deleted?,[0]
4,Netflix not accessible to blind people using a...,2,4,2015-11-30 12:14:27.873186,Greg Nickel,Will do…,[1]


In [6]:
def get_conversation_dag(tid):
  #Returns the Directed Acyclic Graph representing the Thread. Here nodes are messages and edges capture the 'reply to' relationship.
  #Use this function to calculate metrics using networkx library on graphs downstream.
  temp = df[df['ThreadID']==tid]

  threads = list(temp['Thread'])
  postids = list(temp['PostID'])
  posts = list(temp['Post'])
  datetimes = list(temp['DateTime'])
  authors = list(temp['Author'])
  parentids = list(temp['ParentID_List'])

  G = nx.DiGraph()
  G.add_nodes_from(posts)

  edges = []
  for i in range(len(postids)):
    parent = parentids[i]
    child = postids[i]

    if parent == "-1":
      continue
    else:
      if parent == "[]":
        if i > 1:
          parent = "["+str(postids[i-1])+"]"
        else:
          parent = "["+str(postids[0])+"]"
      parent = int(parent[1:-1])
    edges.append((parent,child))
  G.add_edges_from(edges)
  return G

In [7]:
def get_subthreads(tid):
  #Get a list of lists. Given a thread id get all the subthreads.
  temp = df[df['ThreadID']==tid]

  threads = list(temp['Thread'])
  postids = list(temp['PostID'])
  if len(postids)<2:
    return (postids)
  posts = list(temp['Post'])
  datetimes = list(temp['DateTime'])
  authors = list(temp['Author'])
  parentids = list(temp['ParentID_List'])

  G = nx.DiGraph()
  G.add_nodes_from(posts)

  edges = []
  for i in range(len(postids)):
    parent = parentids[i]
    child = postids[i]

    if parent == "-1":
      continue  
    else:
      if parent == "[]":
        if i > 1:
          parent = "["+str(postids[i-1])+"]"
        else:
          parent = "["+str(postids[0])+"]"
      parent = int(parent[1:-1])
    edges.append((parent,child))
  G.add_edges_from(edges)
  chaini = chain.from_iterable
  roots = (n for n,d in G.in_degree() if d==0)
  leaves = (n for n,d in G.out_degree() if d==0)
  all_paths = partial(nx.all_simple_paths, G)
  ans = list(chaini(starmap(all_paths, product(roots, leaves))))
  return(ans)

In [8]:
def get_list_subthreads(cluster_name,threadID):
  conversations = []
  for i, val in enumerate(get_subthreads(threadID)):
    conversations.append(cluster_name+str(i)+":"+str(val)[1:-1].replace(',',''))
  return conversations

In [9]:
def get_n(s1):
  #size of dataset D. Total number of posts in a thread.
  u = set([])
  for _,v in s1.items():
    u = u.union(set(v))
  return(len(u))

In [10]:
def get_length_clustering(s):
  #Get length of each clustering within a cluster s = {s1,s2,s3}, here we return [len(s1),len(s2)...]
  list_len = []
  for _,v in s.items():
    list_len.append(len(v))
  return(list_len)

In [11]:
def get_points(s):
  #Get all points in each cluster in the clustering s
  u = set([])
  for _,v in s.items():
    u = u.union(set(v))
  return(u)

In [12]:
def create_contingency_table(gold, auto):
  #Refer https://en.wikipedia.org/wiki/Rand_index#The_contingency_table
  #https://people.eng.unimelb.edu.au/baileyj/papers/yanglei.pdf
  #Generalization for Soft Clusters : http://derektanderson.com/pdfs/05482124.pdf
  table = []
  names_gold = []
  names_auto = []
  for i, _ in gold.items():
    names_gold.append(i)
  for i, _ in auto.items():
    names_auto.append(i)
  for i, v1 in gold.items():
    table_row = []
    for j, v2 in auto.items():
      table_row.append(len(v1.intersection(v2)))
    table.append(table_row)
  table = np.array(table)
  sum_rows = np.sum(table, axis = 1)
  sum_cols = np.sum(table, axis = 0)
  n = get_n(gold)
  #Generalization for soft clusterings
  n_max = np.sum(sum_rows)
  phi = n/n_max
  table = np.multiply(phi,table)
  sum_rows = np.sum(table, axis = 1)
  sum_cols = np.sum(table, axis = 0)
  return table, sum_rows, sum_cols, names_gold, names_auto

In [13]:
def get_information_entropy(s1,s2):
  c,c_rows,c_cols,g_name, a_name = create_contingency_table(s1,s2)
  H_u = 0.0
  H_v = 0.0
  total = get_n(s1)
  for i in range(len(c_rows)):
    if c_rows[i] > 0:
      H_u = H_u - ((c_rows[i]/total) * math.log(c_rows[i]/total,2))
  for i in range(len(c_cols)):
    if c_cols[i] > 0:
      H_v = H_v - ((c_cols[i]/total) * math.log(c_cols[i]/total,2))
  return (H_u,H_v)

In [14]:
def get_mi(s1,s2):
  c,c_rows,c_cols,g_name, a_name = create_contingency_table(s1,s2)
  I_uv = 0.0
  total = get_n(s1)
  for i in range(len(c_rows)):
    for j in range(len(c_cols)):
      if c[i][j] > 0:
        I_uv = I_uv + ((c[i][j] / total) * math.log((c[i][j] * total) / (c_rows[i] * c_cols[j]), 2.0))
      else:
        continue
  return(I_uv)

In [15]:
def get_joint_entropy(s1,s2):
  c,c_rows,c_cols,g_name, a_name = create_contingency_table(s1,s2)
  H_uv = 0.0
  total = get_n(s1)
  for i in range(len(c_rows)):
    for j in range(len(c_cols)):
      if c[i][j] > 0:
        H_uv = H_uv - ((c[i][j] / total) * math.log(c[i][j] / total, 2.0))
      else:
        continue
  return(H_uv)

In [16]:
def get_vi(s1,s2):
  #Refer : Original Paper : https://sites.stat.washington.edu/mmp/Papers/compare-colt.pdf
  #Refer : Code from Kummerfeld : https://github.com/jkkummerfeld/irc-disentanglement/blob/master/tools/evaluation/conversation-eval.py
  #Refer : *IMP* Adjustment for Soft Clusterings -> https://people.eng.unimelb.edu.au/baileyj/papers/yanglei.pdf
  H_uv = get_joint_entropy(s1,s2)
  I_uv = get_mi(s1,s2)
  #H_u,H_v = get_information_entropy(s1,s2)
  VI = H_uv - I_uv
  #VI = H_u + H_v - 2*I_uv (Equivalent Formula)
  return(VI)

In [17]:
def get_one_minus_scaled_vi(s1,s2):
  c,c_rows,c_cols,g_name, a_name = create_contingency_table(s1,s2)
  VI = get_vi(s1,s2)
  n = np.sum(c_rows)
  max_score = math.log(n, 2.0)
  scaled_VI = VI / max_score
  
  return round(1 - scaled_VI,3)

In [18]:
def get_normalized_vi(s1,s2):
  H_uv = get_joint_entropy(s1,s2)
  I_uv = get_mi(s1,s2)
  NVI = 1 - (I_uv / H_uv)
  return (round(NVI,3))

In [19]:
def get_nmi_joint(s1,s2):
  H_uv = get_joint_entropy(s1,s2)
  I_uv = get_mi(s1,s2)
  nmi = (I_uv / H_uv)
  return (round(nmi,3))

In [20]:
def get_nmi_max(s1,s2):
  H_u,H_v = get_information_entropy(s1,s2)
  I_uv = get_mi(s1,s2)
  nmi = (I_uv / max(H_u,H_v))
  return (round(nmi,3))

In [21]:
def get_nmi_sum(s1,s2):
  H_u,H_v = get_information_entropy(s1,s2)
  I_uv = get_mi(s1,s2)
  nmi = (I_uv / (H_u + H_v))
  return (round(nmi,3))

In [22]:
def get_nmi_sqrt(s1,s2):
  H_u,H_v = get_information_entropy(s1,s2)
  I_uv = get_mi(s1,s2)
  nmi = (I_uv / math.sqrt(H_u * H_v))
  return (round(nmi,3))

In [23]:
def get_nmi_min(s1,s2):
  H_u,H_v = get_information_entropy(s1,s2)
  I_uv = get_mi(s1,s2)
  nmi = (I_uv / min(H_u,H_v))
  return (round(nmi,3))

In [24]:
def get_one_to_one(s1,s2):
  #https://aclanthology.org/P08-1095.pdf 
  contingency,rows_sums,col_sums,gold_name,auto_name = create_contingency_table(s1,s2)
  X = rows_sums
  Y = col_sums
  N = get_n(s1)
  B = nx.Graph()
  left_nodes = []
  for i in gold_name:
    left_nodes.append('Left_'+str(i))
  right_nodes = []
  for i in auto_name:
    right_nodes.append('Right_'+str(i))
  B.add_nodes_from(left_nodes, bipartite=0)
  B.add_nodes_from(right_nodes, bipartite=1)

  for i in range(len(X)):
    for j in range(len(Y)):
      if contingency[i][j] > 0:
        B.add_edge(left_nodes[i], right_nodes[j], weight = contingency[i][j])
  
  matches = nx.algorithms.matching.max_weight_matching(B)
  one_one_ratio = 0.0
  for u,v in matches:
    one_one_ratio = one_one_ratio + (100*B.get_edge_data(u,v)['weight']/N)
  return(round(one_one_ratio,3))

In [25]:
def get_omega_score(s1,s2):
  #Refer : https://iopscience.iop.org/article/10.1088/1742-5468/2011/02/P02017/pdf (pg 6)
  n = get_n(s1)
  N = n*(n-1.0)/2.0
  all_points = get_points(s1)
  all_pairs = combinations(all_points,2)
  s1_count = {}
  s2_count = {}
  for i,j in all_pairs:
    for m,n in s1.items():
      if i in n and j in n:
        if (i,j) in s1_count:
          s1_count[(i,j)] = s1_count.get((i,j)) + 1
        else:
          s1_count[(i,j)] = 1
      else:
        if (i,j) in s1_count:
          s1_count[(i,j)] = s1_count.get((i,j)) + 0
        else:
          s1_count[(i,j)] = 0
    for m,n in s2.items():
      if i in n and j in n:
        if (i,j) in s2_count:
          s2_count[(i,j)] = s2_count.get((i,j)) + 1
        else:
          s2_count[(i,j)] = 1
      else:
        if (i,j) in s2_count:
          s2_count[(i,j)] = s2_count.get((i,j)) + 0
        else:
          s2_count[(i,j)] = 0
  
  count_s1_rev = {}
  count_s2_rev = {}
  for (i,j),k in s1_count.items():
    if count_s1_rev.get((k),-1)!=-1:
      count_s1_rev[k] = count_s1_rev.get(k).union(set([(i,j)]))
    else:
      count_s1_rev[k] = set([(i,j)])
  for (i,j),k in s2_count.items():
    if count_s2_rev.get((k),-1)!=-1:
      count_s2_rev[k] = count_s2_rev.get(k).union(set([(i,j)]))
    else:
      count_s2_rev[k] = set([(i,j)])
  
  i_max = max(list(count_s1_rev.keys()))
  j_max = max(list(count_s2_rev.keys()))
  min_ij = min(i_max,j_max)

  count_s1_s2_rev = {}
  for i in range(0,min_ij+1):
    inter = count_s1_rev.get(i,set([])).intersection(count_s2_rev.get(i,set([])))
    count_s1_s2_rev[i] = inter

  U_sum = 0.0
  for i in range(0,min_ij+1):
    U_sum = U_sum + len(count_s1_s2_rev.get(i,set([])))
  U_sum = U_sum / N
  E_sum = 0.0
  for i in range(0,min_ij+1):
    E_sum = E_sum + (len(count_s1_rev.get(i,set([])))*len(count_s2_rev.get(i,set([]))))
  E_sum = E_sum / (N * N)
  if(E_sum != 1):
    omega = (U_sum - E_sum)/(1.0- E_sum)
  else:
    omega = 0
  return(round(omega,3))

In [26]:
def create_shen_precision_table(s1, s2):
  #Definition : https://www.microsoft.com/en-us/research/wp-content/uploads/2006/01/p35-shen.pdf (pg 40)
  table = []
  contingency,rows_sums,col_sums,gold_name,auto_name = create_contingency_table(s1,s2)
  X = rows_sums
  Y = col_sums

  a = 0
  b = 0
  for i, v1 in s1.items():
    table_row = []
    n_i = X[a]
    b = 0
    for j, v2 in s2.items():
      n_j = Y[b]
      if(n_j == 0):
        table_row.append(0)
      else:
        table_row.append(len(v1.intersection(v2))/n_j)
      b = b+1
    table.append(table_row)
    a=a+1
  table = np.array(table)
  
  return table

In [27]:
def create_shen_recall_table(s1, s2):
  #Definition : https://www.microsoft.com/en-us/research/wp-content/uploads/2006/01/p35-shen.pdf (pg 40)
  table = []
  contingency,rows_sums,col_sums,gold_name,auto_name = create_contingency_table(s1,s2)
  X = rows_sums
  Y = col_sums

  a = 0
  b = 0
  for i, v1 in s1.items():
    table_row = []
    n_i = X[a]
    b = 0
    for j, v2 in s2.items():
      n_j = Y[b]
      if n_i == 0:
        table_row.append(0)
      else:
        table_row.append(len(v1.intersection(v2))/n_i)
      b = b+1
    table.append(table_row)
    a=a+1
  table = np.array(table)
  
  return table

In [28]:
def create_shen_F_table(s1, s2):
  #Definition : https://www.microsoft.com/en-us/research/wp-content/uploads/2006/01/p35-shen.pdf (pg 40)
  prec_table = create_shen_precision_table(s1,s2)
  recall_table = create_shen_recall_table(s1,s2)
  contingency,rows_sums,col_sums,gold_name,auto_name = create_contingency_table(s1,s2)
  X = rows_sums
  Y = col_sums
  table = []

  a = 0
  b = 0
  max_f = []
  for i, v1 in s1.items():
    table_row = []
    n_i = X[a]
    b = 0
    for j, v2 in s2.items():
      n_j = Y[b]
      if (prec_table[a][b] + recall_table[a][b]) != 0:
        f = (2 * prec_table[a][b] * recall_table[a][b])/(prec_table[a][b] + recall_table[a][b])
      else:
        f = 0
      table_row.append(f)
      b = b+1
    #Select Max_{j} F(i,j) for every row i
    table.append(max(table_row))
    a=a+1

  return table

In [29]:
def get_shen_f1(s1,s2):
  #Refer : https://github.com/jkkummerfeld/irc-disentanglement/blob/master/tools/evaluation/conversation-eval.py
  #Definition : https://www.microsoft.com/en-us/research/wp-content/uploads/2006/01/p35-shen.pdf (pg 40)
  contingency,rows_sums,col_sums,gold_name,auto_name = create_contingency_table(s1,s2)
  X = rows_sums
  Y = col_sums
  f_table = create_shen_F_table(s1,s2)
  sum_f = 0.0
  n = np.sum(Y)

  for i in range(len(f_table)):
      sum_f = sum_f + ((X[i]) * f_table[i])
  sum_f = sum_f / n
  return(round(sum_f,3))

In [30]:
gold = {'C0':{1,2,3,4},'C1':{5,6}}
auto = {'C0':{1,2,3,4},'C1':{5,6}}
VI = get_vi(gold, auto)
one_one = get_one_to_one(gold,auto)
omega = get_omega_score(gold,auto)
shen_f1 = get_shen_f1(gold,auto)

print(VI, one_one, omega,shen_f1)
#Sanity check 

0.0 100.0 1.0 1.0


In [31]:
model1_test = pd.read_csv('/content/drive/MyDrive/inference.forum.test.out', header=None) #2019 ACL Model by Kummerfeld et al. trained on IRC
model2_test = pd.read_csv('/content/drive/MyDrive/inference.forum.test.1.out', header=None) #2019 ACL Model by Kummerfeld et al. trained on Forum
model3_test = pd.read_csv('/content/drive/MyDrive/inference-forum.ptr.out', header=None) #2020 EMNLP Model by Yu et al. trained on IRC
model4_test = pd.read_csv('/content/drive/MyDrive/train.forum.ptr.test.3.out', header=None) #2020 EMNLP Model by Yu et al. trained on Forum


In [32]:
threadID_list = []
parent_list = []
child_list = []
for i in list(model1_test[0]):
  if i[0]=='#':
    continue
  else:
    path = i.split(':')
    thread = path[0].split('.')
    threadID = thread[3]
    anno = path[1].split(' ')
    parent = anno[1]
    child = anno[0]
    threadID_list.append(int(threadID))
    parent_list.append(int(parent))
    child_list.append(int(child))
preds = pd.DataFrame()
preds['ThreadID'] = threadID_list
preds['P'] = parent_list
preds['C'] = child_list
preds = preds[preds['P']!=preds['C']]
preds = preds.sort_values(['ThreadID','P','C'])
preds = preds.reset_index(drop=True)

In [33]:
preds #predictions by model1 (Kummerfeld trained on IRC, inference prediction on our Forum test dataset) -> post id is relative for the thread

Unnamed: 0,ThreadID,P,C
0,66,4,5
1,66,6,12
2,69,8,10
3,69,8,12
4,69,8,15
...,...,...,...
10679,22614,10,16
10680,22614,11,20
10681,22614,13,14
10682,22614,17,21


In [34]:
threadID2_list = []
parent2_list = []
child2_list = []
for i in list(model2_test[0]):
  if i[0]=='#':
    continue
  else:
    path = i.split(':')
    thread = path[0].split('.')
    threadID = thread[3]
    anno = path[1].split(' ')
    parent = anno[1]
    child = anno[0]
    threadID2_list.append(int(threadID))
    parent2_list.append(int(parent))
    child2_list.append(int(child))
preds2 = pd.DataFrame()
preds2['ThreadID'] = threadID2_list
preds2['P'] = parent2_list
preds2['C'] = child2_list
preds2 = preds2[preds2['P']!=preds2['C']]
preds2 = preds2.sort_values(['ThreadID','P','C'])
preds2 = preds2.reset_index(drop=True)

In [35]:
preds2 #predictions by model1 (Kummerfeld trained on Forum dataset, inference prediction on our Forum test dataset) -> post id is relative for the thread

Unnamed: 0,ThreadID,P,C
0,66,0,1
1,66,1,2
2,66,2,3
3,66,2,7
4,66,3,4
...,...,...,...
19061,22614,18,19
19062,22614,18,20
19063,22614,18,22
19064,22614,20,21


In [36]:
threadID3_list = []
parent3_list = []
child3_list = []
for i in list(model3_test[0]):
  if i[0]=='#':
    continue
  else:
    path = i.split(':')
    thread = path[0].split('.')
    threadID = thread[2]
    anno = path[1].split(' ')
    parent = anno[1]
    child = anno[0]
    threadID3_list.append(int(threadID))
    parent3_list.append(int(parent))
    child3_list.append(int(child))
preds3 = pd.DataFrame()
preds3['ThreadID'] = threadID3_list
preds3['P'] = parent3_list
preds3['C'] = child3_list
preds3 = preds3[preds3['P']!=preds3['C']]
preds3 = preds3.sort_values(['ThreadID','P','C'])
preds3 = preds3.reset_index(drop=True)

In [37]:
preds3 #predictions on Pointer Network model by Yu et al. 

Unnamed: 0,ThreadID,P,C
0,66,4,5
1,66,4,6
2,66,6,12
3,66,9,10
4,69,4,5
...,...,...,...
10706,22614,11,20
10707,22614,12,16
10708,22614,13,14
10709,22614,13,15


In [38]:
threadID4_list = []
parent4_list = []
child4_list = []
for i in list(model4_test[0]):
  if i[0]=='#':
    continue
  else:
    path = i.split(':')
    thread = path[0].split('.')
    threadID = thread[2]
    anno = path[1].split(' ')
    parent = anno[1]
    child = anno[0]
    threadID4_list.append(int(threadID))
    parent4_list.append(int(parent))
    child4_list.append(int(child))
preds4 = pd.DataFrame()
preds4['ThreadID'] = threadID4_list
preds4['P'] = parent4_list
preds4['C'] = child4_list
preds4 = preds4[preds4['P']!=preds4['C']]
preds4 = preds4.sort_values(['ThreadID','P','C'])
preds4 = preds4.reset_index(drop=True)

In [39]:
preds4

Unnamed: 0,ThreadID,P,C
0,66,1,2
1,66,2,3
2,66,4,5
3,66,5,6
4,66,6,7
...,...,...,...
16907,22614,18,19
16908,22614,19,20
16909,22614,20,21
16910,22614,21,22


In [40]:
#1000 uniformly random chosen threads with atleast 10 messages / posts.
test_threads = set([66,69,73,79,83,84,93,100,109,113,115,129,134,135,140,147,148,150,154,158,164,166,173,179,180,198,206,223,224,226,228,255,259,260,274,279,284,301,304,306,311,328,336,348,366,367,371,383,384,405,406,407,410,417,420,427,429,436,442,453,459,460,466,476,500,506,507,512,515,516,521,531,532,542,554,556,559,564,569,578,586,589,594,596,600,603,611,621,623,625,629,632,642,644,650,651,652,654,661,680,687,688,699,701,713,715,717,723,726,742,743,750,769,772,785,788,791,794,809,816,818,821,828,829,833,842,845,846,854,856,865,873,881,885,891,896,901,903,910,912,913,915,919,925,935,939,953,960,973,983,985,988,992,1021,1025,1031,1037,1056,1073,1106,1107,1117,1119,1124,1192,1197,1216,1246,1249,1275,1287,1305,1327,1346,1349,1357,1406,1408,1410,1412,1415,1434,1452,1486,1508,1514,1518,1530,1566,1584,1595,1597,1598,1606,1618,1631,1632,1653,1670,1682,1683,1686,1700,1736,1759,1762,1776,1777,1778,1779,1795,1802,1803,1810,1815,1826,1829,1834,1839,1843,1846,1858,1863,1867,1876,1889,1893,1894,1895,1904,1908,1912,1913,1915,1923,1931,1938,1944,1948,1949,1950,1954,1955,1976,1984,1987,1993,1999,2001,2014,2017,2018,2023,2025,2036,2046,2047,2052,2063,2065,2070,2073,2076,2094,2099,2119,2140,2148,2180,2183,2184,2202,2203,2208,2209,2215,2224,2228,2229,2245,2274,2294,2307,2318,2319,2335,2344,2347,2354,2358,2359,2372,2378,2383,2396,2398,2400,2402,2409,2410,2415,2416,2422,2430,2436,2446,2447,2472,2494,2499,2500,2503,2517,2521,2535,2558,2561,2577,2578,2597,2604,2612,2614,2616,2617,2618,2625,2637,2653,2666,2689,2741,2761,2762,2767,2773,2775,2785,2789,2799,2800,2824,2836,2837,2840,2851,2852,2853,2854,2863,2864,2870,2875,2883,2885,2893,2894,2896,2902,2912,2925,2927,2928,2932,2935,2942,2945,2946,2954,2966,2979,2982,2992,2996,3003,3004,3007,3015,3016,3040,3044,3047,3048,3050,3057,3058,3059,3060,3072,3090,3095,3107,3118,3125,3127,3128,3133,3137,3151,3158,3160,3166,3172,3205,3207,3215,3220,3221,3230,3235,3237,3275,3282,3286,3287,3289,3294,3330,3332,3333,3336,3344,3347,3360,3361,3378,3402,3415,3418,3420,3451,3453,3475,3476,3479,3489,3491,3517,3522,3525,3530,3542,3543,3549,3565,3570,3576,3593,3594,3612,3641,3644,3646,3655,3659,3668,3669,3679,3721,3722,3745,3752,3757,3758,3769,3772,3775,3796,3797,3799,3802,3806,3812,3819,3834,3838,3843,3847,3859,3866,3885,3889,3906,3908,3912,3915,3918,3919,3921,3923,3926,3927,3943,3946,3953,3963,3967,3970,3971,3973,12104,12107,12111,12113,12115,21159,12119,12120,12122,12124,12125,12127,21164,21166,21167,12133,12135,21170,21172,12137,12139,21176,21179,12159,21186,21189,12166,21194,21195,21197,12171,12174,21204,21207,21209,12181,12182,21212,21214,12191,21216,21217,21220,12196,21225,21228,21233,12203,12204,12206,21237,21241,12212,12213,12214,12216,12223,21262,21266,21269,21270,21271,12235,21274,21278,21288,12249,21295,21297,21299,12252,12258,12261,21321,21323,21330,21331,21332,21333,21339,21347,21350,21358,21364,21373,12301,21381,21383,12307,12317,21398,12318,12319,12320,12322,21404,21405,12327,21412,21414,12332,12335,21421,21426,21428,21440,21442,12346,21450,21457,12353,21465,21468,21469,21470,21474,12359,21475,21476,21481,21486,21493,21494,21497,21500,21503,21506,21507,12366,21511,21515,21517,21521,21527,21528,21532,21533,21535,21538,21541,21544,12390,21549,21550,21552,21556,21558,21564,12416,21567,21568,21569,12423,21583,21584,21592,21599,21602,12432,12436,21604,21605,21620,21623,21628,21631,21634,21638,21640,12452,21641,21642,21645,21647,12458,12460,21649,21650,21651,21652,21659,21664,12468,21668,21669,21673,21678,12475,21685,21689,21699,21704,21707,21709,12489,21719,21720,12494,12498,21731,12507,21738,21741,12512,12514,21745,21750,21755,21756,21758,21762,21766,21768,12518,21775,21778,21779,21780,21781,12523,21787,21788,21792,12528,21793,21797,12530,21798,12537,21803,12540,12541,21805,21808,21811,21813,12548,12549,21818,21821,21824,21828,21829,21832,21833,12562,21838,21839,21841,21848,21849,21853,21854,21855,21866,21868,12570,12571,12575,21878,12583,21884,21886,21889,21890,21891,12589,21893,21897,21898,12595,21905,21906,21907,21912,21914,21915,21925,21928,21930,12597,21935,21941,21942,21945,21949,21950,21954,21955,21956,21958,21959,21960,12611,12614,12615,21972,21975,21978,12617,21981,21982,21985,21993,12627,12632,21996,22000,12641,12642,22004,22006,22008,22009,22018,22019,22022,12654,22025,22028,22032,12656,22033,22036,22041,22042,22043,12664,22046,22047,22048,22052,22056,12671,12676,22064,22067,22068,12680,12682,12684,12687,22070,22071,22078,22081,12693,12694,22086,12701,22095,22102,22106,12713,22108,22110,22111,22112,22114,22117,22120,22121,22124,22131,22134,12723,22135,12726,22146,22157,12728,22163,22169,22172,22179,22181,22183,22194,22200,22203,12749,22216,22218,22219,22223,22227,12757,22230,22236,22237,22238,22241,22242,22243,22246,22249,22250,22255,22260,12772,22264,22267,12775,22272,22276,22277,22282,22283,22284,22286,12780,12781,22293,22305,22310,12785,22325,12791,12792,12794,22342,12796,22358,22359,22361,22362,22363,22370,22373,12815,22380,22382,12817,22386,22387,22388,12823,22391,22392,22393,22395,22396,22397,12824,12828,22403,12839,22406,22412,22413,12843,12845,22421,22423,22428,22429,12854,22443,12858,22445,12868,12869,12870,22466,22472,22484,22485,12889,22497,22498,22499,22500,22522,22524,22526,12912,22535,22540,22550,22556,22558,12928,12931,12933,22567,22568,22569,12938,22580,12941,22583,12952,12958,12960,22603,12961,12965,12967,12968,22614])

In [41]:
print(len(test_threads)) #Check 1000 threads
thread_list = []
parent_list = []
child_list = []
post_sn = {}
for i in test_threads:
  if i == '':
    continue
  else:
    
    temp = df[df['ThreadID']==int(i)]
    postIDs = sorted(temp['PostID'])
    
    k = 0
    for j in postIDs:
      post_sn[(int(i),int(j))] = int(k) #Dictionary to get relative post ID given thread ID and post ID 
      k = k + 1 
    s = get_conversation_dag(int(i))
    
    for j,k in s.edges():
      thread_list.append(int(i))

      p = post_sn.get((i,j))
      c = post_sn.get((i,k))

      parent_list.append(p)
      child_list.append(c)

1000


In [42]:
gold = pd.DataFrame()
gold['ThreadID'] = thread_list
gold['P'] = parent_list
gold['C'] = child_list
gold = gold[gold['P']!=gold['C']]
gold = gold.sort_values(['ThreadID','P','C'])
gold = gold.reset_index(drop=True)

In [43]:
gold #ground truth of forum test dataset

Unnamed: 0,ThreadID,P,C
0,66,0,1
1,66,0,3
2,66,1,2
3,66,1,5
4,66,3,4
...,...,...,...
19061,22614,17,18
19062,22614,17,23
19063,22614,19,20
19064,22614,19,22


In [44]:
def get_set_of_sets(s1,name):
  #Given a prediction, get a set of subthreads for each thread
  set_set_list = []
  for i in test_threads:
    temp = s1[s1['ThreadID']==int(i)]
    p = list(temp['P'])
    c = list(temp['C'])
    edges = []
    for j in range(len(p)):
      edges.append((int(p[j]),int(c[j])))
    all_nodes = list(set(p).union(set(c)))
    G = nx.DiGraph()
    G.add_nodes_from(all_nodes)
    G.add_edges_from(edges)
    chaini = chain.from_iterable
    roots = (n for n,d in G.in_degree() if d==0)
    leaves = (n for n,d in G.out_degree() if d==0)
    all_paths = partial(nx.all_simple_paths, G)
    ans = chaini(starmap(all_paths, product(roots, leaves)))
    set_ans = {}
    k = 0
    for i in ans:
      set_ans[name+str(k)] = set(i)
      k = k + 1
    set_set_list.append(set_ans)
  return set_set_list

In [45]:
gt = get_set_of_sets(gold,'G') #Ground Truth
pr = get_set_of_sets(preds,'AC') #Model 1
pr2 = get_set_of_sets(preds2,'AF') #Model 2
pr3 = get_set_of_sets(preds3,'EC') #Model 3
pr4 = get_set_of_sets(preds4,'EF') #Model 3

In [46]:
temp = pd.DataFrame()
temp['TID'] = list(test_threads)
temp['GT'] = gt
temp['PR'] = pr
temp['PR2'] = pr2
temp['PR3'] = pr3
temp['PR4'] = pr4

In [47]:
preds4

Unnamed: 0,ThreadID,P,C
0,66,1,2
1,66,2,3
2,66,4,5
3,66,5,6
4,66,6,7
...,...,...,...
16907,22614,18,19
16908,22614,19,20
16909,22614,20,21
16910,22614,21,22


In [48]:
temp

Unnamed: 0,TID,GT,PR,PR2,PR3,PR4
0,2052,"{'G0': {0, 1}, 'G1': {0, 2, 3, 4, 5, 6, 7, 8},...","{'AC0': {0, 2}, 'AC1': {0, 4, 6}, 'AC2': {0, 8...","{'AF0': {0, 1, 2}, 'AF1': {0, 3, 4, 5, 6, 7, 8...","{'EC0': {0, 1}, 'EC1': {0, 2}, 'EC2': {0, 4, 5...","{'EF0': {0, 1, 2, 3}, 'EF1': {4, 5, 6, 7, 8}, ..."
1,22535,"{'G0': {0, 1, 2, 3, 4, 5, 6}, 'G1': {0, 1, 2, ...","{'AC0': {0, 2, 4}, 'AC1': {0, 2, 6, 7, 9}, 'AC...","{'AF0': {0, 1, 2, 3, 4, 5, 6}, 'AF1': {0, 7}, ...","{'EC0': {0, 1}, 'EC1': {0, 2}, 'EC2': {3, 4, 5...","{'EF0': {0, 1}, 'EF1': {0, 2, 3, 4, 5}, 'EF2':..."
2,22540,"{'G0': {0, 1}, 'G1': {0, 2}, 'G2': {0, 6}, 'G3...","{'AC0': {0, 3}, 'AC1': {0, 10, 11}, 'AC2': {0,...","{'AF0': {0, 1, 2}, 'AF1': {0, 3, 4}, 'AF2': {0...","{'EC0': {0, 3, 13}, 'EC1': {4, 5}, 'EC2': {4, ...","{'EF0': {0, 1, 2}, 'EF1': {3, 4, 5, 6, 7, 8, 9..."
3,12301,"{'G0': {0, 2}, 'G1': {0, 3}, 'G2': {0, 1, 4}, ...","{'AC0': {0, 4}, 'AC1': {11, 3, 12}, 'AC2': {5,...","{'AF0': {0, 1, 2}, 'AF1': {0, 3}, 'AF2': {0, 1...","{'EC0': {3, 5, 6}, 'EC1': {8, 7}, 'EC2': {9, 1...","{'EF0': {0, 1, 2, 3}, 'EF1': {0, 1, 2, 4, 5}, ..."
4,2063,"{'G0': {0, 1, 2, 3, 4}, 'G1': {0, 5}, 'G2': {0...","{'AC0': {0, 8, 2}, 'AC1': {1, 3}, 'AC2': {9, 5...","{'AF0': {0, 1}, 'AF1': {0, 2, 3}, 'AF2': {0, 2...","{'EC0': {0, 2, 4}, 'EC1': {0, 8}, 'EC2': {1, 3...","{'EF0': {0, 1}, 'EF1': {2, 4}, 'EF2': {2, 3, 5..."
...,...,...,...,...,...,...
995,2036,"{'G0': {0, 1, 2}, 'G1': {0, 1, 3}, 'G2': {0, 1...","{'AC0': {8, 1, 3}, 'AC1': {11, 1, 3, 4}, 'AC2'...","{'AF0': {0, 1, 2, 3}, 'AF1': {0, 1, 4, 5, 6, 7...","{'EC0': {0, 1}, 'EC1': {2, 3, 4}, 'EC2': {8, 2...","{'EF0': {0, 1, 2, 4, 5}, 'EF1': {0, 1, 2, 3, 6..."
996,22522,"{'G0': {0, 1}, 'G1': {0, 2, 3, 4, 5}, 'G2': {0...","{'AC0': {3, 5}, 'AC1': {9, 6, 7}}","{'AF0': {0, 1}, 'AF1': {0, 2, 3, 4, 5, 6, 7}, ...","{'EC0': {9, 4}, 'EC1': {4, 7}}","{'EF0': {0, 1, 2, 3}, 'EF1': {8, 9, 7}}"
997,22524,"{'G0': {0, 1, 2, 3, 4}, 'G1': {0, 5, 6}, 'G2':...","{'AC0': {0, 2}, 'AC1': {0, 4}, 'AC2': {0, 1, 6...","{'AF0': {0, 1, 2, 3}, 'AF1': {0, 1, 2, 4}, 'AF...","{'EC0': {0, 2}, 'EC1': {0, 4}, 'EC2': {0, 1, 6...","{'EF0': {0, 1, 6, 7}, 'EF1': {0, 1, 6, 8}, 'EF..."
998,2046,"{'G0': {0, 1}, 'G1': {0, 2}, 'G2': {0, 3}, 'G3...","{'AC0': {0, 11}, 'AC1': {8, 1, 10}, 'AC2': {2,...","{'AF0': {0, 1}, 'AF1': {0, 2, 3}, 'AF2': {0, 4...","{'EC0': {2, 3, 4, 5}, 'EC1': {2, 3, 4, 6, 7}, ...","{'EF0': {0, 1, 2, 3, 4, 5, 6, 7}, 'EF1': {0, 1..."


In [49]:
def get_metrics(s1,s2):
  one_minus_scaled_VI_list = []
  NVI_list = []
  omega_list = []
  one_one_list = []
  shen_f1_list = []
  counter = 0
  for i in range(len(s1)):
    if len(s2[i]) == 0:
      counter = counter + 1
      NVI_list.append(0)
      one_minus_scaled_VI_list.append(0)
      omega_list.append(0)
      one_one_list.append(0)
      shen_f1_list.append(0)
    else:
      one_minus_scaled_VI = get_one_minus_scaled_vi(s1[i],s2[i])
      NVI = 1 - get_normalized_vi(s1[i],s2[i])
      OM = get_omega_score(s1[i],s2[i])
      OVO = get_one_to_one(s1[i],s2[i])
      SF1 = get_shen_f1(s1[i],s2[i])

      print('Metrics',i,one_minus_scaled_VI, NVI,OM,OVO,SF1)
      one_minus_scaled_VI_list.append(one_minus_scaled_VI)
      NVI_list.append(NVI)
      omega_list.append(OM)
      one_one_list.append(OVO)
      shen_f1_list.append(SF1)

  print('No Predictions :'+str(counter))
  metrics = pd.DataFrame()
  metrics['TID'] = list(test_threads)
  metrics['1-Scaled VI'] = one_minus_scaled_VI_list
  metrics['1-Normalized VI'] = NVI_list
  metrics['Omega'] = omega_list
  metrics['one_to_one'] = one_one_list
  metrics['shen_F1'] = shen_f1_list
  return metrics

In [50]:
metrics = get_metrics(gt, pr)

Metrics 0 -0.047 0.013000000000000012 0.063 30.556 0.939
Metrics 1 0.127 0.0040000000000000036 -0.006 35.0 0.849
Metrics 2 -0.243 0.08499999999999996 0.164 22.917 0.872
Metrics 3 0.114 0.22399999999999998 -0.046 30.0 0.645
Metrics 4 0.259 0.20499999999999996 0.071 38.095 0.574
Metrics 5 -0.165 0.06999999999999995 0.017 15.741 0.723
Metrics 6 -0.299 0.08199999999999996 -0.037 9.705 0.756
Metrics 7 -0.027 0.10599999999999998 0.034 26.087 0.714
Metrics 8 0.203 0.237 -0.152 26.087 0.413
Metrics 9 -0.084 0.03200000000000003 0.084 30.0 0.697
Metrics 10 0.237 0.253 0.126 34.783 0.597
Metrics 11 -0.251 0.02300000000000002 -0.014 9.009 0.535
Metrics 12 0.548 0.25 -0.077 50.0 0.233
Metrics 13 -0.04 0.10899999999999999 0.038 27.586 0.694
Metrics 14 0.024 0.08199999999999996 0.002 13.333 0.524
Metrics 15 0.02 0.03600000000000003 0.095 35.714 0.543
Metrics 16 -0.139 0.06499999999999995 0.136 22.642 0.828
Metrics 17 -0.015 0.07799999999999996 -0.035 22.222 0.453
Metrics 18 0.086 0.15100000000000002 

  after removing the cwd from sys.path.


0.15600000000000003 0.062 31.034 0.449
Metrics 143 -0.25 0.049000000000000044 0.06 19.298 0.768
Metrics 144 -0.092 0.125 0.004 19.565 0.589
Metrics 145 0.17 0.238 -0.059 27.027 0.528
Metrics 146 -0.021 0.051000000000000045 0.024 35.0 0.657
Metrics 147 0.59 0.49 0.042 45.0 0.374
Metrics 148 0.048 0.06000000000000005 -0.011 33.333 0.518
Metrics 149 0.049 0.09099999999999997 0.147 30.769 0.659
Metrics 150 -0.255 0.016000000000000014 0.11 29.63 0.918
Metrics 151 0.145 0.16500000000000004 0.147 26.087 0.538
Metrics 152 -0.07 0.07099999999999995 0.041 26.087 0.799
Metrics 153 0.263 0.23199999999999998 -0.155 30.0 0.411
Metrics 154 0.143 0.129 0.009 26.667 0.41
Metrics 155 -0.226 0.031000000000000028 0.217 14.433 0.806
Metrics 156 -0.076 0.08899999999999997 0.081 23.333 0.728
Metrics 157 -0.017 0.09799999999999998 -0.054 20.0 0.628
Metrics 158 -0.056 0.08699999999999997 0.064 30.435 0.714
Metrics 159 -0.445 0.02100000000000002 0.145 16.176 0.939
Metrics 160 0.214 0.139 -0.007 28.571 0.418
Met

In [51]:
metrics2 = get_metrics(gt, pr2)

Metrics 0 -0.024 0.015000000000000013 0.161 35.556 1.394
Metrics 1 0.344 0.020000000000000018 -0.116 50.0 1.123
Metrics 2 -0.455 0.016000000000000014 0.222 19.737 1.216
Metrics 3 -0.715 0.010000000000000009 0.577 16.216 1.516
Metrics 4 -0.432 0.016000000000000014 0.257 18.28 1.25
Metrics 5 -0.502 0.0040000000000000036 0.479 10.944 1.865
Metrics 6 -0.651 0.008000000000000007 0.374 7.5 1.281
Metrics 7 -0.145 0.02400000000000002 0.182 17.822 1.17
Metrics 8 -0.64 0.013000000000000012 0.778 16.374 1.573
Metrics 9 0.024 0.02300000000000002 0.531 36.111 1.422
Metrics 10 -0.395 0.018000000000000016 0.296 18.947 1.183
Metrics 11 -0.723 0.007000000000000006 0.391 9.091 1.419
Metrics 12 -0.409 0.016000000000000014 0.522 27.083 1.369
Metrics 13 -0.449 0.02100000000000002 0.481 21.591 1.483
Metrics 14 -0.66 0.006000000000000005 0.178 10.178 1.645
Metrics 15 -0.131 0.02300000000000002 0.291 38.462 0.993
Metrics 16 -0.25 0.019000000000000017 0.085 17.105 0.97
Metrics 17 -0.626 0.01200000000000001 0.4

In [52]:
metrics3 = get_metrics(gt, pr3)

Metrics 0 -0.019 0.017000000000000015 0.082 33.333 1.282
Metrics 1 0.113 0.0010000000000000009 -0.034 33.333 0.864
Metrics 2 0.033 0.08099999999999996 -0.103 24.0 0.528
Metrics 3 0.527 0.473 -0.157 40.0 0.41
Metrics 4 -0.009 0.122 0.121 32.353 0.705
Metrics 5 -0.151 0.09099999999999997 -0.016 17.647 0.72
Metrics 6 -0.292 0.10199999999999998 0.048 12.222 0.963
Metrics 7 -0.244 0.052000000000000046 0.05 21.795 0.932
Metrics 8 -0.282 0.03200000000000003 0.224 17.241 0.856
Metrics 9 -0.049 0.0040000000000000036 -0.057 33.333 0.713
Metrics 10 0.133 0.19499999999999995 0.043 32.0 0.584
Metrics 11 0.235 0.32099999999999995 -0.097 23.81 0.485
Metrics 12 0.11 0.07299999999999995 0.062 33.333 0.548
Metrics 13 0.251 0.14900000000000002 0.139 33.333 0.486
Metrics 14 -0.147 0.06899999999999995 0.047 14.286 0.882
Metrics 15 -0.046 0.010000000000000009 0.134 31.579 0.676
Metrics 16 -0.103 0.05500000000000005 0.164 24.194 0.904
Metrics 17 0.623 0.29700000000000004 0.037 60.0 0.272
Metrics 18 -0.115 0.

  after removing the cwd from sys.path.


-0.104 33.333 0.352
Metrics 149 -0.055 0.09199999999999997 0.017 29.032 0.738
Metrics 150 -0.117 0.03300000000000003 0.053 38.889 0.692
Metrics 151 -0.264 0.049000000000000044 0.235 15.574 0.837
Metrics 152 0.024 0.121 -0.099 30.303 0.639
Metrics 153 0.437 0.378 0.128 42.857 0.679
Metrics 154 0.508 0.351 -0.066 42.857 0.282
Metrics 155 0.204 0.17900000000000005 0.067 28.0 0.377
Metrics 156 0.127 0.17500000000000004 0.04 27.273 0.625
Metrics 157 0.15 0.16400000000000003 -0.051 24.138 0.554
Metrics 158 -0.07 0.09799999999999998 0.034 38.095 0.752
Metrics 159 -0.446 0.02100000000000002 0.229 16.923 0.947
Metrics 160 0.527 0.246 -0.045 40.0 0.205
Metrics 161 0.388 0.18700000000000006 0.147 50.0 0.426
Metrics 162 0.161 0.08299999999999996 0.167 28.571 0.485
Metrics 163 -0.029 0.07499999999999996 0.101 29.412 0.817
Metrics 164 -0.126 0.03700000000000003 0.176 20.93 0.732
Metrics 165 0.24 0.22199999999999998 0.112 34.375 0.498
Metrics 166 0.17 0.118 0.106 43.75 0.607
Metrics 167 0.19 0.199999

In [53]:
metrics4 = get_metrics(gt, pr4)

Metrics 0 0.197 0.05800000000000005 -0.05 37.037 0.895
Metrics 1 0.252 0.0 0.119 41.667 1.091
Metrics 2 0.038 0.09899999999999998 -0.003 25.0 0.604
Metrics 3 -0.487 0.006000000000000005 0.141 15.556 1.248
Metrics 4 -0.077 0.128 -0.069 22.727 0.746
Metrics 5 -0.089 0.10799999999999998 -0.057 18.947 0.694
Metrics 6 0.045 0.19999999999999996 0.036 19.048 0.66
Metrics 7 -0.338 0.007000000000000006 0.303 15.301 1.634
Metrics 8 -0.303 0.007000000000000006 -0.003 12.15 0.886
Metrics 9 0.08 0.027000000000000024 0.017 36.364 1.185
Metrics 10 -0.147 0.129 0.042 24.074 0.856
Metrics 11 -0.067 0.18799999999999994 -0.043 19.512 0.725
Metrics 12 -0.332 0.03700000000000003 0.309 27.778 1.06
Metrics 13 -0.176 0.126 -0.078 25.0 0.782
Metrics 14 -0.495 0.02200000000000002 0.22 10.309 1.687
Metrics 15 0.072 0.14800000000000002 -0.077 35.294 0.674
Metrics 16 0.028 0.20099999999999996 -0.126 24.444 0.651
Metrics 17 0.104 0.19999999999999996 -0.01 33.333 0.631
Metrics 18 -0.3 0.015000000000000013 0.363 23.7

In [54]:
metrics4.describe() #model4 EMNLP2020 Pointer Network Trained on Forum Data

Unnamed: 0,TID,1-Scaled VI,1-Normalized VI,Omega,one_to_one,shen_F1
count,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0
mean,10570.267,-0.144286,0.07406,0.046099,22.973363,0.951876
std,9136.185749,0.236139,0.071671,0.147326,9.976866,0.310653
min,66.0,-0.769,0.0,-0.463,2.317,0.248
25%,2016.25,-0.3045,0.015,-0.046,15.72725,0.72775
50%,8038.5,-0.128,0.0565,0.043,22.6115,0.88
75%,21602.5,0.02625,0.112,0.132,29.032,1.12225
max,22614.0,0.841,0.731,1.0,75.0,2.491


In [55]:
metrics3.describe() #model3 EMNLP2020 Pointer Network Pretrained

Unnamed: 0,TID,1-Scaled VI,1-Normalized VI,Omega,one_to_one,shen_F1
count,1000.0,1000.0,997.0,1000.0,1000.0,1000.0
mean,10570.267,0.060166,0.136778,0.047133,29.273597,0.635064
std,9136.185749,0.264798,0.116092,0.104883,12.573207,0.220365
min,66.0,-0.573,0.0,-0.227,0.0,0.0
25%,2016.25,-0.12725,0.054,-0.01925,20.408,0.4945
50%,8038.5,0.036,0.105,0.037,27.778,0.6285
75%,21602.5,0.205,0.188,0.10025,35.294,0.7725
max,22614.0,1.0,0.667,0.577,100.0,1.589


In [56]:
metrics2.describe() #model 2 ACL 2019 Trained on 

Unnamed: 0,TID,1-Scaled VI,1-Normalized VI,Omega,one_to_one,shen_F1
count,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0
mean,10570.267,-0.436024,0.012731,0.316285,18.925462,1.364437
std,9136.185749,0.199315,0.007378,0.19001,8.143707,0.272943
min,66.0,-0.832,0.0,-0.333,1.679,0.621
25%,2016.25,-0.578,0.008,0.18775,13.06975,1.209
50%,8038.5,-0.469,0.011,0.299,18.889,1.332
75%,21602.5,-0.3285,0.016,0.4305,23.438,1.47725
max,22614.0,0.344,0.056,1.0,55.556,3.682


In [57]:
metrics.describe() #model 1 ACL 2019 Pretrained 

Unnamed: 0,TID,1-Scaled VI,1-Normalized VI,Omega,one_to_one,shen_F1
count,1000.0,1000.0,999.0,1000.0,1000.0,1000.0
mean,10570.267,0.053063,0.127695,0.040897,27.955024,0.613219
std,9136.185749,0.233099,0.107043,0.094611,11.561757,0.200134
min,66.0,-0.531,0.0,-0.305,0.0,0.0
25%,2016.25,-0.10925,0.056,-0.01625,20.0,0.492
50%,8038.5,0.0225,0.101,0.0335,26.748,0.615
75%,21602.5,0.177,0.1655,0.091,33.333,0.73525
max,22614.0,1.0,0.724,0.564,100.0,1.66


In [58]:
#IRC Data for Reference
irc_model = pd.read_csv('/content/drive/MyDrive/inference.irc.test.1.out',header=None)
irc_gold = pd.read_csv('/content/drive/MyDrive/irc_gold.out',header=None)

In [59]:
irc_model

Unnamed: 0,0
0,# Sun Sep 26 17:50:59 2021
1,# disentangle.py inference.irc.test.1 --model ...
2,../data/test/2005-07-06_14.annotation.txt:0 0 -
3,../data/test/2005-07-06_14.annotation.txt:1 0 -
4,../data/test/2005-07-06_14.annotation.txt:2 2 -
...,...
14997,../data/test/2016-06-08_07.annotation.txt:1495...
14998,../data/test/2016-06-08_07.annotation.txt:1496...
14999,../data/test/2016-06-08_07.annotation.txt:1497...
15000,../data/test/2016-06-08_07.annotation.txt:1498...


In [60]:
irc_gold

Unnamed: 0,0
0,993 1000 -
1,995 1000 -
2,1001 1001 -
3,1000 1002 -
4,998 1003 -
...,...
5182,1495 1495 -
5183,1489 1496 -
5184,1492 1497 -
5185,1497 1498 -


In [61]:
threadID_list = []
parent_list = []
child_list = []
for i in list(irc_model[0]):
  if i[0]=='#':
    continue
  else:
    path = i.split(':')
    thread = path[0].split('/')
    threadID = thread[3].split('_')[0]
    anno = path[1].split(' ')
    parent = anno[1]
    child = anno[0]
    threadID_list.append(str(threadID))
    parent_list.append(int(parent))
    child_list.append(int(child))
irc = pd.DataFrame()
irc['ThreadID'] = threadID_list
irc['P'] = parent_list
irc['C'] = child_list
irc = irc[irc['P']!=irc['C']]
irc = irc.sort_values(['ThreadID','P','C'])
irc = irc.reset_index(drop=True)

In [62]:
irc = irc[irc['C']>999]

In [63]:
irc

Unnamed: 0,ThreadID,P,C
704,2005-07-06,993,1000
705,2005-07-06,993,1002
707,2005-07-06,998,1003
708,2005-07-06,1005,1006
709,2005-07-06,1006,1007
...,...,...,...
12047,2016-06-08,1491,1492
12048,2016-06-08,1493,1494
12049,2016-06-08,1493,1497
12050,2016-06-08,1496,1499


In [64]:
irc_threads = sorted(set(irc['ThreadID']))

In [65]:
irc_threads

['2005-07-06',
 '2007-01-11',
 '2007-12-01',
 '2008-07-14',
 '2010-08-17',
 '2013-09-01',
 '2014-06-18',
 '2015-03-18',
 '2016-02-22',
 '2016-06-08']

In [66]:
def get_irc_set_of_sets(s1,name):
  #Given a prediction, get a set of subthreads for each thread
  set_set_list = []
  for i in irc_threads:
    temp = s1[s1['ThreadID']==str(i)]
    p = list(temp['P'])
    c = list(temp['C'])
    edges = []
    for j in range(len(p)):
      edges.append((int(p[j]),int(c[j])))
    all_nodes = list(set(p).union(set(c)))
    G = nx.DiGraph()
    G.add_nodes_from(all_nodes)
    G.add_edges_from(edges)
    chaini = chain.from_iterable
    roots = (n for n,d in G.in_degree() if d==0)
    leaves = (n for n,d in G.out_degree() if d==0)
    all_paths = partial(nx.all_simple_paths, G)
    ans = chaini(starmap(all_paths, product(roots, leaves)))
    set_ans = {}
    k = 0
    for i in ans:
      set_ans[name+str(k)] = set(i)
      k = k + 1
    set_set_list.append(set_ans)
  return set_set_list

In [67]:
irc_pr = get_irc_set_of_sets(irc,'IP')

In [68]:
irc_gold

Unnamed: 0,0
0,993 1000 -
1,995 1000 -
2,1001 1001 -
3,1000 1002 -
4,998 1003 -
...,...
5182,1495 1495 -
5183,1489 1496 -
5184,1492 1497 -
5185,1497 1498 -


In [69]:
irc_gold_threads = ['2005-07-06'] *  506 + ['2007-01-11'] * 520 + ['2007-12-01'] * 516 +['2008-07-14'] * 528 + ['2010-08-17'] * 522 + ['2013-09-01'] * 532 + ['2014-06-18'] * 517 + ['2015-03-18'] * 517 + ['2016-02-22'] * 518 + ['2016-06-08'] * 511

In [70]:
irc_gold['Threads'] = irc_gold_threads

In [71]:
irc_gold_th = irc_gold['Threads'] + [':']*len(irc_gold[0]) + irc_gold[0]

In [72]:
irc_gold = pd.DataFrame()
irc_gold[0]=irc_gold_th

In [73]:
irc_gold

Unnamed: 0,0
0,2005-07-06:993 1000 -
1,2005-07-06:995 1000 -
2,2005-07-06:1001 1001 -
3,2005-07-06:1000 1002 -
4,2005-07-06:998 1003 -
...,...
5182,2016-06-08:1495 1495 -
5183,2016-06-08:1489 1496 -
5184,2016-06-08:1492 1497 -
5185,2016-06-08:1497 1498 -


In [74]:
threadID_list = []
parent_list = []
child_list = []
for i in list(irc_gold[0]):
  if i[0]=='#':
    continue
  else:
    path = i.split(':')
    threadID = path[0]
    anno = path[1].split(' ')
    parent = anno[0]
    child = anno[1]
    threadID_list.append(str(threadID))
    parent_list.append(int(parent))
    child_list.append(int(child))
ircg = pd.DataFrame()
ircg['ThreadID'] = threadID_list
ircg['P'] = parent_list
ircg['C'] = child_list
ircg = ircg[ircg['P']!=ircg['C']]
ircg = ircg.sort_values(['ThreadID','P','C'])
ircg = ircg.reset_index(drop=True)

In [75]:
ircg

Unnamed: 0,ThreadID,P,C
0,2005-07-06,993,1000
1,2005-07-06,995,1000
2,2005-07-06,998,1003
3,2005-07-06,1000,1002
4,2005-07-06,1002,1013
...,...,...,...
4260,2016-06-08,1492,1493
4261,2016-06-08,1492,1497
4262,2016-06-08,1493,1494
4263,2016-06-08,1496,1499


In [76]:
irc_gt = get_irc_set_of_sets(ircg,'IG')

In [None]:
irc_metrics = get_metrics(irc_gt,irc_pr)

Metrics 0 0.244 0.29900000000000004 0.327 13.489 1.502
Metrics 1 -0.077 0.136 0.266 6.248 3.124
Metrics 2 0.03 0.18799999999999994 0.255 6.627 2.784
Metrics 3 0.151 0.25 0.423 9.551 1.997


In [None]:
irc_metrics.describe()