# Readme first

**What this notebook does:**
- Locate and return the reason texts

**What this notebook DOES NOT do:**
- Return the "type" (e.g., External Shock) of the reason

In [6]:
import datatable as dt
import numpy as np
import pyarrow as pa
import torch
import pandas as pd

from collections import Counter
from IPython.utils import io
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from src.datamodules.datamodules import NrSpanDataModule
from src.models.models import Model
from tqdm import tqdm
from transformers import AutoTokenizer
from typing import List

# Load deepspeed model

In [10]:
def load_checkpoint(zero_ckpt_dir, ckpt_resave_path):
    '''
    Args:
        zero_ckpt_path (str): path to zero checkpoint (when using deepspeed)
        pt_ckpt_path (str): path to pt checkpoint (when using ddp)
    '''

    convert_zero_checkpoint_to_fp32_state_dict(
        zero_ckpt_dir, ckpt_resave_path)

    return torch.load(ckpt_resave_path)

def init_model(device: str):
    # ------------------
    # load ckpt
    # ------------------
    zero_ckpt_dir = '/home/yu/OneDrive/NewsReason/local-dev/checkpoints/epoch=11.ckpt'
    ckpt_resave_path = '/home/yu/OneDrive/NewsReason/local-dev/checkpoints/ckpt_saved.pt'

    ckpt = load_checkpoint(zero_ckpt_dir, ckpt_resave_path)

    # ------------------
    # collect hparams
    # ------------------
    hparams = ckpt['hyper_parameters']
    datamodule_cfg = hparams['datamodule_cfg']
    model_cfg = hparams['model_cfg']

    # ------------------
    # init model
    # ------------------

    state_dict = ckpt['state_dict']

    with io.capture_output() as captured:
        model = Model(**hparams)
        model.load_state_dict(state_dict, strict=False)
        model.to(device)
        model.eval()

    return model, datamodule_cfg, model_cfg

# init model
device = 'cuda:0'

model, datamodule_cfg, model_cfg = init_model(device)

Processing zero checkpoint '/home/yu/OneDrive/NewsReason/local-dev/checkpoints/epoch=11.ckpt/checkpoint'
Detected checkpoint of type zero stage 2, world_size: 2
Parsing checkpoint created by deepspeed==0.7.3
Reconstructed fp32 state dict with 395 params 355387417 elements
Saving fp32 state dict to /home/yu/OneDrive/NewsReason/local-dev/checkpoints/ckpt_saved.pt


# NR-Span

## init dataloader

In [11]:
def init_dataloader(pretrained_model, tx_path, coarse, n_unique_labels, ignore_index):

    datamodule = NrSpanDataModule(
        inference=True,
        train_val_test_split=[1, 0, 0],
        use_biolu=False,
        tx_path=tx_path,
        ignore_index=ignore_index,
        coarse=coarse,
        n_unique_labels=n_unique_labels,
        special_tokens=['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>'],
        pretrained_model=pretrained_model)

    datamodule.setup()

    return datamodule

# init dataloader
pretrained_model = 'roberta-large'
tx_path = '/home/yu/OneDrive/NewsReason/local-dev/data/annotation/batch-4/2-annotated/annotated_agreed_full_batch3_4.feather'
coarse = True
# [BIO]: 25 (coarse), 49 (fine) | [BIOLU]: 49 (coarse), 97 (fine)
n_unique_labels = 25
ignore_index = -100

datamodule = init_dataloader(
    pretrained_model, tx_path, coarse, n_unique_labels, ignore_index)

Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /home/yu/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "_name_or_path": "roberta-large",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.19.2",
  "type_vocab_size": 1,
  "us

N unique ids=25
label_to_id={'[CLS]': -100, '[SEP]': -100, '[PAD]': -100, '<s>': -100, '</s>': -100, '<pad>': -100, 'O': 0, 'B-Firm Action': 1, 'I-Firm Action': 2, 'B-Contrast/Confusion': 3, 'I-Contrast/Confusion': 4, 'B-Demand & Trading': 5, 'I-Demand & Trading': 6, 'B-Operation Outcome': 7, 'I-Operation Outcome': 8, 'B-External Shock': 9, 'I-External Shock': 10, 'B-Labor': 11, 'I-Labor': 12, 'B-Technical': 13, 'I-Technical': 14, 'B-Others': 15, 'I-Others': 16, 'B-Litigation': 17, 'I-Litigation': 18, 'B-Financing': 19, 'I-Financing': 20, 'B-Third Party': 21, 'I-Third Party': 22, 'B-Fraud & Investigation': 23, 'I-Fraud & Investigation': 24}


## predict

In [15]:
def inference(datamodule):
    '''
    '''
    # get dataloader
    dataloader = datamodule.test_dataloader()
    
    # get tokenizer
    tokenizer = datamodule.tokenizer

    # make predictions!
    headlines = []
    pred_reasons = []

    # get speical token_ids (i.e., the ids of speical tokens like [CLS], or </s>)
    special_tokens = tokenizer.special_tokens_map.values()
    special_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in special_tokens]

    for i, batch in enumerate(tqdm(dataloader)):
        if i >= 5:
            break

        input_ids = batch['input_ids'][0].tolist()  # List[int]
        texts = batch['texts']  # List[str]

        # collect headlines
        headlines.append(texts[0])

        # predict
        batch = {k: v.to(device) for k, v in batch.items()
                 if k not in ['texts', 'headline_classes']}

        with torch.no_grad():
            y = model.predict(batch).argmax(-1)[0].cpu()

            # get the index of each reason token
            y = y.tolist()

            # the index of reason token.
            reason_token_ixs = []

            # the id (in its tokenizer) of each reason token.
            reason_token_ids = []

            for i, (input_id, token_pred) in enumerate(zip(input_ids, y)):
                if (token_pred not in [0, ignore_index]) and (input_id not in special_token_ids):
                    reason_token_ixs.append(i)
                    reason_token_ids.append(input_id)

        # save predicted reasons
        reason = tokenizer.convert_tokens_to_string(
            tokenizer.convert_ids_to_tokens(reason_token_ids))
        pred_reasons.append(reason)

    
    # save results to a table
    return pd.DataFrame({'headline': headlines, 'pred_reasons': pred_reasons})
    
# run inference
inference_output = inference(datamodule)


  0%|          | 5/1200 [00:00<00:47, 25.25it/s]


In [16]:
inference_output

Unnamed: 0,headline,pred_reasons
0,Thor Shares Fall As Company Sees 'headwinds' F...,Company Sees 'headwinds' From Steel Tariffs
1,Mettler-Toledo International Inc. (MTD) Reache...,
2,Reed's shares are trading higher after the com...,the company announced an expanded distributio...
3,Sector Update: Tech Stocks Ending Near Session...,Q3 Revenue Beat
4,Office Depot Shares Spike To Near Session Low ...,1.1M Share Block Trade At $3.06/Share Crosses
