In [1]:
import pandas as pd 
import string

In [19]:
training_dat_file = '../models_neural/quote_attribution/data/our-annotated-data__stage-2.tsv'
real_training_data = pd.read_csv(training_dat_file, sep='\t', header=None)

In [26]:
real_training_data[3].drop_duplicates().str.split('/').str.get(1).value_counts()

train    251
test      84
Name: 3, dtype: int64

In [36]:
real_training_data[1].apply(lambda x: 'None' if x == 'None' else 'Source').value_counts()

None      8451
Source    7342
Name: 1, dtype: int64

In [18]:
# make fake data

In [54]:
from random import random, randint, choice

In [28]:
num_train_docs = 251
num_test_docs = 84

In [97]:
sources = ['Robert Smith', 'Jenna Bellman', 'officials', 'Joseph Biden', 'Donald Trump', 'authorities']
source_sentences = [
    'He doesn’t teach math.',
    'You don’t have children.',
    'Does he play tennis?',
    'The train leaves every morning at 18 AM.',
    'Water freezes at 0°C',
    'I love my new pets.',
    'She doesn’t use a computer.',
    'It snows a lot in winter in Russia.',
    'We live in Texas.',
    'You go to holiday every summer.',
    'Do you like spaghetti?',
    'My daughter does the laundry.',
    'The train does not leave at 12 AM.',
    'I have no money at the moment.',
    'Do they talk a lot?',
    'Tomorrow early morning first I go to morning walk.',
    'We drink coffee every morning.',
    'I and my sister don’t see each other anymore.',
    'They don’t go to school tomorrow.',
    'Every child likes an ice cream.',
]

background_sentences = [
    'They have been known to say that.',
    'The path forward is unclear.',
    'This occurred as the opposite was being discussed.',
    'No known follow up occurred.',
    'For all of history, this has been said.',
    'The reasons were not apparent.',
    'Updates will follow.',
    'Research has shown the opposite.',
    'Daily modes are hard to establish.',
    'All sources declined comment',
]




output_docs = []
for i in range(num_train_docs + num_test_docs):
    doc_idx = ('/train/%s' if i < num_train_docs else '/test/%s') % i 
    num_sents = randint(20, 80)
    output_doc = []
    for s in range(num_sents):
        has_source = random() > .5
        if has_source:
            source = choice(sources)
            sent = choice(source_sentences)
            sent = '"%s" said %s' % (sent, source)
        else:
            source = 'None'
            sent = choice(background_sentences)

        output_docs.append([sent, source, s, doc_idx])

In [98]:
ls ../models_neural/quote_attribution/data

our-annotated-data__stage-2.tsv
quote-attribution-classification__sanity-check-data.tsv


In [106]:
output_docs_df = pd.DataFrame(output_docs)
output_docs_df.to_csv(
    '../models_neural/quote_attribution/data/quote-attribution-classification__sanity-check-data.tsv', 
    sep='\t', header=None, index=False
)

# Examine data that's been generated

In [8]:
import sys
sys.path.insert(0, '../models_neural/quote_attribution/utils_dataset')
sys.path.insert(0, '../models_neural')
sys.path.insert(0, '../')

from quote_attribution.utils_dataset import SourceClassificationDataModule
from attrdict import AttrDict
import pandas as pd

In [2]:
training_dat_file = '../models_neural/quote_attribution/data/our-annotated-data__stage-2.tsv'

In [3]:
config = AttrDict()

In [4]:
config.downsample_negative_data = 1
config.local = True
config.shuffle_data = False
config.num_documents = 5
config.include_nones_as_positives = False

In [5]:
data_module = SourceClassificationDataModule(
        config=config,
        data_fp=training_dat_file,
        pretrained_model_path='/Users/alex/.cache/torch/transformers/named-models/roberta-base-expanded-embeddings',
        num_cpus=1,
        split_type='key',
        split_perc=.95,
        model_type='roberta',
        batch_size=1,
        max_length_seq=200,
        spacy_path='en_core_web_lg',
    )

In [6]:
dataset = data_module.get_dataset()

  0%|          | 0/335 [00:00<?, ?it/s]

In [103]:
data_module.tokenizer.decode(dataset[0]['doc_tokens'])

