# Setup

## Import Libraries

In [None]:
import pandas as pd
import numpy as np
import re
import json
from os import listdir
from tqdm import tqdm
from itertools import chain
from math import factorial
import random

random.seed(5678)
np.random.seed(5678)

## Parameters

In [None]:
path_to_data = "/home/luke/Archive/200—Work/HuntLab/2021 ~ Argument Processor/Data/"
path_to_models = "/home/luke/Archive/200—Work/HuntLab/2021 ~ Argument Processor/Models/"

## Initialisation

In [None]:
ops = {
    "suggest-intermediary-claims": {},
    "suggest-reasons": {},
    "suggest-objections": {},
    "suggest-conclusion": {},
    "suggest-copremise": {},
    "suggest-abstraction": {}
}

## Import Data

In [None]:
# Create function for stripping markdown markup from a string.

from markdown import Markdown
from io import StringIO

def unmark_element(element, stream=None):
    if stream is None:
        stream = StringIO()
    if element.text:
        stream.write(element.text)
    for sub in element:
        unmark_element(sub, stream)
    if element.tail:
        stream.write(element.tail)
    return stream.getvalue()


# patching Markdown
Markdown.output_formats["plain"] = unmark_element
__md = Markdown(output_format="plain")
__md.stripTopLevelTags = False


def unmark(text):
    return __md.convert(text)

In [None]:
# Import complete Kialo maps.

path = path_to_data + "Kialo—Lenz-2020/Raw-Export/"

fs = [f for f in listdir(path) if f[-4:] == ".txt" and "(1)" not in f]
    # Only include text files (not system files).
    # Exclude files that are duplicates (those that have "(1)" in the filename).

def make_claim(claim):

    claim = re.search('((\d+\.)+) (Pro: |Con: )?(.*)', claim)
    
    if claim == None:
        return {
            "id": None
        }
    else:
        claim = claim.groups()
        return {
            "id": claim[0],
            "type": claim[-2][:3] if claim[-2] != None else "Seed",
            "txt": unmark(claim[-1])
        }

def is_child(d, id):
    if len(d) < len(id):
        return False
    elif len(d.split('.')) != len(id.split('.')) + 1:
        return False
    else:
        return d[:len(id)] == id

def make_map(filename):
    
    with open(path + filename, 'r', encoding='utf-8') as file:
        txt = file.read()
    
    # Extract title.
    title = txt.split('\n', 1)[0][18:]
    
    # Get array of claims.
    claims = '\n'.join(txt.splitlines()[2:]).split('\n1.')
    seed = claims[0]
    claims = [seed] + ['1.' + s for s in claims[1:]]

    # Reformat each claim.
    claims = [make_claim(claim) for claim in claims]
    claims = [claim for claim in claims if claim["id"] != None]
    
    # Convert to dictionary.
    m = {claim["id"]: claim for claim in claims}

    # Store ids of children.
    ids = [claim["id"] for claim in claims]
    for id in ids:
        m[id]["claims"] = [d for d in ids if is_child(d, id)]

    # Resolve references to other claims.
    n_references = 0
    for id in ids:
        if m[id]["txt"][:7] == "-> See ":
            n_references = n_references + 1
            ref_id = m[id]["txt"][7:]
            if ref_id in ids:
                m[id]["txt"] = m[ref_id]["txt"]
            else:
                m[id]["txt"] = ref_id.split(' ', 2)[2]

    return {
        "title": title,
        "filename": filename,
        "ids": ids,
        "map": m,
        "n_references": n_references
    }

maps = [ make_map(f) for f in tqdm(fs) ]

# Save maps to file.

with open(path_to_models + 'maps.json', 'w') as fp:
    json.dump(maps, fp)

In [None]:
# Read in maps.

with open(path_to_models + 'maps.json') as json_file:
    maps = json.load(json_file)

In [None]:
# Truncate maps to maximum depth of 4.

max_depth = 3

In [None]:
# Remove all newline characters from claims.

for map in tqdm(maps):
    for id in map['ids']:
        map['map'][id]['txt'] = map['map'][id]['txt'].replace('\n', ' ')

In [None]:
# Randomly assign maps to train/val/test splits.

# Randomly re-order examples.
random.shuffle(maps)
revisions = revisions.sample(frac=1, random_state=5678)

# Assign to splits.

revisions_train, revisions_validate, revisions_test = np.split(revisions, [int(.6 * len(revisions)), int(.8 * len(revisions))])
revisions = {
    "train": revisions_train,
    "validate": revisions_validate,
    "test": revisions_test
}

