<a href="https://colab.research.google.com/github/Vidushi-GitHub/gcn-nlp-test/blob/main/InformationExtractor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Information Extraction Model for NASA GCN Project**
The goal of this project is to extract important information from NASA GCN Circulars for the purpose of automated data entry into the archive, which is currently done manually.

#Step 1: Import Libraries

In [1]:
'''
Author: Ronit Agarwala
Last Modified: 09/12/2023
'''

import pandas as pd
import tarfile
import os
import re
import csv
import torch

#Step 2: Extract Data

In [2]:
'''
Extract table data from csv file.
This will contain our labels for the model.
'''

df = pd.read_csv('swift_redshift_data.csv', skiprows=0, header=0)
print(df)

         GRB  Time[UT] TriggerNumber        BAT RA(J2000)  \
0    230818A  23:27:34       1186032  285.888\n19:03:33.1   
1    230506C  17:09:19       1167288  134.371\n08:57:29.0   
2    230414B  16:14:21       1164180  181.052\n12:04:12.5   
3    230328B  14:54:48       1162001  291.037\n19:24:08.9   
4    230325A  03:15:35       1161390  296.854\n19:47:25.0   
..       ...       ...           ...                  ...   
415   050319  09:31:18        111622  154.172\n10:16:41.3   
416   050318  15:44:37        111529   49.695\n03:18:46.8   
417   050315  20:59:43        111063  306.476\n20:25:54.2   
418   050223  03:09:06        106709  271.394\n18:05:34.6   
419   050126  12:00:54        103780  278.115\n18:32:27.6   

           BAT Dec(J2000) BAT T90[sec]  BAT Fluence(15-150 keV)[10-7 erg/cm2]  \
0      40.888\n40:53:16.8         9.82                                  19.00   
1      45.131\n45:07:51.6        31.00                                  17.00   
2      53.179\n53:10:44.

In [3]:
'''
Extract GCN Circulars from .tar.gz file to a new unzipped folder.
Store the path for each circular in the unzipped folder in list 'dir'.
'''

file = tarfile.open('./all_gcn_circulars.tar.gz')
file.extractall('./all_gcn_circulars')
file.close()

dir = os.listdir('./all_gcn_circulars/gcn3') #Store all file names as strings in dir

#Add file path to beginning of file names in dir
for i in range(len(dir)):
  dir[i] = './all_gcn_circulars/gcn3/' + dir[i]

print(f"Number of circulars: {len(dir)}")
print(f"First circular path: {dir[0]}")

Number of circulars: 33653
First circular path: ./all_gcn_circulars/gcn3/12884.gcn3


In [4]:
'''
Iterate through each path in dir, and store the text in lower case.
This will contain our features for the model.
'''

circulars = [] #List to store the text of each GCN
full_text = '' #String to store entire corpus of data
for gcn in dir:
  with open(gcn, encoding = "ISO-8859-1") as f:
    file_str = f.read().lower()
    circulars.append(file_str)
    full_text += file_str

print(circulars[0]) #Print first circular

title:   gcn circular
number:  12884
subject: grb 120119a, swift-bat refined analysis
date:    12/01/20 22:42:20 gmt
from:    hans krimm at nasa-gsfc  <hans.a.krimm@nasa.gov>

m. stamatikos (osu), s. d. barthelmy (gsfc), w. h. baumgartner (gsfc/umbc),
a. p. beardmore (u leicester), j. r. cummings (gsfc/umbc),
e. e. fenimore (lanl), n. gehrels (gsfc), h. a. krimm (gsfc/usra),
c. b. markwardt (gsfc), d. m. palmer (lanl), t. sakamoto (gsfc/umbc),
j. tueller (gsfc), t. n. ukwatta (gwu) (i.e. the swift-bat team):

