### 导入

In [1]:
import json
import os
import math
import itertools
import numpy as np
import pandas as pd

In [2]:
os.chdir('/data1/su/app/text_forecast/data/datasets/nyt_new%20structure/')
f1 = open('%20train.demo.json', mode='r')
train_demo = json.load(f1)
f2 = open('crisis_new.json', mode='r')
crisis = json.load(f2)
f3 = open('entities_new.json', mode='r')
entities = json.load(f3)

In [3]:
print('nyt:')
print(type(train_demo))
print(len(train_demo))
print(train_demo[0].keys())

print('crisis:')
print(type(crisis))
print(len(crisis))
print(crisis[0].keys())

print('entities:')
print(type(entities))
print(len(entities))
print(entities[0].keys())

nyt:
<class 'list'>
110
dict_keys(['name', 'src', 'page', 'page_mean', 'time', 'taxo'])
crisis:
<class 'list'>
4
dict_keys(['tgt5', 'tgt_date5', 'tgt6', 'tgt_date6', 'tgt7', 'tgt_date7', 'tgt8', 'tgt_date8', 'tgt9', 'tgt_date9', 'name', 'src', 'src_date'])
entities:
<class 'list'>
46
dict_keys(['tgt', 'tgt_date', 'name', 'src', 'src_date'])


### 修改后的taxo距离计算函数

In [4]:
def taxostat_distance(timeline, depth) -> list:
    """
    params:
    timeline: json字典，.keys() = dict_keys(['name', 'src', 'page', 'page_mean', 'time', 'taxo'])
    
    depth: taxonomic classifier的最大追索深度,一般最大距离为depth-1

    return: list,每一个时间节点距离基准taxonomic classifier的平均距离
    """
    # 取出raw_taxostr
    raw_taxostr_lst = timeline.get('taxo')
    
    # 划分taxostr
    taxostr_lst = [raw_taxostr.split("|") if type(raw_taxostr) == str else '' for raw_taxostr in raw_taxostr_lst ]
    # taxostr分段
    taxo_unit_lst = [[taxostr.split("/") for taxostr in unit] if unit != '' else '' for unit in taxostr_lst ]

    # 计算距离
    # 以出现频次最高的taxo为基准
    try:
        base_taxo = (
            pd.value_counts(
                [
                    "/".join(taxo[0 : min(depth, len(taxo))])
                    for unit in taxo_unit_lst
                    for taxo in unit
                ]
            )
            .index[0]
            .split("/")
        )
    except:
        return []
    base_len = len(base_taxo)
    # 计算每个时间节点内taxo的平均距离
    taxo_distance_lst = []
    for taxo_unit in taxo_unit_lst:
        curr_scores = []
        for taxo in taxo_unit:  # 计算每一个taxo距离base_taxo的距离
            minus = 1
            for i in range(min(base_len, len(taxo))):
                if taxo[i] != base_taxo[i]:
                    minus = 0
                    break

            score = depth - minus - i
            curr_scores.append(score)

        
        #taxo_distance_lst.append(sum(curr_scores) * 1.0 / len(taxo_unit)
        
        if len(taxo_unit)!= 0:
            taxo_distance_lst.append(sum(curr_scores) * 1.0 / len(taxo_unit))                         
        else:
            taxo_distance_lst.append(4.0) #taxo为空的距离取4.0

    return taxo_distance_lst


### 计算nyt的taxo距离(坑点:有小部分文章的taxo为空)

In [5]:
#e.g
print(train_demo[2].get('name'))
print(train_demo[2].get('taxo')[449]) #这里为nan
print(train_demo[2].get('time')[449])
print(train_demo[2].get('src')[449])
print(taxostat_distance(train_demo[2],4)[449])

nato_and
nan
19980703T000000
['Crews', 'from', 'six', 'NATO', 'ships', 'battle', 'in', 'tug', '-', 'of', '-', 'war', 'competition', 'next', 'to', 'Intrepid', 'Sea', '-', 'Air', '-', 'Space', 'Museum', ',', 'NYC', ';', 'photo', '(', 'S', ')']
4.0


In [6]:
Taxo_scores = [taxostat_distance(topic,4) for topic in train_demo]
for i in range(0,len(train_demo)): train_demo[i]['taxo_score'] = Taxo_scores[i] 