maps_train, maps_validate, maps_test = np.split(maps, [int(.6 * len(maps)), int(.8 * len(maps))])
maps = {
    "train": maps_train,
    "validate": maps_validate,
    "test": maps_test
}

## Extract forks and branches

In [None]:
# Extract forks.

forks = {
    "train": [],
    "validate": [],
    "test": []
}

def get_forks(map):
    
    forks = []

    for id in map["ids"]:
        if len(map["map"][id]["claims"]) > 0 and len(id.split('.')) <= max_depth:
            forks.append({
                "seed": map["map"][id],
                "claims": [map["map"][d] for d in map["map"][id]["claims"]],
                "map_title": map['map']['1.']['txt']
            })

    return forks

for split in maps.keys():
    for m in tqdm(maps[split]):
        forks[split].extend(get_forks(m))

In [None]:
# Extract branches.

branches = {
    "train": [],
    "validate": [],
    "test": []
}

def get_branches(branch, map, split):
    "Recursive function to get all chains of reasoning."
    
    if len(branch[-1]["claims"]) == 0:
        branches[split].append({ "branch": branch, "map_title": map['map']['1.']['txt'] })
    
    elif len([ id for id in branch[-1]["claims"] if map["map"][id]["type"] == "Pro"]) == 0:
        branches[split].append({ "branch": branch, "map_title": map['map']['1.']['txt'] })
    
    else:
        leaf = branch[-1]
        for id in leaf["claims"]:
            if map["map"][id]["type"] == "Pro":
                get_branches(branch + [map["map"][id]], map, split)

for split in maps.keys():
    for m in maps[split]:
        seed_branch = [m["map"]["1."]]
        get_branches(seed_branch, m, split)

    # Augment with all sub-chains of length 3 or more.
   
    sub_branches = []
    for branch in branches[split]:
        if len(branch['branch']) < 4:
            next
        else:
            for start in range(0, len(branch['branch'])-2):
                for end in range(start+3, len(branch['branch'])+1):
                    sub_branches.append({ "branch": branch['branch'][start:end], "map_title": branch['map_title']})

    branches[split] = sub_branches

## Generate task-specific datasets

### `suggest-intermediary-claims`

In [None]:
op = "suggest-intermediary-claims"

# Create dataset.

def make_datum(b):
    d = {
        "in": {
            "start": b['branch'][-1],
            "end": b['branch'][0]
        },
        "out": b['branch'][::-1],
        "map_title": b['map_title']
    }
    return(d)

for split in branches.keys():
    ops[op][split] = [make_datum(b) for b in branches[split]]

# Define prompt-generation function(s).

def make_prompts(d):
    
    prompts = {}

    prompts['map_title'] = d["map_title"]
    
    prompts["manual"] = {
        "in": f'Input: {d["in"]["start"]["txt"]} -> {d["in"]["end"]["txt"]}\n\nOutput: ',
        "out": " -> ".join([claim["txt"] for claim in d["out"]])
    }

    prompts["few-shot"] = {
        "in": f"Reason from the start claim to the end claim.\n\nStart and end claim: A cyclone hit Queensland, Australia. ~ The price of bananas increased.\nCompleted chain of reasoning:\n* A cyclone hit Queensland, Australia. ~ The cyclone destroyed banana crops. ~ Supply of bananas went down, whilst demand stayed constant. ~ The price of bananas increased.\n\nStart and end claim: Education levels improve. ~ Society becomes more politically polarised.\nCompleted chain of reasoning:\n* Education levels improve. ~ People become more skilled at finding high-quality justifications for their existing beliefs (confirmation bias). ~ Society becomes more politically polarised.\n\nStart and end claim: People move out of cities and into the countryside. ~ Greenhouse gas emissions increase.\nCompleted chain of reasoning:\n* People move out of cities and into the countryside. ~ Population density decreases. ~ Both people and products need to be transported further. ~ They are transported using vehicles that burn fossil fuels. ~ Greenhouse gas emissions increase.\n\nStart and end claim: {d['in']['start']['txt']} ~ {d['in']['end']['txt']}\nCompleted chain of reasoning\n*",
        "out": " ~ ".join([claim["txt"] for claim in d["out"]])
    }

    prompts["soft"] = {
        "in": f'{d["in"]["start"]["txt"]} -> {d["in"]["end"]["txt"]}\n\nAnswer: ',
        "out": " -> ".join([claim["txt"] for claim in d["out"]])
    }

    return(prompts)

