In [663]:
from SPARQLWrapper import SPARQLWrapper, JSON

def get_citations(source_celex, cites_depth=1, cited_depth=1):
    """
    Gets all the citations one to X steps away. Hops can be specified as either
    the source document citing another (defined by `cites_depth`) or another document
    citing it (`cited_depth`). Any numbers higher than 1 denote that new source document
    citing a document of its own.

    This specific implementation does not care about intermediate steps, it simply finds
    anything X or fewer hops away without linking those together.
    """    
    sparql = SPARQLWrapper('https://publications.europa.eu/webapi/rdf/sparql')
    sparql.setReturnFormat(JSON)
    sparql.setQuery('''
        prefix cdm: <http://publications.europa.eu/ontology/cdm#>
        prefix xsd: <http://www.w3.org/2001/XMLSchema#>

        SELECT DISTINCT * WHERE
        {
        {
            SELECT ?name2 WHERE {
                ?doc cdm:resource_legal_id_celex "%s"^^xsd:string .
                ?doc cdm:work_cites_work{1,%i} ?cited .
                ?cited cdm:resource_legal_id_celex ?name2 .
            }
        } UNION {
            SELECT ?name2 WHERE {
                ?doc cdm:resource_legal_id_celex "%s"^^xsd:string .
                ?cited cdm:work_cites_work{1,%i} ?doc .
                ?cited cdm:resource_legal_id_celex ?name2 .
            }
        }
        }''' % (source_celex, cites_depth, source_celex, cited_depth))
    ret = sparql.queryAndConvert()

    targets = set()
    for bind in ret['results']['bindings']:
        target = bind['name2']['value']
        targets.add(target)
    targets = set([el for el in list(targets) if el.startswith('3')]) #Filters the list. Filtertype: '3'=legislation, '6'=case law.
        
    return targets

def get_citations_multiple(sources, cites_depth=1, cited_depth=1, union=True):
    """
    Gets citations coming from multiple sources (given as a list of CELEX IDs).
    By default gets the union of all the resulting CELEXes, but of interest
    might be the intersect instead, returning only documents that are common
    between all the sources.
    """
    results = [get_citations(source, cites_depth, cited_depth) for source in sources]
    results.append(sources) #ensures that source nodes (ie starting points) are included in nodes list

    if union:
        return set().union(*results)    
    else:
        start_set = results[0]
        if len(sources) > 1:
            return start_set.union(*results[1:])
        else:
            return start_set

def get_citations_structure(source, cites_depth=1, cited_depth=1, dont_repeat=set()):
    if cites_depth > 0 and cited_depth > 0:
        cites, nodes1 = get_citations_structure(source, cites_depth, 0, dont_repeat)
        cited, nodes2 = get_citations_structure(source, 0, cited_depth, dont_repeat)
        return cites.union(cited), nodes1.union(nodes2)


    new_cites_depth = max(cites_depth - 1, 0)
    new_cited_depth = max(cited_depth - 1, 0)

    dont_repeat = dont_repeat.union({source})

    links = set()
    nodes = {source}
    targets = get_citations(source, min(cites_depth, 1), min(cited_depth, 1))

    for target in targets:
        nodes.add(target)
        # We're looking for citations from the source
        if cites_depth > 0:
            links.add((source, target))
        # Or to the source
        else:
            links.add((target, source))

        if new_cites_depth or new_cited_depth and target not in dont_repeat:
            new_links, new_nodes = get_citations_structure(target, new_cites_depth, new_cited_depth)
            links = links.union(new_links)
            nodes = nodes.union(new_nodes)

    return links, nodes

def get_citations_structure_multiple(sources, cites_depth=1, cited_depth=1):
    links = set()
    nodes = set(sources)
    for source in sources:
        if source.startswith('3'):
            new_links, new_nodes = get_citations_structure(source, cites_depth, cited_depth)
            links = links.union(new_links)
            nodes = nodes.union(new_nodes)
#            nodes = set([el for el in list(nodes) if el.startswith('3')]) #Filters the list. Filtertype: '3'=legislation, '6'=case law.
    return links, nodes