using the data set from t-239 to t+395 sec from the recent telemetry downlink,
we report further analysis of bat grb 120119a (trigger #512035)
(beardmore, et al., gcn circ. 12859).  the bat ground-calculated position is
ra, dec = 120.029, -9.076 deg which is
    ra(j2000)  =  08h 00m 06.9s
    dec(j2000) = -09d 04' 35.3"
with an uncertainty of 1.0 arcmin, (radius, sys+stat, 90% containment).
the partial coding was 100%.

the mask-weighted light curve shows a series of three overla

In [5]:
'''
Get smaller dataframe with relevant information only.
This section converts the SWIFT GRB Table to a python dict.
The end result is a dict with key = GCN No. and value = Redshift Text.
'''

sub_df = df[['References', 'Redshift', 'GRB']].copy()
removed_list = []
removed_indices=[]
gcn_format = re.compile('GCN ?\d+( *\(.*?\))?')
gcn_number_format = re.compile('GCN ?(\d+)')
bracket_format = re.compile(' ?\([^)]*\)')
instrument_format = re.compile('\((.*?)\)')
instrument_format_1 = re.compile('\(([^):]*)')

#Extract just the Redshift References from the References column
for index, row in sub_df.iterrows():
  row['References'] = row['References'].splitlines()
  row['Redshift'] = row['Redshift'].splitlines()

  temp=''
  for line in row['References']: #Extract only redshift references
    if line[0:8] == 'Redshift':
      temp = line
  if temp != '':
    row['References'] = temp

  if isinstance(row['References'], list): #Drop the non-uniform data for now
    removed_list.append((row['References'], row['Redshift'], row['GRB']))
    removed_indices.append(index)

red_list = [] #List of tuples
sub_df = sub_df.drop(removed_indices)

#Create a list of tuples
#Each tuple contains a GCN No. with it's corresponding Redshift Text and GRB Name
for index, row in sub_df.iterrows():
  gcn_iter = gcn_format.finditer(row['References'])
  gcns = list(gcn_iter)

  #References with only one circular don't need any further iteration
  if len(gcns) == 1 and len(row['Redshift']) == 1:
    gcn_num_match = gcn_number_format.search(gcns[0].group())
    gcn_num = gcn_num_match.group(1)
    redshift_text = bracket_format.sub('', row['Redshift'][0])
    grb_name = row['GRB']

    if len(redshift_text.split(',')) != 1: #Drop the non-uniform data for now
      removed_list.append((row['References'], row['Redshift'], row['GRB']))
      continue

    red_list.append((gcn_num, redshift_text.strip().lower(), grb_name.strip().lower()))

  #If multiple circulars are present, iterate through them all
  else:
    for gcn in gcns:
      for redshift in row['Redshift']:

        #Extract instrument name from reference
        instrument_match = instrument_format.search(gcn.group())
        if instrument_match == None:
          continue
        instrument_name = instrument_match.group(1)

        #Extract instrument name from redshift text
        instrument_match_1 = instrument_format_1.search(redshift)
        if instrument_match_1 == None:
          continue
        instrument_name_1 = instrument_match_1.group(1)

        #Check to see if both instruments match. If yes, append tuple to list
        if instrument_name == instrument_name_1:
          gcn_num_match = gcn_number_format.search(gcn.group())
          gcn_num = gcn_num_match.group(1)
          redshift_text = bracket_format.sub('', redshift)
          grb_name = row['GRB']
          red_list.append((gcn_num, redshift_text.strip().lower(), grb_name.strip().lower()))

redshift_dict = {} #Dict for redshift data

rem_idx = [94, 274, 284, 437, 455] #Indices that need manual re-entry
for i, tuple_ in enumerate(red_list):
  if i in rem_idx:
    removed_list.append(tuple_)
  else:
    redshift_dict[tuple_[0]] = (tuple_[1], tuple_[2])

for item in removed_list:
  print(item)
print(len(removed_list))
print(len(redshift_dict))

(['BAT: GCN 29677; GCN 29691', 'XRT: GCN 29678; GCN 29681; GCN 29686; Evans et al., 2009, MNRAS, 397, 1177', 'UVOT: GCN 29679; GCN 29697', 'Radio: GCN 29685 (OSIRIS/GTC)'], ['1.487 (OSIRIS/GTC)'], '210321A')
(['BAT: GCN 19645; GCN 19648', 'XRT: GCN 19655; Evans et al., 2009, MNRAS, 397, 1177', 'UVOT: GCN 19645; GCN 19656', 'Radio: GCN 19849 (GMRT)'], ['The detection in all UVOT filters is consistent with a redshift of z < ~1.5'], '160703A')
(['BAT: GCN 18076; GCN 18086', 'XRT: GCN 18076; GCN 18079; GCN 18082; Evans et al., 2009, MNRAS, 397, 1177', 'UVOT: GCN 18076; GCN 18084', 'Radio: GCN 18080 (VLT)'], ['0.313 (VLT: emission)'], '150727A')
('Redshift: GCN 26538 (UVOT); GCN (VLT)', ['1.19 (UVOT), 1.148 (VLT: absorption)'], '191221B')
('21209', 'consistent with gtc results', '170531b')
('12542', 'consistent with gemini-south', '111107a')
('11997', '1.61, pairitel', '110503a')
('6651', 'proposed redshift 3.626', '070721b')
('5946', '0.41 possible host redshift', '061210')
9
539


In [6]:
'''
Add back removed data points after manual lookup
'''

redshift_dict['29685'] = ('1.487', '210321a')
redshift_dict['19656'] = ('the detection in all uvot filters is consistent with a redshift of z < ~1.5', '160703a')
redshift_dict['18080'] = ('0.313', '150727a')
redshift_dict['26538'] = ('1.19', '191221b')
redshift_dict['21209'] = ('consistent with that obtained by de ugarte-postigo et al. (gcn 21177)', '170531b')
redshift_dict['12542'] = ('consistent with that reported by chornock et al. (gcn 12537)', '111107a')
redshift_dict['11997'] = ('1.61', '110503a')
redshift_dict['6651'] = ('3.626', '070721b')
redshift_dict['5946'] = ('0.41', '061210')

print(len(redshift_dict))

548


In [7]:
print(circulars[0])

title:   gcn circular
number:  12884
subject: grb 120119a, swift-bat refined analysis
date:    12/01/20 22:42:20 gmt
from:    hans krimm at nasa-gsfc  <hans.a.krimm@nasa.gov>

m. stamatikos (osu), s. d. barthelmy (gsfc), w. h. baumgartner (gsfc/umbc),
a. p. beardmore (u leicester), j. r. cummings (gsfc/umbc),
e. e. fenimore (lanl), n. gehrels (gsfc), h. a. krimm (gsfc/usra),
c. b. markwardt (gsfc), d. m. palmer (lanl), t. sakamoto (gsfc/umbc),
j. tueller (gsfc), t. n. ukwatta (gwu) (i.e. the swift-bat team):

using the data set from t-239 to t+395 sec from the recent telemetry downlink,
we report further analysis of bat grb 120119a (trigger #512035)
(beardmore, et al., gcn circ. 12859).  the bat ground-calculated position is
ra, dec = 120.029, -9.076 deg which is
    ra(j2000)  =  08h 00m 06.9s
    dec(j2000) = -09d 04' 35.3"
with an uncertainty of 1.0 arcmin, (radius, sys+stat, 90% containment).
the partial coding was 100%.

the mask-weighted light curve shows a series of three overla

In [8]:
'''
Create a list for our features.
Each element in the list shall be a tuple.
The first element will be the question.
The second element of the tuple will be the circular text.
The third element will the the answer text supposedly contained in the circular.
The fourth element will be the question: "what is the name of the grb?"
'''

data = [] #List of tuples
circular_num_format = re.compile('number\:  (\d+)')

for circular in circulars:
  circ_num_match = circular_num_format.search(circular)
  if circ_num_match == None:
    continue
  circ_num = circ_num_match.group(1) #Extract circular number from text

  #Find corresponding circular number in redshift dict
  #Append its redshift value, along with the GCN text to our data list
  for key in redshift_dict.keys():
    if circ_num == key:
      data.append(('what is the redshift value?', circular, redshift_dict[key][0]))
      data.append(('what is the grb name?', circular, redshift_dict[key][1]))
      break

print(f"Current no. of data points: {len(data)}")
print(f"Question: {data[0][0]}")
print(f"Context: {data[0][1]}")
print(f"Answer: {data[0][2]}")
print(f"Question 2: {data[1][0]}")
print(f"Answer 2: {data[1][2]}")

Current no. of data points: 1092
Question: what is the redshift value?
Context: title:   gcn circular
number:  9222
subject: grb 090423: refined tng analysis
date:    09/04/24 14:16:29 gmt
from:    cristiano guidorzi at ferrara u,italy  <guidorzi@fe.infn.it>

a. fernandez-soto (ifca-santander), f. mannucci (inaf-oaa), d. fugazza 
(inaf-oab), l.a. antonelli (inaf-oar), s. campana (inaf-oab), g. 
chincarini (univ. bicocca), s. covino (inaf-oab), p. d'avanzo 
(inaf-oab/u. bicocca), v. d'elia (inaf-oar), m. della valle 
(inaf-oaca/eso), a. fiorenzano (tng), c. guidorzi (univ. ferrara), e. 
maiorano (iasf-bo), j. mao (inaf-oab), r.  margutti (inaf-oab/univ. 
bicocca), s. marinoni (tng), e. palazzi (iasf-bo), c. c. thoene (inaf-oab)
report, on behalf of a larger collaboration (cibo):

we have performed an in-depth analysis of the tng spectrum of grb090423 
taken on apr 23 at 22:16 ut with the nics/amici combination (thoene et 
al, gcn 9216). we have corrected the wavelength calibration, now 

In [9]:
'''
Make sure the redshift text is actually in the circulars
'''
del_list = []
del_idx = []
for i, tuple_ in enumerate(data):
  if tuple_[2] not in tuple_[1]:
    del_list.append(tuple_)
    del_idx.append(i)

print(f"Previous no. of data points: {len(data)}")
print(f"No. of points to be removed: {len(del_list)}")
data = [tuple_ for i,tuple_ in enumerate(data) if i not in del_idx]
print(f"Final no. of data points: {len(data)}")

Previous no. of data points: 1092
No. of points to be removed: 56
Final no. of data points: 1036


In [10]:
'''
Write data to csv file.
'''

data = [(index, *row) for index, row in enumerate(data)]
with open('preprocessed_redshift_data.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['id', 'question', 'context', 'answer'])
    writer.writerows(data)

#Step 3: Fine-Tuning our Model

In [11]:
'''
Download and import necessary Huggingface libraries.
'''

!pip install transformers datasets evaluate accelerate

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DefaultDataCollator
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
from transformers import pipeline
import evaluate
from tqdm.auto import tqdm
import numpy as np
import collections

Collecting transformers
  Downloading transformers-4.33.1-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m63.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m51.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.22.0-py3-none-any.whl (251 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.2/251.2 kB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers)
  Downloading huggingface_hub-0.17.1-py3-none-any.whl (294 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2

In [12]:
'''
Create a Huggingface DatasetDict object to store the training, validation and test sets.
We use a 80/10/10 split.
Fixed seed is used for reproducibility.
'''

dataset = load_dataset("csv", data_files="preprocessed_redshift_data.csv", split='train[:]')
dataset = dataset.train_test_split(test_size=0.2, seed=42)
temp_dataset = dataset.pop("test")
temp_dict = temp_dataset.train_test_split(test_size=0.5, seed=42)
dataset["validation"] = temp_dict["train"]
dataset["test"] = temp_dict["test"]
dataset

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'context', 'answer'],
        num_rows: 828
    })
    validation: Dataset({
        features: ['id', 'question', 'context', 'answer'],
        num_rows: 104
    })
    test: Dataset({
        features: ['id', 'question', 'context', 'answer'],
        num_rows: 104
    })
})

