In [2]:
import os
from collections import Counter
import hashlib

Preprocessing for original CNN\DM data

In [7]:
dm_single_close_quote = u'\u2019' # unicode
dm_double_close_quote = u'\u201d'
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"]
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'

In [6]:
def fix_missing_period(line):
    if "@highlight" in line: 
        return line
    if line=="": 
        return line
    if line[-1] in END_TOKENS: 
        return line
    return line + " ."

def get_art_abs(lines):
    lines = [line.lower() for line in lines]
    lines = [fix_missing_period(line) for line in lines]
    article_lines = []
    highlights = []
    next_is_highlight = False
    for idx,line in enumerate(lines):
        if line == "":
            continue
        elif line.startswith("@highlight"):
            next_is_highlight = True
        elif next_is_highlight:
            highlights.append(line)
        else:
            article_lines.append(line)
    article = ' '.join(article_lines)
    abstract = '. '.join(highlights)
    return article, abstract

def readlines(file):
    with open(file,'r',encoding='utf-8') as f:
        return [line.strip() for line in f]

In [27]:
voc = Counter()
with open('data/stories.tsv','w',encoding='utf-8') as fout:
    for path in ['data/cnn_stories_tokenized/cnn_stories_tokenized/','data/dm_stories_tokenized/dm_stories_tokenized/']:
        for file in os.listdir(path):
            lines = readlines(os.path.join(path,file))
            article, abstract = get_art_abs(lines)
            if article.strip() == "" or abstract.strip() == "":
                continue
            art_tokens = article.split(' ')
            abs_tokens = abstract.split(' ')
            tokens = art_tokens + abs_tokens
            tokens = [t.strip() for t in tokens]
            tokens = [t for t in tokens if t!=""] # remove empty
            voc.update(tokens)
            fout.write('<s> %s </s>\t<s> %s </s>\n'%(article,abstract))

In [8]:
import pickle
with open('data/origin_voc.pickle','wb') as f:
    pickle.dump(voc,f)

In [4]:
def hashhex(s):
    h = hashlib.sha1()
    h.update(s.encode())
    return h.hexdigest()

def get_url_hashes(url_list):
    return [hashhex(url) for url in url_list]

urllist = {}
for a in ['cnn','dailymail']:
    for b in ['training','test','validation']:
        with open('data/url_lists/'+a+'_wayback_'+b+'_urls.txt','r') as f:
            tmp = []
            for line in f:
                tmp.append(line.strip())
            urllist[a+b] = get_url_hashes(tmp)

In [6]:
def ctf_write(f,article,abstract,id):
    article = '<s> '+article.strip()+' </s>'
    abstract = '<s> '+abstract.strip()+' </s>'
    article_words = article.split()
    article_input = []
    article_extended_input = []
    abstract_words = abstract.split()
    abstract_input = []
    abstract_extended_input = []
    oov = {}
    for i,w in enumerate(article_words):
        if i == 400:
            article_input.append(voc['</s>'])
            article_extended_input.append(voc['</s>'])
            break
        try:
            article_input.append(voc[w])
            article_extended_input.append(voc[w])
        except KeyError:
            oov[w] = len(voc)+len(oov)
            article_input.append(voc['<unk>'])
            article_extended_input.append(oov[w])
    for i,w in enumerate(abstract_words):
        if i == 100:
            abstract_input.append(voc['</s>'])
            abstract_extended_input.append(voc['</s>'])
            break
        try:
            abstract_input.append(voc[w])
            abstract_extended_input.append(voc[w])
        except KeyError:
            abstract_input.append(voc['<unk>'])
            try:
                abstract_extended_input.append(oov[w])
            except KeyError:
                abstract_extended_input.append(voc['<unk>'])
    for i,wid in enumerate(article_input):
        if i < len(abstract_extended_input):
            f.write("%d\t|S0 %d:1\t|S1 %d:1\t|S2 %d:1\t|S3 %d:1\n"%(id,wid,article_extended_input[i],abstract_input[i],abstract_extended_input[i]))
            #f.write("%d\t|S0 %d:1\t|S1 %d:1\t|S2 %d:1\n"%(id,wid,article_extended_input[i],abstract_extended_input[i]))
        else:
            f.write("%d\t|S0 %d:1\t|S1 %d:1\n"%(id,wid,article_extended_input[i]))
    return len(article_input),len(abstract_extended_input),len(oov)

voc = {}
with open('./data/SELF_DATA/voc_50k.txt','r',encoding='utf-8') as f:
    for line in f:
        w,id = line.split('\t')
        voc[w] = int(id)

In [None]:
extended = []
for b in ['training','test','validation']:
    i = 0
    print('create '+b)
    with open('data/stories_'+b+'.ctf','w') as fout:
        for url in urllist['cnn'+b]:
            lines = readlines('data/cnn_stories_tokenized/cnn_stories_tokenized/'+url+'.story')
            article, abstract = get_art_abs(lines)
            if article.strip() == "" or abstract.strip() == "":
                continue
            extended.append(ctf_write(fout,article,abstract,i))
            i += 1
            #break
        for url in urllist['dailymail'+b]:
            lines = readlines('data/dm_stories_tokenized/dm_stories_tokenized/'+url+'.story')
            article, abstract = get_art_abs(lines)
            if article.strip() == "" or abstract.strip() == "":
                continue
            extended.append(ctf_write(fout,article,abstract,i))
            i += 1
            #break
    print('%d %s data created'%(i,b))

Preprocessing for MSRA data

In [4]:
voc = Counter()
for extend in ['desc','headline']:
    with open('./data/SELF_DATA/train.'+extend,'r',encoding='utf-8') as f:
        for line in f:
            tokens = line.strip().split(' ')
            voc.update(tokens)

In [5]:
with open('./data/SELF_DATA/voc_50k.txt','w',encoding='utf-8') as f:
    #f.write('<unk>\t0\n')
    f.write('<s>\t0\n')
    f.write('</s>\t1\n')
    f.write('<pad>\t2\n')
    i = 3
    for w in voc.most_common(50000):
        f.write('%s\t%d\n'%(w[0],i))
        i += 1
    f.write('<unk>\t%d\n'%i)

In [7]:
extended = []
for a in ['test','valid','train']:
    with open('data/SELF_DATA/%s.desc'%a,'r',encoding='utf-8') as desc_in:
        with open('data/SELF_DATA/%s.headline'%a,'r',encoding='utf-8') as head_in:
            with open('data/SELF_DATA/%s.ctf'%a,'w') as fout:
                i = 0
                article_tokens = 0
                head_tokens = 0
                print('create '+a)
                for desc,head in zip(desc_in,head_in):
                    tmp1,tmp2,tmp3 = ctf_write(fout,desc,head,i)
                    article_tokens += tmp1
                    head_tokens += tmp2
                    extended.append(tmp3)
                    i += 1
                print('%d %s data created, total %d article tokens, %d headline tokens'%(i,a,article_tokens,head_tokens))

create test
9961 test data created, total 701716 article tokens, 108871 headline tokens
create valid
10000 valid data created, total 704482 article tokens, 109053 headline tokens
create train
9094344 train data created, total 638924687 article tokens, 99021828 headline tokens


In [8]:
max(extended)

108