In [664]:
expert_docs = {
    '32018R1139',
    '32010R0996',
    '32012R1025',
    '32008R0765',
    '32008D0768',
    '32006L0042',
    '32009L0048',
    '32012R0748',
    '32015R0640',
    '32019R0945',
    '32008R0300',
    '32014R0376',
    '32019R0947',
    '32012R0923',
    '32021R0666',
    '32011R1178',
    '32012R0965',
    '32014R1321',
    '32011R1332',
    '32018R1048',
    '32017R0373',
    '32021R0664',
    '32021R0665',
    '32016L1148',
    '32019R0881',
    '32016R0679',
    '31985L0374',
    '32004R0785',
}

In [781]:
source = '32021R0664'
sources = ['32019R0945','32021R0664']

In [782]:
links, nodes = get_citations_structure(source, cited_depth=0, cites_depth=2)

In [783]:
links, nodes = get_citations_structure_multiple(sources, cited_depth=1, cites_depth=2)

In [784]:
# Example on how to go through nodes in a set
for node in nodes:
    print(node)

31991L0477
32011R1332
32002R2342
31997L0067
32008R0765
32007R0614
32007D0344
31993R0793
31981L0851
31998L0034
32004L0018
31967L0548
32014R1321
31977L0388
31999R1073
32012R0923
32006L0042
32004R0550
32002L0021
32008R0216
32002L0095
31994L0062
32006D1926
31970L0156
32001R1049
32014L0035
32011R0182
32000X1218(01)
31989L0048
32012R1025
32010D0048
31989L0392
32006R1907
31996R2185
32002D0676
31999L0044
32002L0024
32009L0081
31999Y0128(01)
31992L0005
31992R2913
32008L0063
31990L0314
31998L0006
32005L0036
32006D1673
31993L0013
32021R0665
32014L0030
32009L0048
32015R0640
32020R0639
32021R0664
31987L0102
31997L0066
32004R0549
31992L0028
32021R0666
31999D0468
32012R0965
32008D0768
32002R1605
32020R0746
31975L0319
31984L0450
31994L0047
32001L0095
32015R0340
31989L0552
32004R1935
32017R0373
32000L0031
32008R1008
32019R0945
31995R2988
32011R1178
31985L0374
32002L0096
31984D0133
32012R0748
31998L0027
31999L0093
31995L0046
31993L0022
31998L0084
32014L0053
32016R0679
32001R0045
32006L0012
31998L0043
31

In [785]:
def do_stats(nodes, print_res=True):
    nodes = set(nodes)

    precision = len(nodes.intersection(expert_docs)) / float(len(nodes))
    recall = len(nodes.intersection(expert_docs)) / float(len(expert_docs))
    f1 = 2 * (precision * recall) / (precision + recall)
    if print_res:
        print(f'Total nodes found in search: {len(nodes)}')
        print(f'Precision: {precision}\nRecall: {recall}\nF1: {f1}')

        print(f'Common nodes ({len(nodes.intersection(expert_docs))}): {nodes.intersection(expert_docs)}')
        print(f'Missed nodes ({len(expert_docs - nodes)}): {expert_docs - nodes}')
        print(f"Extra nodes (ones it shouldn't ({len(nodes - expert_docs)}): {nodes - expert_docs}")
    return (precision, recall, f1)

do_stats(nodes)