ops[op]["make_prompts"] = make_prompts

### `suggest-reasons`

In [None]:
op = "suggest-reasons"

# Create dataset.

def sub_lists(l):
    lists = []
    for i in range(len(l) + 1):
        for j in range(i):
            lists.append(l[j: i])
    return lists

def make_data(v):
    ds = []

    pros = [claim for claim in v["claims"] if claim["type"] == "Pro"]
    pros_sublists = sub_lists(pros)

    for sublist in pros_sublists:
        for p_idx in range(min(10, factorial(len(sublist)))):
            # Randomly generate at most 10 permutations of the subset.
            # The number of permutations is capped to prevent factorial explosion in runtime.
            
            permutation = list(np.random.permutation(sublist))

            ds.append({
                "in": {
                    "seed": v["seed"],
                    "pros": permutation[:-1]
                },
                "out": permutation[-1],
                "map_title": v["map_title"]
            })
    
    return(ds)

for split in forks.keys():
    forks_with_pros = [v for v in forks[split] if len([claim for claim in v["claims"] if claim["type"] == "Pro"]) > 0]
    data = [make_data(v) for v in tqdm(forks_with_pros)]
    data = list(chain.from_iterable(data))
    ops[op][split] = data

# Define prompt-generation function(s).

def make_prompts(d):
    
    prompts = {}

    prompts["map_title"] = d["map_title"]

    prompts["manual"] = {
        "in": d["in"]["seed"]["txt"] + "\n\nPros:\n- " + "\n- ".join([claim['txt'] for claim in d["in"]["pros"]]) + "\n- ",
        "out": d["out"]["txt"]
    }

    prompts["zero-shot"] = {
        "in": "List reasons why: \"" + d['in']['seed']['txt'] + "\n\nReasons:" + ''.join(['\n* ' + claim['txt'] for claim in d['in']['pros']]) +"\n* ",
        "out": d["out"]["txt"]
    }

    prompts["few-shot"] = {
        "in": "Suggest reasons why each claim is true.\n\nClaim: The COVID-19 pandemic was caused by a lab leak in Wuhan.\nReason 1: A virology lab in Wuhan was conducted research on coronaviruses.\nReason 2: Biosecurity leaks are relatively common.\n\nClaim: The world should transition away from fossil fuels.\nReason 1: The burning of fossil fuels releases greenhouse gases.\nReason 2: Greenhouse gases warm the planet and lead to sea level rise.\nReason 3: Burning fossil fuels leads to air pollution and preventable deaths.\n\nClaim: Socrates is mortal.\nReason 1: Socrates is human.\nReason 2: All humans are mortal.\n\nClaim: " + d['in']['seed']['txt'] + ''.join(["\nReason " + str(index + 1) + ": " + claim['txt'] for index, claim in enumerate(d['in']['pros'])]) + "\nReason " + str(len(d['in']['pros']) + 1) + ": ",
        "out": d["out"]["txt"]
    }

    prompts["soft"] = {
        "in": d["in"]["seed"]["txt"] + "\n\nPros:\n- " + "\n- ".join([claim['txt'] for claim in d["in"]["pros"]]) + "\n- ",
        "out": d["out"]["txt"]
    }

    return(prompts)

ops[op]["make_prompts"] = make_prompts

### `suggest-objections`

In [None]:
op = "suggest-objections"

# Create dataset.

def make_data(v):
    ds = []

    cons = [claim for claim in v["claims"] if claim["type"] == "Con"]
    cons_sublists = sub_lists(cons)

    for sublist in cons_sublists:
        for p_idx in range(min(10, factorial(len(sublist)))):
            # Randomly generate at most 10 permutations of the subset.
            # The number of permutations is capped to prevent factorial explosion in runtime.
            
            permutation = list(np.random.permutation(sublist))

            ds.append({
                "in": {
                    "seed": v["seed"],
                    "cons": permutation[:-1]
                },
                "out": permutation[-1],
                "map_title": v["map_title"]
            })
    
    return(ds)

for split in forks.keys():
    forks_with_cons = [v for v in forks[split] if len([claim for claim in v["claims"] if claim["type"] == "Con"]) > 0]
    data = [make_data(v) for v in tqdm(forks_with_cons)]
    data = list(chain.from_iterable(data))
    ops[op][split] = data

# Define prompt-generation function(s).

