In [3]:
import re
from math import log
import json
from itertools import combinations
from collections import defaultdict
from typing import Any, Optional, Dict, Union, Callable, Tuple, Type, List
import numpy as np
import math
import random
from enum import Enum
from operator import itemgetter
from functools import singledispatch
from pathlib import Path

We'll begin by taking a look at a sample of the data (to see how I got this data or to get your own data for your preferred sources and keywords, feel free to use the script I wrote to access news stories through NewsAPI on GitHub). 

In [4]:
with open('./data/CBC-story-1.json') as f:
        data = json.load(f)
        print(data)

['Every January, along with the new year\xa0come scads of predictions about what lies ahead.', 'This year, device fatigue is rampant, savvy consumers aren\'t swayed by novelty\xa0and the so-called\xa0"techlash" against big tech companies like Facebook and Google is in full force.', 'And with\xa0climate change and increasing natural disasters, people are starting to question the carbon\xa0footprint of devices and data.', 'Plus, 2020 isn\'t just "any" decade.', 'We have officially entered "the future," a fact made clear when you consider that\xa0some of the most well-known science fiction movies are now set in the past; the rainy, neon-lit dystopian world of <em So what does that mean for this year?', 'Sure, there will be flashy wearables, new smartphone models\xa0and efforts to bring virtual reality mainstream.', 'But the real innovations and trends to watch are the ones you might have to look a bit harder to see.', "Our smartphones don't seem to be going anywhere, but our relationship 

Now we need to create some tools to help deal with the text data.

To get the best results, we need to filter by part of speech - specifically we filter to get down to the "open classes", like verbs, nouns and adjectives (so called because unlike closed classes, like prepositions and articles, new words of these types get added to the vocabulary often). 

In [5]:
ASCII = re.compile(r"[^a-z0-9]")
POS = re.compile(r"^[NJ]")

def filter_pos(token: Dict[str, str]) -> bool:
    if not POS.match(token["pos"]) and token["pos"] != "ADJ" and token["pos"] != "CD":
        return False
    return True

This function calculates the similarity metric that Mihalcea uses in her paper (though you could easily substitute other similarity metrics here).

In [6]:
def overlap(s1: str, s2: str, **kwargs) -> float:
    s1 = set(s1.split())
    s2 = set(s2.split())
    intersection = len(s1 & s2)
    norm = log(len(s1)) + log(len(s2))
    return intersection / norm

This function normalizes a sentence by lower casing and removing punctuation from each word.

In [7]:
def norm_sentence(sent: str) -> str:
    tokens = sent.split()
    tokens = [norm_token(token) for token in tokens]
    return " ".join(tokens)

This function normalizes a token by lowercasing and removing non-alphanumeric characters.

In [8]:
def norm_token(token: str) -> str:
    token = token.lower()
    return ASCII.sub("", token)

This function builds a vocab mapping tokens to indices and produces indices are contiguous and a token has the index of where it first appeared in the tokens list

In [9]:
def build_vocab(tokens: List[str]) -> Dict[str, int]:
    vocab = defaultdict(lambda: len(vocab))
    for token in tokens:
        vocab[token]
    return {k: i for k, i in vocab.items()}

Now to create the graph! First we create a Vertex object:

In [10]:
class Vertex:
    def __init__(self, value: str):
        self.value = value
        self._edges_out: Dict[int, float] = {}
        self._edges_in: Dict[int, float] = {}

    @property
    def edges_out(self) -> Dict[int, float]:
        """A mapping of target vertex to weight representing 
        the edges with this vertex as the source."""
        return self._edges_out

    @property
    def edges_in(self) -> Dict[int, float]:
        """A mapping of source vertex to weight representing 
        the edges that end at this vertex."""
        return self._edges_in

    @property
    def degree_in(self) -> int:
        """The number of edges that end at this vertex."""
        return len(self.edges_in)

    @property
    def degree_out(self) -> int:
        """The number of edges that start at this vertex."""
        return len(self.edges_out)

    def __str__(self) -> str:
        """A summary of this vertex."""
        return f"V(term={self.value}, in={self.degree_in}, out={self.degree_out})"

    def __eq__(self, other) -> bool:
        if not isinstance(other, Vertex):
            raise TypeError(f"Can only compare to other Vertex objects, got {type(other)}")
        if self is other:
            return True
        if self.value != other.value:
            return False
        if self._edges_out != other._edges_out:
            return False
        if self._edges_in != other._edges_in:
            return False
        return True

And a Graph object:

In [11]:
class Graph:
    def __init__(self, vertices: Union[Dict[str, int], List[str]]):
        """A directed simple graph.
        :param vertices: A mapping of vertex labels to integer indices or a 
        list of vertex labels. If the latter then indices are assigned in order.
        """
        if isinstance(vertices, dict):
            if set(vertices.values()) != set(range(len(vertices))):
                raise ValueError("Vertex indices must be contiguous")
            self.label2idx: Dict[str, int] = vertices
        else:
            self.label2idx: Dict[str, int] = {n: i for i, n in enumerate(vertices)}
        self.idx2label: Dict[int, str] = {i: k for k, i in self.label2idx.items()}

    def __getitem__(self, key: Union[str, int]) -> Union[int, str]:
        """Get either the index or vertex label based on the other one.
        :param key: The vertex label or index
        :returns: the vertex index of the label is given or the vertex label 
        if index is given.
        """
        if isinstance(key, int):
            return self.idx2label[key]
        return self.label2idx[key]

    def __contains__(self, key: Union[str, int]) -> bool:
        """Check if the graph has a vertex labeled key.
        :param key: The vertex label or index you are asking about
        :returns: True if the vertex exists, False otherwise
        """
        if isinstance(key, int):
            return key in self.idx2label
        return key in self.label2idx

    def add_vertex(self, label: Optional[str]) -> str:
        """Add a vertex to the graph.
        :param label: The label to give the new vertex.
        :returns: The vertex label
        """
        raise NotImplementedError

    def _add_vertex(self, label: Optional[str]) -> str:
        """Add a vertex to the label2idx with a given label or a new one.
        :param label: The label for the new vertex
        :returns: The label for the new vertex
        """
        if label is None:
            label = str(len(self.label2idx))
        if label in self.label2idx:
            raise ValueError(f"Node labels must be unique, label {label} is already in use.")
        idx = len(self.label2idx)
        self.label2idx[label] = idx
        self.idx2label[idx] = label
        return idx

    def add_edge(self, source: Union[str, int], target: Union[str, int], weight: float = 1.0) -> None:
        """Add an edge to the graph.
        :param source: The vertex label or index of the edge source
        :param target: The vertex label or index of the edge target
        :param weight: The weight to put on the edge
        :raises ValueError: When the source and target node are the same, when the weight is less than zero
        """
        raise NotImplementedError

    @property
    def density(self) -> float:
        """Get the density of the graph.
        The density of a graph is the ratio of edges that the graph has to the number
        it could possibly have, this is bounded by 0 and 1.
        ```math
            D = \frac{|E|}{|V|(|V| - 1)}
        ```
        """
        return self.edge_count / (self.vertex_count * (self.vertex_count - 1))

    @property
    def edge_count(self) -> int:
        """The number of edges in the graph."""
        raise NotImplementedError

    @property
    def vertex_count(self) -> int:
        """The number of vertices in the graph."""
        raise NotImplementedError

    def __str__(self) -> str:
        """A summary of the graph.
        Graph summary includes the number of vertices and edges as well as the
        density of the graph.
        """
        return f"G(V={self.vertex_count}, E={self.edge_count}, D={self.density})"

    def print_graph(self, label_lengths: Optional[int] = None) -> None:
        """Print the graph is a human readable way.
        :param label_length: A cut-off on the length of a single label while printing.
        """
        raise NotImplementedError

    def to_dot(self, directed: bool = False, label_length: Optional[int] = None) -> str:
        """Get a dot representation of the graph.
        The dot graph includes vertex labels and edge weights.
        :param directed: Should the dot representation be directed of not. Most graphs created
            in the package are directed but have the same weight in either direction so we
            can collapse the graph into an undirected weighted graph for cleaner plotting.
            Note: Collapsing this doesn't check that the weights in each direction are the same,
            it just plots a single edge.
        :param label_length: A cut-off on the length allowed for a single label in the printing.
        :returns: The representations of the graph as a dot string.
        """
        raise NotImplementedError

In [15]:
class AdjacencyList(Graph):
    def __init__(self, vertices: Dict[str, int]):
        super().__init__(vertices)
        self._vertices: List[Vertex] = [Vertex(l) for l in self.label2idx]

    @property
    def vertices(self) -> List[Vertex]:
        """The vertices in this graph."""
        return self._vertices

    def add_vertex(self, label: Optional[str]) -> int:
        """Add a vertex to the graph.
        :param label: The label to give the new vertex.
        :returns: The vertex index
        """
        idx = self._add_vertex(label)
        if idx != len(self.vertices):
            raise ValueError(
                "The added vertex has a label that is out of order, expected: {len(self.vertices)} found: {idx}"
            )
        self.vertices.append(Vertex(label))
        return idx

    def add_edge(self, source: Union[str, int], target: Union[str, int], weight: float = 1.0) -> None:
        """Add an edge to the graph.
        :param source: The vertex label or index of the edge source
        :param target: The vertex label or index of the edge target
        :param weight: The weight to put on the edge
        :raises ValueError: When the source and target node are the same, when the weight is less than zero
        """
        if weight < 0.0:
            raise ValueError(f"Edge weight must be greater than zero, got {weight}")
        source_idx = source if isinstance(source, int) else self[source]
        target_idx = target if isinstance(target, int) else self[target]
        if source_idx == target_idx:
            raise ValueError(f"Self loops are not allowed, found edge with source and target if {source_idx}")
        source_vertex = self.vertices[source_idx]
        target_vertex = self.vertices[target_idx]
        source_vertex.edges_out[target_idx] = weight
        target_vertex.edges_in[source_idx] = weight

    @property
    def vertex_count(self) -> int:
        """The number of vertices in the graph."""
        return len(self.vertices)

    @property
    def edge_count(self) -> int:
        """The number of edges in the graph."""
        return sum(v.degree_out for v in self.vertices)

    def print_graph(self, label_length: Optional[int] = None) -> None:
        """Print the graph is a human readable way.
        :param label_length: A cut-off on the length of a single label while printing.
        """
        print(str(self))
        for v in self.vertices:
            print(f"\tVertex {self[v.value]}: {v.value[:label_length]}")
            print(f"\t\tOutbound:")
            for idx, weight in v.edges_out.items():
                print(f"\t\t\t{self[v.value]} -> {idx}: {weight}")
            print(f"\t\tInbound:")
            for idx, weight in v.edges_in.items():
                print(f"\t\t\t{self[v.value]} <- {idx}: {weight}")

    def to_dot(self, directed: bool = False, label_length: Optional[int] = None) -> str:
        """Get a dot representation of the graph.
        The dot graph includes vertex labels and edge weights.
        :param directed: Should the dot representation be directed of not. Most graphs created
            in the package are directed but have the same weight in either direction so we
            can collapse the graph into an undirected weighted graph for cleaner plotting.
            Note: Collapsing this doesn't check that the weights in each direction are the same,
            it just plots a single edge.
        :param label_length: A cut-off on the length allowed for a single label in the printing.
        :returns: The representations of the graph as a dot string.
        """
        if directed:
            return self._to_directed_dot(label_length)
        return self._to_undirected_dot(label_length)

    def _to_directed_dot(self, label_length: Optional[int] = None) -> str:
        """Get a dot representation of the graph as a directed graph.
        :param label_length: A cut-off on the length allowed for a single label in the printing.
        :returns: The representations of the graph as a dot string.
        """
        dot = ["digraph G {"]
        for v in self.vertices:
            dot.append(f'\t{self[v.value]} [label="{v.value[:label_length]}"];')
            for idx, weight in v.edges_out.items():
                dot.append(f'\t{self[v.value]} -> {idx} [label="{weight}"];')
        dot.append("}")
        return "\n".join(dot)

    def _to_undirected_dot(self, label_length: Optional[int] = None) -> str:
        """Get a dot representation of the graph as a undirected graph.
        Note:
            This doesn't check that graph edges can actually be collapsed into a single edge.
        :param label_length: A cut-off on the length allowed for a single label in the printing.
        :returns: The representations of the graph as a dot string.
        """
        dot = ["graph G {"]
        edges = set()
        for v in self.vertices:
            dot.append(f'\t{self[v.value]} [label="{v.value[:label_length]}"];')
            for idx, weight in v.edges_out.items():
                if (self[v.value], idx) in edges or (idx, self[v.value]) in edges:
                    continue
                dot.append(f'\t{self[v.value]} -- {idx} [label="{weight}"];')
                edges.add((self[v.value], idx))
        dot.append("}")
        return "\n".join(dot)

And now we get into TextRank! This first function calculates the total weight for a collection of edges.

In [12]:
ConvergenceType = Enum("ConvergenceType", "ALL ANY")
def sum_edges(edges: Dict[str, float]) -> float:
    return sum(edges.values())

This function accumulates the scores from all nodes that have incoming connections to a given node.

In [13]:
def accumulate_score(vertex: Vertex, ws: List[float], denom: List[float]):
    return math.fsum([weight / denom[edge] * ws[edge] for edge, weight in vertex.edges_in.items()])

This function generates the initial scores for each node and pre-computes the outgoing strength for the Adjacency List graph. The sum of the weights for outbound edges for a given node doesn't change as text rank runs because it is based only on the values in the graph, not on ws for the node so we can pre-compute and reuse it instead of always recalculating it.

In [16]:
def text_rank_init(
    graph: AdjacencyList, uniform: bool = True, seed: Optional[int] = None
) -> Tuple[List[float], List[float]]:
    random.seed(seed)
    denom = [sum_edges(v.edges_out) for v in graph.vertices]
    # If the sum off all outgoing edges of V_j is 0.0 then the incoming edge from V_j to V_i will be 0.0
    # We can use anything as the denominator and the value will still be zero
    denom = [d if d != 0.0 else 1.0 for d in denom]
    if uniform:
        ws = [1 / len(graph.vertices) for _ in graph.vertices]
    else:
        ws = [random.random() for _ in graph.vertices]
        norm = sum(ws)
        ws = [w / norm for w in ws]
    return ws, denom

This function calculates the new score for each node.

In [22]:
def text_rank_update(
    graph: AdjacencyList, ws: List[float], denom: List[float], dampening: float = 0.85
) -> List[float]:
    updates = [accumulate_score(v, ws, denom) for v in graph.vertices]
    ws = [(1 - dampening) + dampening * update for update in updates]
    return ws

This function outputs the TextRank score.

In [17]:
def text_rank_output(graph: AdjacencyList, ws: List[float]) -> List[Tuple[str, float]]:
    norm = sum(ws)
    ws = [w / norm for w in ws]
    return sorted(zip(map(lambda v: v.value, graph.vertices), ws), key=itemgetter(1), reverse=True)

This function runs TextRank.

In [18]:
def text_rank(
    graph: Graph,
    dampening: float = 0.85,
    convergence: float = 0.0001,
    convergence_type: ConvergenceType = ConvergenceType.ALL,
    niter: int = 200,
    uniform: bool = False,
    seed: Optional[int] = None,
) -> List[Tuple[str, float]]:
    """Implementation of text rank from here https://www.aclweb.org/anthology/W04-3252.pdf
    :param graph: The graph we are running text rank on
    :param dampening: A scalar between 0 and 1. Used to simulate randomly jumping from one vertex to another.
    :param convergence: An early stopping criteria, when any or all of the node scores change by less than `convergence`
        we stop updating the graph. Set to `0` to turn off early stopping.
    :param convergence_type: Should we stop when all nodes move less than `convergence` or when a single node does
    :param niter: An upper bound on the number of iterations to run
    :param uniform: Should we initialize state vector to have equal prob for each node?
    :param seed: A reproducability seed to initialization of the node scores.
    :returns: Pairs of (node label, scores) sorted by score
    """
    if not 0 <= dampening <= 1:
        raise ValueError(f"dampening must be between `0` and `1`, got {dampening}")
    converge = all if convergence_type is ConvergenceType.ALL else any

    ws_prev, denom = text_rank_init(graph, uniform=uniform, seed=seed)

    for _ in range(niter):
        ws = text_rank_update(graph, ws_prev, denom, dampening)
        if converge(abs(p - c) < convergence for p, c in zip(ws_prev, ws)):
            break
        ws_prev = ws

    return text_rank_output(graph, ws)

In [19]:
def sentence_graph(
    sentences: List[str],
    sim: Callable[..., float] = overlap,
    norm: Callable[[str], str] = norm_sentence,
    GraphType: Type[Graph] = AdjacencyList,
) -> Tuple[Graph, Dict[str, List[int]]]:
    """Generate a fully connected graph with edges between all sentences.
    Note:
        This also generates a dict mapping normalized vertex labels to their offsets in the original
        data. This can be used to run text rank on normalized data but return the original strings.
        You can also sort the output by offsets to make it maybe more readable?
    :param sentences: The sentences to summarize.
    :param sim: A callable that returns the similarity between two vertices, used to set the weight of the edge.
        The callable should have a signature like:
            sim(
                normed_s1,
                normed_s2,
                raw_s1=raw_s1,
                raw_s2=raw_s2,
                s1_idx=s1_idx,
                s2_idx=s2_idx,
            ) -> float:
        Where normed_s1/2 is the normalized strings of the two sentences, raw_s1/2 is the version of the sentence
        before getting normalized and s1/2_idx is the index of the sentences in the token list. This should
        facilitate both simple and complex similarity functions and also experiments that the actual flow of text
        to determine connections.
    :param norm: A function the returns a normalized version of the input sentence. Default implementation lowercases
        string and removes non alpha-numeric characters.
        This is used so simple similarity functions like the set overlap in the paper work well.
    :param GraphType: The Graph class to use.
    :returns: The constructed graph and offsets mapping normalized vertex labels to their place in the original text.
    """
    offsets = defaultdict(list)
    normed = [norm(sentence) for sentence in sentences]
    for i, norm in enumerate(normed):
        offsets[norm].append(i)

    vocab = build_vocab(normed)
    graph = GraphType(vocab)

    for (i, src), (j, tgt) in combinations(enumerate(normed), 2):
        graph.add_edge(src, tgt, sim(src, tgt, raw_s1=sentences[i], raw_s2=sentences[j], s1_idx=i, s2_idx=j))
        graph.add_edge(tgt, src, sim(tgt, src, raw_s1=sentences[j], raw_s2=sentences[i], s1_idx=j, s2_idx=i))

    return graph, offsets

And finally, our summarize function:

In [20]:
def summarize(
    sentences: List[str],
    nsents: Optional[int] = None,
    keep_order: bool = True,
    dampening: float = 0.85,
    convergence: float = 0.0001,
    convergence_type: ConvergenceType = ConvergenceType.ALL,
    niter: int = 200,
    seed: Optional[int] = None,
    sim: Callable[..., float] = overlap,
    norm: Callable[[str], str] = norm_sentence,
    GraphType: Type[Graph] = AdjacencyList,
) -> List[str]:
   
    graph, offsets = sentence_graph(sentences, sim, norm, GraphType)
    if nsents is None:
        nsents = len(sentences) // 3
    selected = text_rank(
        graph, dampening=dampening, convergence=convergence, convergence_type=convergence_type, niter=niter, seed=seed,
    )[:nsents]
    indices = [offsets[s[0]][0] for s in selected]
    if keep_order:
        return [sentences[i] for i in sorted(indices)]
    return [sentences[i] for i in indices]

In [23]:
summs = []
sents = []
sentences = []
for file_name in Path('./data').glob('*.json'):
    with open(file_name) as f:
        sentences.append(json.load(f))
for sents in sentences:
    summs.append(summarize(sents, 2))

In [25]:
for summ in summs:
    print(summ)
    print()

['4 Min Read MANILA (Reuters) - Schools and businesses shut across the Philippine capital on Monday as a volcano belched clouds of ash across the city and seismologists warned an eruption could happen at any time, potentially triggering a tsunami.', 'Thousands of people were forced to evacuate their homes around Taal, one of the world’s smallest active volcanoes, which spewed ash for a second day from its crater in the middle of a lake about 70 km (45 miles) south of central Manila.']

['Pfizer, in a statement to Reuters, said: “The FDA and its 2016 advisory panel had access to all of the data and science on Chantix, and all of the adverse events reports.” The plaintiff experts’ reports, the company noted, were not original science and instead reflected views of the underlying science that differed from the FDA and the advisory panel’s conclusions.', 'CPSC spent the next few months trying to get Yamaha to comply, at one point complaining to the company that it was sending “duplicative”