Total nodes found in search: 102
Precision: 0.20588235294117646
Recall: 0.75
F1: 0.3230769230769231
Common nodes (21): {'32011R1332', '32021R0664', '32008R0765', '32021R0666', '32008D0768', '32012R0965', '32014R1321', '32006L0042', '32012R0923', '32017R0373', '32019R0945', '32011R1178', '31985L0374', '32012R0748', '32012R1025', '32016R0679', '32019R0947', '32014R0376', '32021R0665', '32015R0640', '32009L0048'}
Missed nodes (7): {'32010R0996', '32008R0300', '32018R1139', '32004R0785', '32018R1048', '32019R0881', '32016L1148'}
Extra nodes (ones it shouldn't (81): {'32002R2342', '31997L0067', '32007R0614', '31993R0793', '32004L0018', '31999R1073', '31994L0062', '31970L0156', '32001R1049', '32000X1218(01)', '31989L0048', '32010D0048', '31989L0392', '31996R2185', '32002L0024', '31999Y0128(01)', '31992L0005', '31998L0006', '32020R0639', '31987L0102', '32004R0549', '31999D0468', '32020R0746', '31975L0319', '31994L0047', '32015R0340', '32004R1935', '32000L0031', '31995R2988', '32002L0096', '31

(0.20588235294117646, 0.75, 0.3230769230769231)

In [786]:
# Go through matrix of depths and calculate stats for each
for i in range(5):
    for j in range(5):
        try:
            precision, recall, f1 = do_stats(get_citations_multiple(sources, i, j), print_res=False)
            print(f'Cites depth: {i:02d} | Cited depth: {j:02d} | Pr {precision:.2f} Re {recall:.2f} F1 {f1:.2f}')
        except ZeroDivisionError:
            print("Tried dividing by zero, skipping")

Cites depth: 00 | Cited depth: 00 | Pr 1.00 Re 0.07 F1 0.13
Cites depth: 00 | Cited depth: 01 | Pr 0.67 Re 0.14 F1 0.24
Cites depth: 00 | Cited depth: 02 | Pr 0.57 Re 0.14 F1 0.23
Cites depth: 00 | Cited depth: 03 | Pr 0.50 Re 0.14 F1 0.22
Cites depth: 00 | Cited depth: 04 | Pr 0.40 Re 0.14 F1 0.21
Cites depth: 01 | Cited depth: 00 | Pr 0.76 Re 0.57 F1 0.65
Cites depth: 01 | Cited depth: 01 | Pr 0.70 Re 0.57 F1 0.63
Cites depth: 01 | Cited depth: 02 | Pr 0.67 Re 0.57 F1 0.62
Cites depth: 01 | Cited depth: 03 | Pr 0.64 Re 0.57 F1 0.60
Cites depth: 01 | Cited depth: 04 | Pr 0.59 Re 0.57 F1 0.58
Cites depth: 02 | Cited depth: 00 | Pr 0.21 Re 0.75 F1 0.33
Cites depth: 02 | Cited depth: 01 | Pr 0.21 Re 0.75 F1 0.32
Cites depth: 02 | Cited depth: 02 | Pr 0.20 Re 0.75 F1 0.32
Cites depth: 02 | Cited depth: 03 | Pr 0.20 Re 0.75 F1 0.32
Cites depth: 02 | Cited depth: 04 | Pr 0.20 Re 0.75 F1 0.31
Cites depth: 03 | Cited depth: 00 | Pr 0.07 Re 0.75 F1 0.12
Cites depth: 03 | Cited depth: 01 | Pr 0

In [787]:
#prepare the edges list for csv
edges_list_for_csv = []
for i in links:
    to_add = i[0][1:]+','+i[1][1:]
    edges_list_for_csv.append(to_add)
#write to files
with open('./csvoutput/extracted_edges(network).csv', 'w', newline='') as f:  
    for entries in edges_list_for_csv:
        f.write(entries)
        f.write("\n")

In [762]:
import networkx as nx

# Build complete network of docs
links, _ = get_citations_structure_multiple(sources, cites_depth=2, cited_depth=2)

g = nx.Graph()
g.add_edges_from(links)

In [713]:
# Filter by degree
filter_degree = 2
filtered = {node for node, degree in g.degree if degree >= filter_degree}

do_stats(filtered)

Total nodes found in search: 33
Precision: 0.45454545454545453
Recall: 0.5357142857142857
F1: 0.49180327868852464
Common nodes (15): {'32012R0748', '32014R1321', '32021R0664', '32008R0765', '32019R0947', '32012R1025', '32006L0042', '32012R0923', '32008D0768', '32012R0965', '32017R0373', '32014R0376', '32019R0945', '32009L0048', '31985L0374'}
Missed nodes (13): {'32011R1332', '32015R0640', '32010R0996', '32008R0300', '32004R0785', '32018R1139', '32021R0666', '32018R1048', '32016R0679', '32016L1148', '32019R0881', '32021R0665', '32011R1178'}
Extra nodes (ones it shouldn't (18): {'32002R2342', '32001R1049', '32011R0182', '31995L0046', '32004R0549', '32001L0095', '32014L0053', '31999R1073', '31999D0468', '31998L0034', '32000L0031', '32002R1605', '31996R2185', '31995R2988', '32008R0216', '32002L0021', '32004R0551', '32014L0030'}


(0.45454545454545453, 0.5357142857142857, 0.49180327868852464)

In [None]:
# saving graph created above in gexf format
nx.write_gexf(g, "./networkfiles/networkfile.gexf")

In [60]:
from pyvis.network import Network

# Visualize filtered network
nt = Network(720, 1280)
nt.add_nodes(filtered)
nt.add_edges([(source, target) for source, target in g.edges if source in filtered and target in filtered])
nt.show('filtered.html')