In [13]:
dataset["train"][0]

{'id': 464,
 'question': 'what is the grb name?',
 'context': 'title:   gcn circular\nnumber:  5952\nsubject: grb 061201: magellan redshift of nearby galaxy\ndate:    06/12/21 21:13:16 gmt\nfrom:    edo berger at carnegie obs  <eberger@ociw.edu>\n\ne. berger (carnegie) reports:\n\n"starting on 2006 dec. 21.07 ut we used ldss3 on magellan to obtain an \n1800 sec spectrum of the galaxy located 17 arcsec nw of the optical \nafterglow of the short grb 061201 (see also gcn #5884).  this is the \nnearest bright galaxy (r~19 mag) to the grb position.  we detect several \nemission lines, which we identify as h-beta, [oiii], h-alpha, [nii], and \n[sii] at a redshift of z=0.111.  at this redshift the projected offset of \nthe burst is about 34 kpc, significantly smaller than about 1.9 mpc \nrelative to the center of abell 995 (gcn #5944)."\n\n',
 'answer': '061201'}

In [14]:
dataset["validation"][0]

{'id': 420,
 'question': 'what is the redshift value?',
 'context': 'title:   gcn circular\nnumber:  6663\nsubject: grb 060814 - keck host detection and redshift\ndate:    07/07/24 21:18:06 gmt\nfrom:    christina thoene at niels bohr institute,dark cosmo ctr  <cthoene@astro.ku.dk>\n\nchristina c. thoene (dark/uc berkeley), daniel a. perley and j. s. bloom\n(uc berkeley) report:\n\non 2007 april 15 (ut) we imaged the field of grb 060814 with keck i 10m\ntelescope + lris for 750s in i and 840s in v under poor seeing conditions.\nthe host galaxy reported by malesani et al. (gcn 5456) and others is\nwell-detected at a position consistent with the locations of the ir (levan\net al., gcn 5455) and x-ray (gcn 5451; [1]) transients. a finding chart of\nthe field can be found at:\n\nhttp://lyra.berkeley.edu/~dperley/060814/060814_lris_v.png\n\nwe also took spectra of the host galaxy on july 18 with lris, using grism\n600/4000 and dicroic 560, which covers the wavelength range between 5500\nand

In [15]:
dataset["test"][0]

{'id': 613,
 'question': 'what is the grb name?',
 'context': "title:   gcn circular\nnumber:  24916\nsubject: grb 190627a: vlt/fors2 spectroscopic redshift\ndate:    19/06/30 10:01:00 gmt\nfrom:    jure japelj at api,u of amsterdam  <japelj.jure@gmail.com>\n\nj. japelj (uni. amsterdam), d. a. kann (heth/iaa-csic), a. de ugarte postigo\n(heth/iaa-csic, dark/nbi), l. izzo (heth/iaa-csic), j. p. u. fynbo (dawn/nbi\nand dawn/dtu), d. b. malesani (dtu space), v. d'elia (ssdc), n. r. tanvir\n(univ. leicester), s. d. vergani (cnrs -gepi/observatorie de paris), g.\npugliese,\nl. kaper (uni. amsterdam) report on behalf of the stargate collaboration:\n\nwe observed the optical counterpart (siegel et al., gcn 24889; pozanenko\net al., gcn 24892) of grb 190627a (sonbas et al., gcn 24888) with the eso\nvlt ut1 equipped with the fors2 spectrograph. we obtained a 30 min spectrum\nwith the 600ri (512 - 845 nm) and a 30 min spectrum with the 600z (737-1070\nnm)\ngrism. observations started at 01:12:41

In [16]:
'''
Here is where we select the model and tokenizer to be used for question answering.
The checkpoint name is the name of the model as stored in the Huggingface Model Hub.
'''

checkpoint = "deepset/deberta-v3-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/379 [00:00<?, ?B/s]

Downloading spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/8.65M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/23.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/992 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/735M [00:00<?, ?B/s]

In [17]:
'''
Lets get some preliminary inferences to see how the model performs before training.
We'll use an example from the test set.
'''

question = dataset["validation"][0]["question"]
context = dataset["validation"][0]["context"]
answer = dataset["validation"][0]["answer"]
answer #Real answer

'0.84'

In [18]:
question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer)
question_answerer(question=question, context=context) #Predicted answer

{'score': 0.9180389046669006, 'start': 1025, 'end': 1032, 'answer': ' z=0.84'}

Now let's evaluate the model's performance on our validation set before fine-tuning it.

In [19]:
'''
We start by preprocessing the validation data.
Code adapted from Huggingface NLP course page on Question Answering
'''

def preprocess_validation_examples(examples):
    inputs = tokenizer(
        examples["question"],
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

In [20]:
'''
Apply the above preprocessing function to the whole validation set using map()
'''

validation_dataset = dataset["validation"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=dataset["validation"].column_names,
)

Map:   0%|          | 0/104 [00:00<?, ? examples/s]

In [21]:
'''
Remove extra columns and plug dataset into model to get output logits.
'''

eval_set_for_model = validation_dataset.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

batch = {k: eval_set_for_model[k] for k in eval_set_for_model.column_names}

with torch.no_grad():
    outputs = model(**batch)

start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

In [22]:
'''
The compute_metrics() function will compute the f1 and exact match score for a model
'''

def compute_metrics(start_logits, end_logits, features, examples):
    metric = evaluate.load("squad")
    n_best = 20
    max_answer_length = 30
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": str(example_id), "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": str(example_id), "prediction_text": ""})

    theoretical_answers = [{"id": str(ex["id"]),
                            "answers": {"text": [ex["answer"]],
                                        "answer_start": [ex["context"].find(ex["answer"])]}} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [23]:
pre_training_scores = compute_metrics(start_logits, end_logits, validation_dataset, dataset["validation"])
pre_training_scores

Downloading builder script:   0%|          | 0.00/4.53k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.32k [00:00<?, ?B/s]

  0%|          | 0/104 [00:00<?, ?it/s]

{'exact_match': 26.923076923076923, 'f1': 57.532051282051256}

That is bad, but expected. Let's fine-tune our model now.

In [24]:
'''
This function is for tokenizing and preprocessing the training data.
It is similar to the validation preprocessor, except for the fact that we also predict labels here.
'''

def preprocess_training_examples(examples):
    inputs = tokenizer(
        examples["question"],
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answer"]
    contexts = examples["context"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        context = contexts[sample_idx]
        start_char = context.find(answer)
        end_char = start_char + len(answer)
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [25]:
tokenized_dataset = dataset["train"].map(preprocess_training_examples, batched=True, remove_columns=dataset["train"].column_names)

Map:   0%|          | 0/828 [00:00<?, ? examples/s]

In [26]:
torch.cuda.empty_cache()

In [27]:
'''
Training our model.
'''
torch.cuda.empty_cache()
training_args = TrainingArguments(
    output_dir="trained_results",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)

trainer.train()

You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss


TrainOutput(global_step=468, training_loss=0.20727547213562533, metrics={'train_runtime': 246.2297, 'train_samples_per_second': 15.144, 'train_steps_per_second': 1.901, 'total_flos': 730794901123584.0, 'train_loss': 0.20727547213562533, 'epoch': 3.0})

In [28]:
model.save_pretrained("trained_model")
model = AutoModelForQuestionAnswering.from_pretrained("trained_model")

#Step 4: Inference and Evaluation after Fine-Tuning

In [38]:
'''
Let's now test the model on the same test dataset example after fine-tuning.
'''

question = dataset["validation"][2]["question"]
context = dataset["validation"][2]["context"]
answer = dataset["validation"][2]["answer"]
answer #Real answer

'130603b'

In [39]:
question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer)
question_answerer(question=question, context=context) #Predicted answer

{'score': 0.9999740123748779, 'start': 49, 'end': 58, 'answer': ' 130603b:'}

In [41]:
print(question)
print(context)

what is the grb name?
title:   gcn circular
number:  14757
subject: grb 130603b: vlt/x-shooter redshift confirmation
date:    13/06/04 10:38:37 gmt
from:    dong xu at dark/nbi  <dong.dark@gmail.com>

d. xu (dark/nbi), a. de ugarte postigo (iaa-csic, dark/nbi), d.
malesani (dark/nbi), s. schulze (puc and mcss), j. p. u. fynbo, d. j.
watson (dark/nbi), v. d'elia (asi-sdc, inaf oar), p. goldoni (apc,
cea/irfu), m. vestergaard (dark/nbi) report on behalf of the x-shooter
grb gto collaboration:

we observed the optical afterglow of the short-duration grb 130603b
(melandri et al., gcn 14735; levan et al., gcn 14742; de ugarte
postigo et al., gcn 14743) using the eso vlt equipped with the
x-shooter spectrograph. the observations started on 2013-06-04 at
00:00:28 ut (i.e., 8.187 hr after the burst). a total exposure of
4x600 s was obtained, covering the spectral range from ~300 to ~2100
nm.

a continuum is detected in all the uvb/vis/nir arms of the spectra. we
identify several absorption fea

This answer is much better! Now let's get the evaluation metrics.

In [32]:
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions
post_training_scores = compute_metrics(start_logits, end_logits, validation_dataset, dataset["validation"])
post_training_scores

  0%|          | 0/104 [00:00<?, ?it/s]

{'exact_match': 100.0, 'f1': 100.0}

This is a lot better. We have successfully fine-tuned our model.

From our investigations, it seems to be that the best model is "deepset/deberta-v3-base-squad2". Let's test their performance on the test set now.

In [33]:
test_dataset = dataset["test"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=dataset["test"].column_names,
)

predictions, _, _ = trainer.predict(test_dataset)
start_logits, end_logits = predictions
post_training_scores = compute_metrics(start_logits, end_logits, test_dataset, dataset["test"])
post_training_scores

Map:   0%|          | 0/104 [00:00<?, ? examples/s]

  0%|          | 0/104 [00:00<?, ?it/s]

{'exact_match': 97.11538461538461, 'f1': 97.35576923076923}

An approximate 97% accuracy!