In [7]:
f1_n = open('%20train.demo+taxo.json', mode='w+')
json.dump(train_demo ,f1_n)

### 弱监督选择

In [8]:
def weak_supervision_selection(
    doc_sent_list, doc_date_list, doc_page_list, doc_taxoscore_list, abstract_size = 8, page_weight = 1, taxo_weight = 10, use_date = True, date_size = 2
):

    origin_tuple = tuple(zip(range(len(doc_page_list)), doc_page_list, doc_taxoscore_list, doc_date_list))
    
    #oracle选取初步想法：
    #1.先由page, taxo加权排序得到sorted_tuple(page越大越好，taxo越小越好), 
    #2.再取sorted_tuple中前n*abstract_size个(n待定)组成小集合selected_tuple
    #3.再对小集合里所有的timeline组合取'date方差'最小的一组得到result_tuple
    
    sorted_tuple = sorted(
        origin_tuple, key=lambda x: page_weight*x[1]+taxo_weight*(4-x[2]), reverse=True #w_page * page + w_taxo * (4-taxo)
    )
    
    if len(sorted_tuple) <= abstract_size: use_date=False #当总文章数小于或等于时间线size数时，无法使用date筛选
    
    if use_date == True:
        select_size = min(math.ceil(date_size*abstract_size), len(sorted_tuple))
        selected_tuple = sorted_tuple[0:select_size-1]
        min_Var_date = float("inf")
        for timeline in itertools.combinations(selected_tuple, abstract_size):
            dates = list(zip(*timeline))[3]
            datenum = [int(date[0:8]) for date in dates]
            Var_date = np.var(datenum)
            if Var_date < min_Var_date :
                min_Var_date = Var_date
                result_tuple = timeline
    
    else:
        result_tuple = sorted_tuple[0:min(abstract_size, len(sorted_tuple))]
        
        
    result_tuple = sorted(result_tuple,key = lambda x: x[3]) #result_tuple按时间顺序排序
    
    
    abstract = []
    abstract_date = []
    oracle_ids = []
    for i in range(len(result_tuple)):
        idx = result_tuple[i][0]
        oracle_ids.append(idx)
        abstract.append(doc_sent_list[idx])
        abstract_date.append(doc_date_list[idx])
    
    return abstract, abstract_date, oracle_ids

### 示例

In [9]:
sent_lists = [topic.get('src') for topic in train_demo]
date_lists = [topic.get('time') for topic in train_demo]
page_lists = [topic.get('page') for topic in train_demo]
taxo_lists = [topic.get('taxo_score') for topic in train_demo]

参数取默认值

In [10]:
weak_supervision_selection(sent_lists[0], date_lists[0], page_lists[0], taxo_lists[0], 
                           abstract_size=8)[1]

['20010225T000000',
 '20010620T000000',
 '20010914T000000',
 '20010914T000000',
 '20020329T000000',
 '20020329T000000',
 '20030608T000000',
 '20030921T000000']

不考虑date

In [11]:
weak_supervision_selection(sent_lists[0], date_lists[0], page_lists[0], taxo_lists[0], 
                           abstract_size=8,
                          use_date=False)[1]

['19960817T000000',
 '19961212T000000',
 '19970928T000000',
 '19971214T000000',
 '19981213T000000',
 '20010620T000000',
 '20010914T000000',
 '20030921T000000']

不考虑page

In [12]:
weak_supervision_selection(sent_lists[0], date_lists[0], page_lists[0], taxo_lists[0],
                           abstract_size=8,
                           use_date=False,
                           page_weight=0, 
                           )[1]

['19970928T000000',
 '19980802T000000',
 '20000402T000000',
 '20000730T000000',
 '20010620T000000',
 '20011006T000000',
 '20040613T000000',
 '20060423T000000']

不考虑taxo

In [13]:
weak_supervision_selection(sent_lists[0], date_lists[0], page_lists[0], taxo_lists[0],
                           abstract_size=8,
                           use_date=False,
                           taxo_weight=0, 
                           )[1]

['19960824T000000',
 '19960902T000000',
 '19961212T000000',
 '19970222T000000',
 '19971214T000000',
 '19981213T000000',
 '20030608T000000',
 '20030921T000000']

In [14]:
#试一下数据能不能跑通
for k in range(0,len(sent_lists)):
    weak_supervision_selection(sent_lists[k], date_lists[k], page_lists[k], taxo_lists[k])