In [22]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import json
from pt_model import instantiate_model, predict, predict_raw, read_config
from pt_postprocessing import aggregate_results, output_to_json, agg_results_segments_row, process_labels_map, PersuationResults, postprocess_scores
import pt_constants
# !conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
# da fare: !conda update -n base -c defaults conda

# ignore warnings
import warnings
warnings.filterwarnings('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
import torch 

# check general information
print('Pytorch version: ', torch.__version__)
print('Check availability: ', torch.cuda.is_available())

# check if we have cuda installed
if torch.cuda.is_available():
    # to use GPU
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('GPU is:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu") 

Pytorch version:  2.0.0
Check availability:  False
No GPU available, using the CPU instead.


## Import the data

In [4]:
# df = pd.read_csv('../../data/nuclear/nuclear_mstr.csv').iloc[:200]
df = pd.read_csv('test_sample.csv').iloc[:1]
text, guid = df.text.to_list(),  df.guid.to_list()

In [5]:
text

['The UK is no exception either; our iconic wildlife, including hedgehogs, water voles and many of our once familiar butterflies, is increasingly hard to find and this summer wildfires scorched our desiccated countryside. It can be hard not to feel hopeless in the face of this shocking picture, but the truth is we do not have time for despair. To protect life as we know it on our planet, we must act now. Failing to do so would represent an unforgiveable betrayal of future generations. World leaders hold the key to solving this crisis - they can press fast-forward on the all-important changes we need to see, from a shift to sustainable food production and renewable energy, to investing in nature, and the people and communities who depend upon it, as our greatest ally in the fight against climate change.Leaders on the global stage, including our Prime Minister, must find the courage to act to bring our world back to life, before our natural riches are gone forever. The COP15 Biodiversity

## Instantiate the model and get results to JSON

In [11]:
# Load configuration and instantiate model (xlm-roberta-large by default)
# set model parameters 
cfg = read_config('./defaults.cfg')
# initialize the model 
model = instantiate_model(cfg)

# Full Prediction (inference and postprocessing)
output =  predict(model, text, threshold=0.5)
results_sentence = aggregate_results(text, output, level='sentence')
results_paragraph = aggregate_results(text, output, level='paragraph',granularity='fine', detailed_results=True)

None
cpu
Using threshold:  0.5


In [20]:
results_sentence

{'char_offsets': [array([[ 220,  342],
         [ 344,  404],
         [ 406,  486],
         [ 811,  974],
         [1203, 1357]])],
 'labels': [array([14,  1,  1,  1, 12])]}

In [24]:
# create function to process the results
def get_annotated_snippets(results_sentence, guid, text):

    json_output = output_to_json(results_sentence, document_ids=guid, map_to_labels=True)

    # transform the json into a dataframe, with the following columns: doc_id, label, sentence. doc_id is the first element of the json and needs to be repeated for each element of the list of sentences
    df_output = pd.DataFrame(columns=['doc_id', 'label', 'start', 'end', 'snippet'])
    for doc_id, doc in json_output.items():
        # get idx of doc_id in guid
        idx = guid.index(doc_id)
        # get text of the document
        sent = text[idx]
        for sentence in doc:
            # get sentence chunk from text
            sentence_chunk = sent[sentence['start']:sentence['end']]
            df_output = df_output.append({'doc_id': doc_id, 'label': sentence['label'], 'start': sentence['start'], 'end': sentence['end'], 'snippet': sentence_chunk}, ignore_index=True)

        
    return df_output

In [25]:
# create df with doc_id, label and part of the sentence using start and end elements

get_annotated_snippets(results_sentence, guid, text)

Unnamed: 0,doc_id,label,start,end,snippet
0,express-876ce9f4436857a8936a890ef2898b8e,Loaded_Language,220,342,It can be hard not to feel hopeless in the fac...
1,express-876ce9f4436857a8936a890ef2898b8e,Appeal_to_Fear-Prejudice,344,404,"To protect life as we know it on our planet, w..."
2,express-876ce9f4436857a8936a890ef2898b8e,Appeal_to_Fear-Prejudice,406,486,Failing to do so would represent an unforgivea...
3,express-876ce9f4436857a8936a890ef2898b8e,Appeal_to_Fear-Prejudice,811,974,"Leaders on the global stage, including our Pri..."
4,express-876ce9f4436857a8936a890ef2898b8e,Flag_Waving,1203,1357,"At this Summit, world leaders must raise their..."


## Export to json file

In [7]:
import json
js = output_to_json(results_sentence, document_ids=guid, map_to_labels=True)
json.dump(js, open('../nuclear-pt_sent.json', 'w'))

## Instantiate the model and get raw results with custom threshold

In [198]:
cfg = read_config('./defaults.cfg')
model = instantiate_model(cfg)

output_raw = predict_raw(model, text)
output = postprocess_scores(output_raw, threshold=0.5)

loading configuration file /eos/jeodpp/data/projects/EMM/Disinfo/pt_model_multi_fine/config.json
Model config XLMRobertaConfig {
  "_name_or_path": "xlm-roberta-large",
  "architectures": [
    "XLMRobertaForTokenClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5",
    "6": "LABEL_6",
    "7": "LABEL_7",
    "8": "LABEL_8",
    "9": "LABEL_9",
    "10": "LABEL_10",
    "11": "LABEL_11",
    "12": "LABEL_12",
    "13": "LABEL_13",
    "14": "LABEL_14",
    "15": "LABEL_15",
    "16": "LABEL_16",
    "17": "LABEL_17",
    "18": "LABEL_18",
    "19": "LABEL_19",
    "20": "LABEL_20",
    "21": "LABEL_21",
    "22": "LABEL_22"
  },
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "label2

## Create PersuationResults objects from results

In [176]:
# calculate, slice and aggregate
results = PersuationResults(output, text, uids=guid)
results = results.aggregate_results(levels=['paragraph', 'sentence', 'word'])

In [177]:
# get results to dict {guid -> label -> [sentences]}
results.to_dict(level='sentence', orient='labels',return_spans=True)

{'salzburger-f05f37d68053c917be4b96d8a59a9c31': {<Persuation Techniques.Appeal_to_Authority: 0>: [(0,
    120)],
  <Persuation Techniques.Appeal_to_Popularity: 3>: [(122, 213)],
  <Persuation Techniques.Exaggeration-Minimisation: 10>: [(472, 530)],
  <Persuation Techniques.Name_Calling-Labeling: 15>: [(533, 632),
   (634, 712),
   (1011, 1180)]},
 'salzburger-f7dfd334111dc0580ed9f8773799dc93': {<Persuation Techniques.Appeal_to_Authority: 0>: [(0,
    120)],
  <Persuation Techniques.Appeal_to_Popularity: 3>: [(122, 213)],
  <Persuation Techniques.Exaggeration-Minimisation: 10>: [(472, 530)],
  <Persuation Techniques.Name_Calling-Labeling: 15>: [(533, 632),
   (634, 712),
   (1011, 1180)]},
 'neue-52b12e6de521fb6cd63599aafaf561d9': {<Persuation Techniques.Appeal_to_Authority: 0>: [(0,
    120)],
  <Persuation Techniques.Appeal_to_Popularity: 3>: [(122, 213)],
  <Persuation Techniques.Name_Calling-Labeling: 15>: [(575, 653)]},
 'neue-febf0985ac176bfacc14d0c40c23713b': {<Persuation Techniq

In [178]:
# get results to dict {guid -> sentence -> [labels]}
results.to_dict(level='sentence', orient='segments')

{'salzburger-f05f37d68053c917be4b96d8a59a9c31': {0: [<Persuation Techniques.Appeal_to_Popularity: 3>],
  1: [<Persuation Techniques.Appeal_to_Authority: 0>],
  5: [<Persuation Techniques.Name_Calling-Labeling: 15>],
  6: [<Persuation Techniques.Exaggeration-Minimisation: 10>],
  7: [<Persuation Techniques.Name_Calling-Labeling: 15>],
  11: [<Persuation Techniques.Appeal_to_Popularity: 3>]},
 'salzburger-f7dfd334111dc0580ed9f8773799dc93': {0: [<Persuation Techniques.Appeal_to_Popularity: 3>],
  1: [<Persuation Techniques.Appeal_to_Authority: 0>],
  5: [<Persuation Techniques.Name_Calling-Labeling: 15>],
  6: [<Persuation Techniques.Exaggeration-Minimisation: 10>],
  7: [<Persuation Techniques.Name_Calling-Labeling: 15>],
  11: [<Persuation Techniques.Appeal_to_Popularity: 3>]},
 'neue-52b12e6de521fb6cd63599aafaf561d9': {0: [<Persuation Techniques.Appeal_to_Popularity: 3>],
  1: [<Persuation Techniques.Appeal_to_Authority: 0>],
  6: [<Persuation Techniques.Name_Calling-Labeling: 15>]},
 

In [179]:
# get results with span information {guid -> label -> [sentence spans (start, end)]}
results.to_dict(level='word', orient='segments', return_spans=True)

{'salzburger-f05f37d68053c917be4b96d8a59a9c31': {(4,
   12): [<Persuation Techniques.Appeal_to_Popularity: 3>],
  (40, 52): [<Persuation Techniques.Appeal_to_Popularity: 3>],
  (53, 64): [<Persuation Techniques.Appeal_to_Popularity: 3>],
  (65, 75): [<Persuation Techniques.Appeal_to_Popularity: 3>],
  (161, 178): [<Persuation Techniques.Appeal_to_Authority: 0>],
  (179, 184): [<Persuation Techniques.Appeal_to_Authority: 0>],
  (185, 190): [<Persuation Techniques.Appeal_to_Authority: 0>],
  (502, 515): [<Persuation Techniques.Name_Calling-Labeling: 15>],
  (534, 540): [<Persuation Techniques.Exaggeration-Minimisation: 10>],
  (541, 545): [<Persuation Techniques.Exaggeration-Minimisation: 10>],
  (546, 549): [<Persuation Techniques.Exaggeration-Minimisation: 10>],
  (550, 553): [<Persuation Techniques.Exaggeration-Minimisation: 10>],
  (566, 569): [<Persuation Techniques.Exaggeration-Minimisation: 10>],
  (570, 575): [<Persuation Techniques.Exaggeration-Minimisation: 10>],
  (576, 581): 

In [180]:
# get sentences (or paragraphs/words)
results.paragraphs

{'salzburger-f05f37d68053c917be4b96d8a59a9c31': [[0, 531],
  [533, 714],
  [716, 1180]],
 'salzburger-f7dfd334111dc0580ed9f8773799dc93': [[0, 531],
  [533, 714],
  [716, 1180]],
 'neue-52b12e6de521fb6cd63599aafaf561d9': [[0, 320],
  [322, 472],
  [474, 655],
  [657, 1121]],
 'neue-febf0985ac176bfacc14d0c40c23713b': [[0, 471], [473, 654], [656, 1120]],
 'vorarlberg-8f33bf9eff6c34dec0ed0b2a7a37d94c': [[0, 471],
  [473, 654],
  [656, 1120]],
 'krone-b2f335306b0d91fb805dc8770a566c5e': [[0, 416], [418, 720], [722, 1225]],
 'news-at-498feab7c89e7c228f8524ebae3216d0': [[0, 653], [655, 1119]],
 'volksblatt-c3cd01d91180f4198159f5ce0d81ff04': [[0, 653], [655, 1119]],
 'vorarlberg-64b3650bc1377c8638a339f879e68dd5': [[0, 321],
  [323, 363],
  [365, 749],
  [751, 1576],
  [1578, 1639],
  [1641, 1950],
  [1952, 2606],
  [2608, 2978]],
 'orf-d3919bb9892b09db0b6566444c34dbf1': [[0, 282]],
 'orf-9b567c53ef9f64f2f943834af023519a': [[0, 287]],
 'sn-at-1e41583d692396e05acfa5bd77fd8154': [[0, 579],
  [581,

In [181]:
results.output()['char_offsets']

[array([[   4,   12],
        [  40,   44],
        [  44,   48],
        [  48,   52],
        [  53,   57],
        [  57,   61],
        [  61,   64],
        [  65,   70],
        [  70,   75],
        [ 161,  168],
        [ 168,  169],
        [ 169,  172],
        [ 172,  177],
        [ 177,  178],
        [ 179,  184],
        [ 185,  188],
        [ 188,  190],
        [ 506,  511],
        [ 534,  537],
        [ 537,  540],
        [ 541,  545],
        [ 546,  549],
        [ 550,  553],
        [ 566,  569],
        [ 570,  575],
        [ 576,  581],
        [ 582,  584],
        [ 585,  588],
        [ 588,  591],
        [ 592,  595],
        [ 596,  599],
        [ 600,  603],
        [ 604,  608],
        [ 608,  613],
        [ 678,  684],
        [ 684,  687],
        [ 687,  691],
        [ 692,  695],
        [ 696,  700],
        [ 700,  703],
        [ 704,  712],
        [1024, 1025],
        [1025, 1027],
        [1027, 1030],
        [1030, 1036],
        [1

In [None]:
results.words

{'salzburger-f05f37d68053c917be4b96d8a59a9c31': [[0, 3],
  [4, 12],
  [13, 15],
  [16, 26],
  [27, 33],
  [34, 39],
  [40, 52],
  [53, 64],
  [65, 75],
  [76, 79],
  [80, 89],
  [90, 94],
  [95, 97],
  [98, 101],
  [102, 114],
  [115, 120],
  [120, 121],
  [122, 126],
  [127, 132],
  [133, 148],
  [149, 156],
  [157, 160],
  [161, 178],
  [179, 184],
  [185, 190],
  [190, 191],
  [192, 196],
  [197, 200],
  [201, 202],
  [202, 205],
  [206, 213],
  [213, 214],
  [215, 218],
  [219, 224],
  [224, 225],
  [226, 228],
  [229, 241],
  [242, 251],
  [251, 252],
  [253, 261],
  [262, 266],
  [267, 269],
  [270, 277],
  [278, 283],
  [284, 293],
  [293, 294],
  [294, 301],
  [302, 304],
  [305, 315],
  [316, 319],
  [319, 320],
  [321, 323],
  [324, 331],
  [332, 337],
  [338, 339],
  [339, 342],
  [343, 349],
  [350, 354],
  [354, 356],
  [357, 359],
  [360, 367],
  [368, 369],
  [369, 373],
  [374, 379],
  [379, 381],
  [382, 390],
  [391, 393],
  [394, 401],
  [402, 407],
  [408, 412],
  [