In [6]:
import glob, os
import re
import bleach
import random
from nltk.tokenize import sent_tokenize, word_tokenize

In [2]:
TARGET_CHARACTER = "ROSS"
class Line: 
    def __init__ (self, speaker, line):
        self.speaker = speaker
        self.line = line

    def __str__ (self):
        return self.speaker + ": " + self.line

In [3]:
def process_line(line):
    step1 = line.strip()
    step2 = re.sub(r'\([^)]*\)',"", step1)
    return step2

In [4]:
def shorten_pair(pair):
    (query, reply) = pair
    query_sentences = sent_tokenize(query.line)
    reply_sentences = sent_tokenize(reply.line)

    new_input = ""
    num_input = 0
    for s in reversed(query_sentences):
        s_tokens = word_tokenize(s)
        if num_input + len(s_tokens) <= 30:
            new_input = s + " " + new_input
            num_input += len(s_tokens)

        else:
            break

    new_input = new_input.strip()

    new_reply = ""
    num_reply = 0
    for s in reply_sentences:
        s_tokens = word_tokenize(s)
        if num_reply + len(s_tokens) <= 30:
            new_reply = new_reply + " " + s
            num_reply += len(s_tokens)
        else:
            break

    new_reply = new_reply.strip()
    
    p1 = Line(query.speaker, new_input)
    p2 = Line(reply.speaker, new_reply)
    return (p1,p2)

In [5]:
def make_pairs(lines):
    scene_characters = {}
    lines_structured = []
    for l in lines:
        l_fields = l.split(":")
        if len(l_fields) == 1:
            continue
            
        character = l_fields[0].strip().upper()
        if character not in scene_characters:
            scene_characters[character] = 0
        scene_characters[character] += 1
        character_words = l_fields[1].strip()
        lines_structured.append(Line(character, character_words))
        
    if TARGET_CHARACTER not in scene_characters:
        return []
    
    line_pairs = []
    prev_line = lines_structured[0]
    for l in lines_structured[1:]:
        if l.speaker == TARGET_CHARACTER:
            truncated_pair = shorten_pair((prev_line, l))
            line_pairs.append(truncated_pair)
        elif l.speaker == "ALL" and prev_line.speaker != TARGET_CHARACTER:
            l_new = Line(TARGET_CHARACTER, l.line)
            truncated_pair = shorten_pair((prev_line, l_new))
            line_pairs.append(truncated_pair)
            
        elif TARGET_CHARACTER in l.speaker and prev_line.speaker != TARGET_CHARACTER:
            l_new = Line(TARGET_CHARACTER, l.line)
            truncated_pair = shorten_pair((prev_line, l_new))
            line_pairs.append(truncated_pair)
        prev_line = l 
    
    return line_pairs
    

In [6]:
def pairs_to_string(pairs):
    ret = ""
    for (p1, p2) in pairs:
        ret += str(p1.line) + " <+++++> " + str(p2.line) + "\n"
    return ret 

In [7]:
def process_file(file_name):
    f = open(file_name, 'r', encoding = "ISO-8859-1")
    f_contents = f.read()
    f.close()
    scenes = re.compile("\[.*\]").split(f_contents)
    
    pairs_from_file = ""
    for scene in scenes:
        scene_strip = scene.strip()
        if scene_strip == "":
            continue
        scene_lines = scene.split("\n")
        processed_lines = []
        for l in scene_lines:
            tmp = process_line(l)
            if tmp != "":
                processed_lines.append(tmp)
        line_pairs = make_pairs(processed_lines)
        pairs_txt = pairs_to_string(line_pairs)
        pairs_from_file += pairs_txt
    return pairs_from_file

In [9]:
scripts = os.listdir("scripts/")
all_data = open("Ross_all.txt", 'w')
for s in scripts:
    file_data = process_file("scripts/" + s)
    if file_data == None:
        continue
    all_data.write(file_data)
all_data.close()

# Partition data

In [2]:
def clean_line(text):
    text = re.sub('\x85','...', text)
    text = re.sub('\x91','\'', text)
    text = re.sub('\x92','\'', text)
    text = re.sub('\x96', '-', text)
    text = re.sub('\x97', '-', text)
    text = re.sub('Â', '', text)
    return text

In [5]:
def write_to_file(target_file_name, data):
    base = "../data/"
    f_query = open(base+target_file_name+"_query.en", 'w')
    f_reply = open(base+target_file_name+"_reply.en", 'w')
    for d in data:
        d_clean = clean_line(d)
        pair = d_clean.split(" <+++++> ")
        if pair[0].strip() == "" or pair[1].strip() == "":
            continue 
        f_query.write("\n"+pair[0])
        f_reply.write("\n"+pair[1])
    f_query.close()
    f_reply.close()


In [13]:
def split_Ross():
    f = open("Ross_all.txt", 'r', encoding = "ISO-8859-1")
    
    data = f.read().split("\n")[:-1]
    random.Random(1776).shuffle(data)
    
    test = data[:1500]
    valid = data[1500:2500]
    train = data[2500:]
    
    write_to_file("Ross_test", test)
    write_to_file("Ross_valid", valid)
    write_to_file("Ross_train", train)

split_Ross()

In [12]:
def split_Cornell():
    f = open("Cornell_all.txt", 'r', encoding = "ISO-8859-1")
    
    data = f.read().split("\n")[:-1]
    random.Random(1776).shuffle(data)
    
    test = data[:1500]
    valid_2 = data[1500:2500]
    train_2 = data[2500:8719]
    
    valid_1 = data[8719:45000]
    train_1 = data[45000:]
    
    write_to_file("Cornell_test", test)
    write_to_file("Cornell_valid_2", valid_2)
    write_to_file("Cornell_train_2", train_2)
    
    write_to_file("Cornell_valid", valid_1)
    write_to_file("Cornell_train", train_1)
    
split_Cornell()