def make_prompts(d):
    
    prompts = {}

    prompts["map_title"] = d["map_title"]

    prompts["manual"] = {
        "in": d["in"]["seed"]["txt"] + "\n\nCons:\n- " + "\n- ".join([claim['txt'] for claim in d["in"]["cons"]]) + "\n- ",
        "out": d["out"]["txt"]
    }

    prompts["zero-shot"] = {
        "in": "List objections to the claim that: \"" + d['in']['seed']['txt'] + "\n\nObjections:" + ''.join(['\n* ' + claim['txt'] for claim in d['in']['cons']]) +"\n* ",
        "out": d["out"]["txt"]
    }

    prompts["few-shot"] = {
        "in": "Suggest reasons why each claim is false.\n\nClaim: The COVID-19 pandemic was caused by a lab leak in Wuhan.\nReason 1: Viral incursions from animals to humans are common.\nReason 2: There was a wild animal market in Wuhan that sold bats for human consumption.\n\nClaim: The world should transition away from fossil fuels.\nReason 1: Fossil fuels provide cheap energy that helps lift people out of poverty.\nReason 2: There is ongoing debate about whether fossil fuels are responsible for climate change.\nReason 3: The fossil fuel industry is a major employer in disadvantaged rural areas.\n\nClaim: Socrates is mortal.\nReason 1: Socrates may be a fictional character.\nReason 2: Socrates is still spoken about milennia after he allegedly died.\n\nClaim: " + d['in']['seed']['txt'] + ''.join(["\nReason " + str(index + 1) + ": " + claim['txt'] for index, claim in enumerate(d['in']['cons'])]) + "\nReason " + str(len(d['in']['cons']) + 1) + ": ",
        "out": d["out"]["txt"]
    }

    prompts["soft"] = {
        "in": d["in"]["seed"]["txt"] + "\n\nCons:\n- " + "\n- ".join([claim['txt'] for claim in d["in"]["cons"]]) + "\n- ",
        "out": d["out"]["txt"]
    }

    return(prompts)

ops[op]["make_prompts"] = make_prompts

### `suggest-copremise`

In [None]:
op = "suggest-copremise"

for split in maps.keys():
    ops[op][split] = []

def make_datum(v):
    d = {
        "in": {
            "seed": v['seed'],
            "reasons": [claim for claim in v["claims"] if claim["type"] == "Pro"]
        },
        "map_title": v["map_title"]
    }
    return(d)

forks_with_pros = [v for v in forks['test'] if len([claim for claim in v["claims"] if claim["type"] == "Pro"]) > 0]
ops[op]['test'] = [make_datum(v) for v in forks_with_pros]

def make_prompts(d):
    
    prompts = {}

    prompts["map_title"] = d["map_title"]

    prompts["few-shot"] = {
        "in": "Identify assumptions with the following arguments.\n\nPremises:\n* All men are mortal.\nConclusion: Socrates is mortal.\nAssumptions:\n* Socrates is a man.\n\nPremises:\n* Digital literacy of the population improves.\nConclusion: Misinformation and conspiracy theories become less prevalent.\nAssumptions:\n* The only reason people believe misinformation is that they lack digital literacy.\n* Improved digital literacy will not exacerbate confirmation bias.\n\nPremises:\n* I want a PhD from a recognised university.\n* London has lots of organisations and networking opportunities in my field.\nConclusion: I should do a PhD in London.\nAssumptions:\n* I can afford to live in London.\n* I will have the flexibility to pursue research that interests me.\n\nPremises:\n* " + '\n* '.join([claim['txt'] for claim in d['in']['reasons']]) + "\nConclusion: " + d['in']['seed']['txt'] + "\nAssumptions:\n*"
    }

    return(prompts)

ops[op]["make_prompts"] = make_prompts

### `suggest-abstraction`

In [None]:
op = "suggest-abstraction"

for split in maps.keys():
    ops[op][split] = []

def make_datum(b):
    d = {
        "in": {
            "context": b['branch'][0],
            "reason": b['branch'][1]
        },
        "map_title": b["map_title"]
    }
    return(d)

ops[op]['test'] = [make_datum(b) for b in branches['test']]

def make_prompts(d):
    
    prompts = {}

    prompts["map_title"] = d["map_title"]

    prompts["few-shot"] = {
        "in": "In the following examples, notice how the first claim is rephrased to be more abstract.\n\nToo specific: Nuclear power has very low greenhouse gas emissions. => We should be building more nuclear power plants.\nBetter: Nuclear power is good for the environment. -> We should be building more nuclear power plants.\n\nToo specific: School uniforms ensure that everyone is wearing the same clothes. => Schools should make students wear a uniform.\nBetter: Uniforms reduce class-based discrimination. => Schools should make students wear a uniform.\n\nToo specific: The Thames barrier has 5 backup generators. => The Thames barrier will not fail.\nBetter: The Thames barrier has been designed with lots of redundancy. => The Thames barrier will not fail.\n\nToo specific: " + d['in']['reason']['txt'] + " => " + d['in']['context']['txt'] + "\nBetter: "
    }

    return(prompts)