'<s>journalist passive - voice BANGKOK</s><s>—</s><s>A plane carrying key senior Laotian government officials crashed Saturday morning, leaving at least four people dead, Laotian diplomats said Saturday.</s><s>Killed in the crash were two top figures in the security apparatus of the authoritarian Lao government : the deputy prime minister, Douangchay Phichit, and Thongbane Sengaphone, the minister of public security, according to two Lao diplomats.</s><s>For a Communist party that relies on force and intimidation to stay in power, the loss of what were arguably the two most powerful people in the security apparatus was a significant blow.</s><s>The governor of Vientiane province was also killed in the crash.</s><s>In addition to his post as deputy prime minister Mr. Douangchay was defense minister and a member of the Politburo, the highest decision - making body of the Communist party.</s><s>Mr. Thongbane, the public security head, was feared in the country and was said to be one of th

In [9]:
processed_real_dataset_df = pd.DataFrame(list(dataset))

In [10]:
processed_real_dataset_df = (processed_real_dataset_df
 .assign(sentence_toks=lambda df: 
         df.apply(lambda x: list(filter(lambda y: y[1] == 1, zip(x['doc_tokens'], x['sentence_ind_tokens']))), axis=1)
         .apply(lambda x: list(map(lambda y: y[0], x)))
        )
 .assign(source_toks=lambda df: 
         df.apply(lambda x: list(filter(lambda y: y[1] == 1, zip(x['doc_tokens'], x['source_ind_tokens']))), axis=1)
         .apply(lambda x: list(map(lambda y: y[0], x)))
        )

)

In [11]:
processed_real_dataset_df['source_name'] = processed_real_dataset_df.apply(lambda x: data_module.tokenizer.decode(x['source_toks']), axis=1)

In [12]:
processed_real_dataset_df

