In [5]:
import pandas as pd 
df = pd.read_csv('../../discourse2020/data/news-discourse-training-data.csv')
df = pd.read_csv('../models/discourse-model/data/news-discourse-training-data.csv')

In [11]:
# df['name'].drop_duplicates()

In [61]:
import glob
docs = (
    glob.glob('../models/discourse-model/data/Discourse_Profiling/data/*/*.txt') + 
    glob.glob('../models/discourse-model/data/Discourse_Profiling/data/*/*/*.txt')
)
anns = (
    glob.glob('../models/discourse-model/data/Discourse_Profiling/data/*/*.ann') + 
    glob.glob('../models/discourse-model/data/Discourse_Profiling/data/*/*/*.ann')     
)

In [107]:
def format_key(x):
    if 'test' in x:
        return f'test/{x[-2]}/{x[-1]}'
    if 'train' in x: 
        return f'train/{x[-2]}/{x[-1]}'
    if 'validation' in x: 
        return f'validation/{x[-1]}'
    

ann_df = (
    pd.DataFrame({'ann_fn': anns})
        .assign(key=lambda df: 
            df['ann_fn']
                .str.split('/')
                .apply(format_key)
                .str.strip()
                .str.replace('.ann', '')
    )
)
docs_df = (
    pd.DataFrame({'doc_fn': docs})
        .assign(key=lambda df: 
            df['doc_fn']
                .str.split('/')
                .apply(format_key)
                .str.strip()
                .str.replace('.txt', '')
    )
)

In [108]:
file_df = docs_df.merge(ann_df, right_on='key', left_on='key')

In [199]:
import numpy as np 
def get_line_or_null(lines, indicator):
    match = list(filter(lambda x: x.startswith('URL'), lines))
    if len(match) == 1:
        return match[0].replace(indicator, '').strip()
    if len(match) == 0:
        return np.nan
    else:
        raise ValueError
    

all_parsed_annotations = []
for _, (doc_fn, key, ann_fn) in file_df.iterrows():

    ann = open(ann_fn).read().split('\n')

    # annotation lines
    ann_lines = list(map(lambda x: x.split('\t'), ann))
    ann_lines = list(filter(lambda x: len(x) == 3, ann_lines))
    ann_lines = list(filter(lambda x: 'speech' not in x[1].lower(), ann_lines))

    ann_df = pd.DataFrame(
        sorted(ann_lines, key=lambda x: int(x[2].replace('S', ''))), 
        columns=['indicator', 'label', 'sent_id']
    )

    ann_df = (ann_df
     .assign(label=lambda df: df['label'].str.split().str.get(0).str.strip())
    )

    # doc lines 
    doc_lines = open(doc_fn).read().split('\n')
    
    doc_lines = list(filter(lambda x: len(x) > 0, doc_lines))
    
    url = get_line_or_null(doc_lines, indicator='URL ')
    headline = get_line_or_null(doc_lines, indicator='H ')
    sentences = list(filter(lambda x: x.startswith('S'), doc_lines))
    sentences_df = pd.DataFrame(
        list(map(lambda x: [x.split()[0], ' '.join(x.split()[1:])], sentences)), 
        columns=['sent_id', 'sentence']
    )

    # merge together 
    full_df = sentences_df.merge(ann_df, on='sent_id', ).drop(columns='indicator')
    full_df['headline'] = headline
    full_df['key'] = key
    full_df['url'] = url
    
    all_parsed_annotations.append(full_df)

In [200]:
len(all_parsed_annotations)

802

In [205]:
full_newsworthiness_df = pd.concat(all_parsed_annotations)

In [208]:
full_newsworthiness_df.to_csv('../models/discourse-model/data/reparsed-newsworthiness-df.csv')

In [217]:
( 
  full_newsworthiness_df
     .sort_values(['key', 'sent_id'])
     .groupby('key')
     [['sentence', 'label']]
     .aggregate(list)
     .reset_index()
     .assign(split=lambda df: df['key'].apply(lambda x: 'train' if x.split('/')[0] in ['train', 'test'] else 'test'))
     .drop(columns='key')
     .to_json('../models/discourse-model/data/reparsed-newsworthiness-df.jsonl', orient='records', lines=True)    
)

In [242]:
from sklearn.pipeline import Pipeline 
from sklearn.linear_model import LogisticRegressionCV
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import f1_score

In [223]:
lr_pipeline = Pipeline([
    ('cv', CountVectorizer(min_df=.005, max_df=.5, stop_words='english') ),
    ('lr', LogisticRegressionCV(max_iter=2000))
])

In [233]:
train_df = full_newsworthiness_df.loc[lambda df: df['key'].str.split('/').str.get(0).isin(['train', 'test'])]
test_df = full_newsworthiness_df.loc[lambda df: df['key'].str.split('/').str.get(0).isin(['validation'])]

In [239]:
lr_pipeline.fit(X=train_df['sentence'], y=train_df['label'])

In [241]:
y_pred = lr_pipeline.predict(test_df['sentence'])

In [246]:
f1_score(test_df['label'], y_pred, average='macro')

0.2913223731230926

In [19]:
from transformers import AutoModel, AutoConfig, AutoTokenizer
import sys
sys.path.insert(0, '../models/discourse-model/src/')
from sentence_model import SentenceClassificationModel
import pandas as pd 

In [77]:
from util import label_mapper

In [54]:
from torch.nn.utils.rnn import pad_sequence

In [62]:
config = AutoConfig.from_pretrained('alex2awesome/newsdiscourse-model')
model = SentenceClassificationModel.from_pretrained('alex2awesome/newsdiscourse-model')
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

Downloading pytorch_model.bin:   0%|          | 0.00/715M [00:00<?, ?B/s]

In [63]:
import jsonlines
with open('../models/discourse-model/data/reparsed-newsworthiness-df.jsonl') as f:
    val_data = list(jsonlines.Reader(f))

In [65]:
one_doc = tokenizer(val_data[0]['sentence'], return_tensors='pt', padding=True)

In [66]:
one_doc['input_ids']

tensor([[    0,  5499, 10227,  ...,   615,     4,     2],
        [    0,   530, 13760,  ...,     1,     1,     1],
        [    0,    17,    48,  ...,     1,     1,     1],
        ...,
        [    0,    17,    48,  ...,     1,     1,     1],
        [    0,    17,    48,  ...,     1,     1,     1],
        [    0, 22816, 16514,  ...,     1,     1,     1]])

In [68]:
_, l = model.process_one_doc(**one_doc)

In [94]:
y_preds = l.reshape(22, 9).argmax(dim=1).detach().cpu().tolist()
idx2label_mapper = {v:k for k,v in label_mapper.items()}
pd.concat([
    pd.Series(val_data[0]['label']).to_frame('y_true'),
    pd.Series(list(map(idx2label_mapper.get, y_preds))).to_frame('y_pred'),
    pd.Series(val_data[0]['sentence']).to_frame('sentence')
], axis=1)