ops[op]["make_prompts"] = make_prompts

### `suggest-conclusion`

In [None]:
# Change max_depth for extracted forks to 6 so that there is more training data for this operation.

max_depth = 5

forks = {
    "train": [],
    "validate": [],
    "test": []
}

def get_forks(map):
    
    forks = []

    for id in map["ids"]:
        if len(map["map"][id]["claims"]) > 0 and len(id.split('.')) <= max_depth:
            forks.append({
                "seed": map["map"][id],
                "claims": [map["map"][d] for d in map["map"][id]["claims"]],
                "map_title": map['map']['1.']['txt']
            })

    return forks

for split in maps.keys():
    for m in tqdm(maps[split]):
        forks[split].extend(get_forks(m))

In [None]:
op = "suggest-conclusion"

# Create dataset.

def make_datum(v):
    d = {
        "in": [claim for claim in v["claims"] if claim["type"] == "Pro"],
        "out": v["seed"],
        "map_title": v["map_title"]
    }
    return(d)

for split in forks.keys():
    data = [make_datum(v) for v in forks[split]]
    data = [d for d in data if len(d["in"]) > 0]
    ops[op][split] = data

# Define prompt-generation function(s).

def make_prompts(d):
    
    prompts = {}

    prompts["map_title"] = d["map_title"]

    prompts["manual"] = {
        "in": "\n".join([f"- {claim['txt']}" for claim in d["in"]]) + "\n\nConclusion: ",
        "out": d["out"]["txt"]
    }

    prompts["zero-shot"] = {
        "in": "Consider the facts:" + ''.join(['\n* ' + claim['txt'] for claim in d['in']]) + "\n\nWe must conclude that: ",
        "out": d["out"]["txt"]
    }

    prompts["few-shot"] = {
        "in": "Suggest what we can conclude from each of the following sets of facts.\n\nFacts:\n* Viral incursions from animals to humans are common.\n* There was a wild animal market in Wuhan that sold bats for human consumption.\nConclusion: The SARS-CoV-2 coronavirus may have originated in bats.\n\nFacts:\n* Burning fossil fuels releases greenhouse gases that cause climate change.\n* Climate change may make parts of the planet uninhabitable for humans.\nConclusion: The world should transition away from fossil fuels.\n\nFacts:\n* Socrates is human.\n* All humans are mortal.\nConclusion: Socrates is mortal.\n\nFacts:" + ''.join(['\n* ' + claim['txt'] for claim in d['in']]) + "\nConclusion: ",
        "out": d["out"]["txt"]
    }

    prompts["soft"] = {
        "in": "\n".join([f"- {claim['txt']}" for claim in d["in"]]) + "\n\nConclusion: ",
        "out": d["out"]["txt"]
    }

    return(prompts)

ops[op]["make_prompts"] = make_prompts

In [None]:
# If desired, pre-compile all examples into prompt strings.

if False:
    for op in ops.keys():
        for split in maps.keys():
            if 'make_prompts' in ops[op].keys():
                make_prompts = ops[op]['make_prompts']
                ops[op][split] = [make_prompts(prompt) for prompt in ops[op][split]]

In [None]:
# Truncate train/test/eval examples to a feasible maximum.

for op in ops.keys():
    for split in ops[op].keys():
        if split == 'train':
            if len(ops[op][split]) > 50_000:
                ops[op][split] = random.sample(ops[op][split], 50_000)
        elif split in ['validate', 'test']:
            if len(ops[op][split]) > 10_000:
                ops[op][split] = random.sample(ops[op][split], 10_000)

In [None]:
# Create unique (within each op) IDs for each example.

for op in ops.keys():
    for split in maps.keys():
        for k, prompt in enumerate(ops[op][split]):
            prompt['id'] = k

In [None]:
for key in ops.keys():
    print(key)
    print(f"{len(ops[key]['train'])} / {len(ops[key]['validate'])} / {len(ops[key]['test'])}")

## Export

In [None]:
import dill as pickle

pickle.dump(ops,
            open(path_to_models + 'ops.pickle', 'wb'),
            protocol=4)