Unnamed: 0,source_ind_tokens,sentence_ind_tokens,doc_tokens,label,sent_lens,doc_idx,sent_idx,annotated_source,candidate_source,sentence_toks,source_toks,source_name
0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 15545, 661, 18718, 111, 2236, 163, 9298, 5...",False,"[11, 3, 30, 57, 37, 16, 40, 39, 42, 24, 36, 28...",,,,,"[0, 250, 1587, 139, 622, 1842, 969, 3156, 9, 7...",[949],senior
1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 15545, 661, 18718, 111, 2236, 163, 9298, 5...",False,"[11, 3, 30, 57, 37, 16, 40, 39, 42, 24, 36, 28...",,,,,"[0, 133, 340, 1218, 2156, 4319, 10, 445, 31, 5...","[13360, 2384]",State television
2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 15545, 661, 18718, 111, 2236, 163, 9298, 5...",True,"[11, 3, 30, 57, 37, 16, 40, 39, 42, 24, 36, 28...",/train/902,,Laotian diplomats,diplomats,"[0, 250, 3286, 3406, 762, 949, 1587, 1242, 811...",[10055],diplomats
3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 15545, 661, 18718, 111, 2236, 163, 9298, 5...",True,"[11, 3, 30, 57, 37, 16, 40, 39, 42, 24, 36, 28...",/train/902,,The Lao news agency,lao news agency,"[0, 250, 1587, 139, 622, 1842, 969, 3156, 9, 7...","[1587, 139, 340, 1218]",Lao news agency
4,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 15545, 661, 18718, 111, 2236, 163, 9298, 5...",True,"[11, 3, 30, 57, 37, 16, 40, 39, 42, 24, 36, 28...",/train/902,,The Lao news agency,lao news agency,"[0, 463, 24, 56, 6050, 11, 234, 5416, 3375, 21...","[1587, 139, 340, 1218]",Lao news agency
...,...,...,...,...,...,...,...,...,...,...,...,...
143,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 15545, 661, 18718, 111, 2236, 436, 16, 213...",True,"[61, 37, 31, 34, 34, 26, 35, 41, 19, 4, 48, 46...",/test/903,,Lauri Myllyvirta,Lauri Myllyvirta,"[0, 133, 762, 631, 16, 14, 4420, 2156, 436, 34...","[25722, 118, 1308, 13861, 42660, 102]",Lauri Myllyvirta
144,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 15545, 661, 18718, 111, 2236, 436, 16, 213...",True,"[61, 37, 31, 34, 34, 26, 35, 41, 19, 4, 48, 46...",/test/903,,Lin Boqiang,Lin Boqiang,"[0, 17, 48, 2]","[9252, 3542, 1343, 14607]",Lin Boqiang
145,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 15545, 661, 18718, 111, 2236, 436, 16, 213...",False,"[61, 37, 31, 34, 34, 26, 35, 41, 19, 4, 48, 46...",,,,,"[0, 15545, 661, 18718, 111, 2236, 436, 16, 213...",[168],government
146,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 15545, 661, 18718, 111, 2236, 436, 16, 213...",False,"[61, 37, 31, 34, 34, 26, 35, 41, 19, 4, 48, 46...",,,,,"[0, 17, 48, 2]",[285],public


In [13]:
processed_real_dataset_df['label'].value_counts()

False    74
True     74
Name: label, dtype: int64

In [24]:
(real_training_data
 .loc[lambda df: df[3].isin(processed_real_dataset_df['doc_idx'])]
 .loc[lambda df: df[1] != 'None']
)

Unnamed: 0,0,1,2,3
2,A plane carrying key senior Laotian government...,Laotian diplomats,2,/train/902
10,The Lao news agency posted photos to a web sit...,The Lao news agency,10,/train/902
11,"The news agency , citing a statement from the ...",The Lao news agency,11,/train/902
12,"and it had crashed in Nadi village , west of t...",The Lao news agency,12,/train/902
13,The authorities were “ helping to rescue the s...,The Lao news agency,13,/train/902
...,...,...,...,...
386,I like to be optimistic and hope for the best ...,Ngoc Huynh,69,/train/905
392,"We ’re thinking about this together , ” Ms. Hu...",Ngoc Huynh,75,/train/905
393,“,Ngoc Huynh,76,/train/905
394,That ’s the beauty of it .,Ngoc Huynh,77,/train/905


In [16]:
processed_real_dataset_df.loc[lambda df: df['label'] == True]['annotated_source'][0]

'Foreign Ministry'

In [17]:
processed_real_dataset_df['label'].value_counts()

True     150
False     25
Name: label, dtype: int64

In [170]:
# processed_real_dataset_df.to_pickle('cache/2022-10-06__pickled-dataset.pkl')

In [153]:
(processed_real_dataset_df
 .loc[lambda df: df['label'] == True]
 .loc[lambda df: df['candidate_source'] != 'None']
 [['source_name', 'candidate_source']]
 .apply(lambda s: s.str.lower().str.strip())
#  .pipe(lambda df: df['source_name'] == df['candidate_source'])
#  .value_counts()
)

Unnamed: 0,source_name,candidate_source
1,sek wannamethee,sek wannamethee
3,state television,state television
6,lao news agency,lao news agency
9,state television,state television
11,state television,state television
...,...,...
29891,the wall street journal,the wall street journal
29898,theranos,theranos
29901,elizabeth holmes,elizabeth holmes
29902,elizabeth holmes,elizabeth holmes


In [159]:
(processed_real_dataset_df.loc[lambda df: df['label'] == True].shape)

(14953, 12)

In [157]:
(real_training_data.shape)

(15793, 4)

In [162]:
(processed_real_dataset_df
 .loc[lambda df: df['label'] == True]
 .loc[lambda df: df['candidate_source'] != 'None']
 .shape 
)

(6967, 12)

In [155]:
(real_training_data
 .loc[lambda df: df[1] != 'None']
).shape 

(7342, 4)

In [140]:
(real_training_data
 .loc[lambda df: df[1] != 'None']
 .assign(c=1).groupby([1,3])
 ['c']
 .sum()
)

1                          3         
\tBambang Soelistyo        /test/776     2
\tCard industry bloggers   /train/186    1
\tEnvironmental officials  /test/172     1
\tFire investigators       /train/153    2
\tNate Silver              /train/937    2
                                        ..
workers                    /train/830    1
Ángel Luis Torres          /train/5      3
Édouard Philippe           /test/789     1
“ Today ” show             /test/549     8
” Ed Gonzalez              /train/715    1
Name: c, Length: 2611, dtype: int64