In [1]:
import glob, os
import re
import bleach
import random
from nltk.tokenize import sent_tokenize, word_tokenize

In [2]:
TARGET_CHARACTER = "ROSS"
class Line: 
    def __init__ (self, speaker, line):
        self.speaker = speaker
        self.line = line

    def __str__ (self):
        return self.speaker + ": " + self.line

In [3]:
def process_line(line):
    step1 = line.strip()
    step2 = re.sub(r'\([^)]*\)',"", step1)
    return step2

In [4]:
def shorten_pair(pair):
    (query, reply) = pair
    query_sentences = sent_tokenize(query.line)
    reply_sentences = sent_tokenize(reply.line)

    new_input = ""
    num_input = 0
    for s in reversed(query_sentences):
        s_tokens = word_tokenize(s)
        if num_input + len(s_tokens) <= 30:
            new_input = s + " " + new_input
            num_input += len(s_tokens)

        else:
            break

    new_input = new_input.strip()

    new_reply = ""
    num_reply = 0
    for s in reply_sentences:
        s_tokens = word_tokenize(s)
        if num_reply + len(s_tokens) <= 30:
            new_reply = new_reply + " " + s
            num_reply += len(s_tokens)
        else:
            break

    new_reply = new_reply.strip()
    
    p1 = Line(query.speaker, new_input)
    p2 = Line(reply.speaker, new_reply)
    return (p1,p2)

In [5]:
def make_pairs(lines):
    scene_characters = {}
    lines_structured = []
    for l in lines:
        l_fields = l.split(":")
        if len(l_fields) == 1:
            continue
            
        character = l_fields[0].strip().upper()
        if character not in scene_characters:
            scene_characters[character] = 0
        scene_characters[character] += 1
        character_words = l_fields[1].strip()
        lines_structured.append(Line(character, character_words))
        
    if TARGET_CHARACTER not in scene_characters:
        return []
    
    line_pairs = []
    prev_line = lines_structured[0]
    for l in lines_structured[1:]:
        if l.speaker == TARGET_CHARACTER:
            truncated_pair = shorten_pair((prev_line, l))
            line_pairs.append(truncated_pair)
        elif l.speaker == "ALL" and prev_line.speaker != TARGET_CHARACTER:
            l_new = Line(TARGET_CHARACTER, l.line)
            truncated_pair = shorten_pair((prev_line, l_new))
            line_pairs.append(truncated_pair)
            
        elif TARGET_CHARACTER in l.speaker and prev_line.speaker != TARGET_CHARACTER:
            l_new = Line(TARGET_CHARACTER, l.line)
            truncated_pair = shorten_pair((prev_line, l_new))
            line_pairs.append(truncated_pair)
        prev_line = l 
    
    return line_pairs
    

In [6]:
def pairs_to_string(pairs):
    ret = ""
    for (p1, p2) in pairs:
        ret += str(p1.line) + " <+++++> " + str(p2.line) + "\n"
    return ret 

In [7]:
def process_file(file_name):
    f = open(file_name, 'r', encoding = "ISO-8859-1")
    f_contents = f.read()
    f.close()
    scenes = re.compile("\[.*\]").split(f_contents)
    
    pairs_from_file = ""
    for scene in scenes:
        scene_strip = scene.strip()
        if scene_strip == "":
            continue 
        scene_lines = scene.split("\n")
        processed_lines = []
        for l in scene_lines:
            tmp = process_line(l)
            if tmp != "":
                processed_lines.append(tmp)
        line_pairs = make_pairs(processed_lines)
        pairs_txt = pairs_to_string(line_pairs)
        pairs_from_file += pairs_txt
    return pairs_from_file

In [9]:
scripts = os.listdir("scripts/")
all_data = open("Ross_responses.txt", 'w')
for s in scripts:
    file_data = process_file("scripts/" + s)
    if file_data == None:
        continue
    all_data.write(file_data)
all_data.close()

# Partition data into 3 sets

In [15]:
def divide_data():
    f = open("Cornell_data_all.txt", 'r', encoding = "ISO-8859-1")
    data = f.read().split("\n")[:-1]
    print(len(data))
    random.Random(1776).shuffle(data)
    #8719 pieces of Ross data, (test, valid, train) 1000, 2000, rest
    #221605 pieces of Cornell data, (test, valid, train) 44321, 35456, rest
    test_data = data[:44321]
    valid_data = data[44321:79777]
    train_data = data[79777:]
    
    test_file = open("Cornell_test.txt", 'w')
    for d in test_data:
        test_file.write(d+"\n")
    test_file.close()
    
    valid_file = open("Cornell_valid.txt", 'w')
    for d in valid_data:
        valid_file.write(d+"\n")
    valid_file.close()
    
    train_file = open("Cornell_train.txt", 'w')
    for d in train_data:
        train_file.write(d+"\n")
    train_file.close()

divide_data()

221605


In [13]:
177284/5

35456.8

In [14]:
35456+44321

79777