<a href="https://colab.research.google.com/github/Vidushi-GitHub/gcn-nlp-test/blob/main/InformationExtractor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Information Extraction Model for NASA GCN Project**
The goal of this project is to extract important information from NASA GCN Circulars for the purpose of automated data entry into the archive, which is currently done manually.

#Step 1: Import Libraries

In [53]:
'''
Author: Ronit Agarwala
Last Modified: 09/12/2023
'''

import pandas as pd
import tarfile
import os
import re
import csv
import torch

#Step 2: Extract Data

In [54]:
'''
Extract table data from csv file.
This will contain our labels for the model.
'''

df = pd.read_csv('swift_redshift_data.csv', skiprows=0, header=0)
print(df)

         GRB  Time[UT] TriggerNumber        BAT RA(J2000)  \
0    230818A  23:27:34       1186032  285.888\n19:03:33.1   
1    230506C  17:09:19       1167288  134.371\n08:57:29.0   
2    230414B  16:14:21       1164180  181.052\n12:04:12.5   
3    230328B  14:54:48       1162001  291.037\n19:24:08.9   
4    230325A  03:15:35       1161390  296.854\n19:47:25.0   
..       ...       ...           ...                  ...   
415   050319  09:31:18        111622  154.172\n10:16:41.3   
416   050318  15:44:37        111529   49.695\n03:18:46.8   
417   050315  20:59:43        111063  306.476\n20:25:54.2   
418   050223  03:09:06        106709  271.394\n18:05:34.6   
419   050126  12:00:54        103780  278.115\n18:32:27.6   

           BAT Dec(J2000) BAT T90[sec]  BAT Fluence(15-150 keV)[10-7 erg/cm2]  \
0      40.888\n40:53:16.8         9.82                                  19.00   
1      45.131\n45:07:51.6        31.00                                  17.00   
2      53.179\n53:10:44.

In [55]:
'''
Extract GCN Circulars from .tar.gz file to a new unzipped folder.
Store the path for each circular in the unzipped folder in list 'dir'.
'''

file = tarfile.open('./all_gcn_circulars.tar.gz')
file.extractall('./all_gcn_circulars')
file.close()

dir = os.listdir('./all_gcn_circulars/gcn3') #Store all file names as strings in dir

#Add file path to beginning of file names in dir
for i in range(len(dir)):
  dir[i] = './all_gcn_circulars/gcn3/' + dir[i]

print(f"Number of circulars: {len(dir)}")
print(f"First circular path: {dir[0]}")

Number of circulars: 33653
First circular path: ./all_gcn_circulars/gcn3/17093.gcn3


In [56]:
'''
Iterate through each path in dir, and store the text in lower case.
This will contain our features for the model.
'''

circulars = [] #List to store the text of each GCN
full_text = '' #String to store entire corpus of data
for gcn in dir:
  with open(gcn, encoding = "ISO-8859-1") as f:
    file_str = f.read().lower()
    circulars.append(file_str)
    full_text += file_str

print(circulars[0]) #Print first circular

title:   gcn circular
number:  17093
subject: grb 141121a: mondy optical observations
date:    14/11/24 16:56:28 gmt
from:    alexei pozanenko at iki, moscow  <apozanen@iki.rssi.ru>

e. mazaeva (iki), e. klunko (istp), a. volnova (iki), m. eselevich 
(istp), i. korobtsev (istp), a. pozanenko (iki) report on behalf of 
larger grb follow-up collaboration:

we observed the field of grb 141121a (lien et al., gcn 17075) with 
azt-33ik telescope of sayan observatory (mondy) on nov., 23  starting on 
(ut) 20:57:39.  we obtained several  images in r-filter. in a combined 
image we clearly  detect optical  afterglow (tanga et al., gcn 17078; 
perley et al., gcn 17081).

a preliminary photometry is based on nearby sdss  stars:

date       ut start   t-t0     filter   exp.    ot    ot_err
                               (mid, days)       (s)

2014-11-23 20:57:39   2.74025  r        39*120  20.53 0.04

we confirm re-brightening of the afterglow reported early (watson et al. 
gcn 17090; kuroda et al

In [57]:
'''
Get smaller dataframe with relevant information only.
This section converts the SWIFT GRB Table to a python dict.
The end result is a dict with key = GCN No. and value = Redshift Text.
'''

sub_df = df[['References', 'Redshift', 'GRB']].copy()
removed_list = []
removed_indices=[]
gcn_format = re.compile('GCN ?\d+( *\(.*?\))?')
gcn_number_format = re.compile('GCN ?(\d+)')
bracket_format = re.compile(' ?\([^)]*\)')
instrument_format = re.compile('\((.*?)\)')
instrument_format_1 = re.compile('\(([^):]*)')

#Extract just the Redshift References from the References column
for index, row in sub_df.iterrows():
  row['References'] = row['References'].splitlines()
  row['Redshift'] = row['Redshift'].splitlines()

  temp=''
  for line in row['References']: #Extract only redshift references
    if line[0:8] == 'Redshift':
      temp = line
  if temp != '':
    row['References'] = temp

  if isinstance(row['References'], list): #Drop the non-uniform data for now
    removed_list.append((row['References'], row['Redshift'], row['GRB']))
    removed_indices.append(index)

red_list = [] #List of tuples
sub_df = sub_df.drop(removed_indices)

#Create a list of tuples
#Each tuple contains a GCN No. with it's corresponding Redshift Text and GRB Name
for index, row in sub_df.iterrows():
  gcn_iter = gcn_format.finditer(row['References'])
  gcns = list(gcn_iter)

  #References with only one circular don't need any further iteration
  if len(gcns) == 1 and len(row['Redshift']) == 1:
    gcn_num_match = gcn_number_format.search(gcns[0].group())

    instrument_match = instrument_format.search(gcns[0].group())
    instrument_match_1 = instrument_format_1.search(row['Redshift'][0])
    instrument_name = ''
    if instrument_match == None and instrument_match_1 == None: #Drop references with no instrument name for now
      removed_list.append((row['References'], row['Redshift'], row['GRB']))
      continue

    elif instrument_match != None:
      instrument_name = instrument_match.group(1)

    elif instrument_match_1 != None:
      instrument_name = instrument_match_1.group(1)

    gcn_num = gcn_num_match.group(1)
    redshift_text = bracket_format.sub('', row['Redshift'][0])
    grb_name = row['GRB']

    if len(redshift_text.split(',')) != 1: #Drop the non-uniform data for now
      removed_list.append((row['References'], row['Redshift'], row['GRB']))
      continue

    red_list.append((gcn_num, redshift_text.strip().lower(), grb_name.strip().lower(), instrument_name.strip().lower()))

  #If multiple circulars are present, iterate through them all
  else:
    for gcn in gcns:
      for redshift in row['Redshift']:

        #Extract instrument name from reference
        instrument_match = instrument_format.search(gcn.group())
        if instrument_match == None:
          continue
        instrument_name = instrument_match.group(1)

        #Extract instrument name from redshift text
        instrument_match_1 = instrument_format_1.search(redshift)
        if instrument_match_1 == None:
          continue
        instrument_name_1 = instrument_match_1.group(1)

        #Check to see if both instruments match. If yes, append tuple to list
        if instrument_name == instrument_name_1:
          gcn_num_match = gcn_number_format.search(gcn.group())
          gcn_num = gcn_num_match.group(1)
          redshift_text = bracket_format.sub('', redshift)
          grb_name = row['GRB']
          red_list.append((gcn_num, redshift_text.strip().lower(), grb_name.strip().lower(), instrument_name.strip().lower()))

redshift_dict = {} #Dict for redshift data

rem_circ = ['21209', '12542', '11997', '6651', '5946'] #Indices that need manual re-entry
for i, tuple_ in enumerate(red_list):
  if tuple_[0] in rem_circ:
    removed_list.append(tuple_)
  else:
    redshift_dict[tuple_[0]] = (tuple_[1], tuple_[2], tuple_[3])

for item in removed_list:
  print(item)
print(len(removed_list))
print(len(redshift_dict))

(['BAT: GCN 29677; GCN 29691', 'XRT: GCN 29678; GCN 29681; GCN 29686; Evans et al., 2009, MNRAS, 397, 1177', 'UVOT: GCN 29679; GCN 29697', 'Radio: GCN 29685 (OSIRIS/GTC)'], ['1.487 (OSIRIS/GTC)'], '210321A')
(['BAT: GCN 19645; GCN 19648', 'XRT: GCN 19655; Evans et al., 2009, MNRAS, 397, 1177', 'UVOT: GCN 19645; GCN 19656', 'Radio: GCN 19849 (GMRT)'], ['The detection in all UVOT filters is consistent with a redshift of z < ~1.5'], '160703A')
(['BAT: GCN 18076; GCN 18086', 'XRT: GCN 18076; GCN 18079; GCN 18082; Evans et al., 2009, MNRAS, 397, 1177', 'UVOT: GCN 18076; GCN 18084', 'Radio: GCN 18080 (VLT)'], ['0.313 (VLT: emission)'], '150727A')
('Redshift: GCN 33609', ['Minimal possible redshift value of z = 0.09'], '230328B')
('Redshift: GCN 26538 (UVOT); GCN (VLT)', ['1.19 (UVOT), 1.148 (VLT: absorption)'], '191221B')
('21209', 'consistent with gtc results', '170531b', 'wht')
('12542', 'consistent with gemini-south', '111107a', 'vlt')
('11997', '1.61, pairitel', '110503a', 'tng')
('6651'

In [58]:
'''
Add back removed data points after manual lookup
'''

redshift_dict['29685'] = ('1.487', '210321a', 'osiris/gtc')
redshift_dict['18080'] = ('0.313', '150727a', 'vlt')
redshift_dict['26538'] = ('1.19', '191221b', 'uvot')
redshift_dict['21209'] = ('consistent with that obtained by de ugarte-postigo et al. (gcn 21177)', '170531b', 'wht')
redshift_dict['12542'] = ('consistent with that reported by chornock et al. (gcn 12537)', '111107a', 'vlt')
redshift_dict['11997'] = ('1.61', '110503a', 'tng')
redshift_dict['6651'] = ('3.626', '070721b', 'vlt')
redshift_dict['5946'] = ('0.41', '061210', 'keck')
redshift_dict['19656'] = ('z < ~1.5', '160703a', 'uvot')
redshift_dict['5319'] = ('z < 2.3', '060708', 'vlt')
redshift_dict['17758'] = ('z < 3', '150424a', 'gtc')
redshift_dict['9797'] = ('0.696 < z < 2.2', '090814a', 'vlt')
redshift_dict['12202'] = ('1.036 < z < 2.7', '110726a', 'gemini-north')
redshift_dict['23537'] = ('z < 2.4', '181213a', 'not')

print(len(redshift_dict))

547


In [59]:
'''
Augment dataset with fake data
'''

import random

circular_num_format = re.compile('number\:  (\d+)')
fake_circulars = []
for circular in circulars:
  circ_num_match = circular_num_format.search(circular)
  if circ_num_match == None:
    continue
  circ_num = circ_num_match.group(1) #Extract circular number from text

  if circ_num == '9797':
    new_circ = circular.replace('0.696 < z < 2.2', '0.73 < z < 2.91')
    new_circ = new_circ.replace('9797', '1000001')
    fake_circulars.append(new_circ)
    redshift_dict['1000001'] = ('0.73 < z < 2.91', '090814a', 'vlt')

    new_circ = circular.replace('0.696 < z < 2.2', '1.4 < z < 3.456')
    new_circ = new_circ.replace('9797', '1000002')
    fake_circulars.append(new_circ)
    redshift_dict['1000002'] = ('1.4 < z < 3.456', '090814a', 'vlt')

    new_circ = circular.replace('0.696 < z < 2.2', '0.5 < z < 4.5')
    new_circ = new_circ.replace('9797', '1000003')
    fake_circulars.append(new_circ)
    redshift_dict['1000003'] = ('0.5 < z < 4.5', '090814a', 'vlt')

  if circ_num == '12202':
    new_circ = circular.replace('1.036 < z < 2.7', '1.4 < z < 2.6')
    new_circ = new_circ.replace('12202', '1000004')
    fake_circulars.append(new_circ)
    redshift_dict['1000004'] = ('1.4 < z < 2.6', '110726a', 'gemini-north')

    new_circ = circular.replace('1.036 < z < 2.7', '0.98 < z < 3.55')
    new_circ = new_circ.replace('12202', '1000005')
    fake_circulars.append(new_circ)
    redshift_dict['1000005'] = ('0.98 < z < 3.55', '110726a', 'gemini-north')

    new_circ = circular.replace('1.036 < z < 2.7', '1.345 < z < 2.9')
    new_circ = new_circ.replace('12202', '1000006')
    fake_circulars.append(new_circ)
    redshift_dict['1000006'] = ('1.345 < z < 2.9', '110726a', 'gemini-north')

  if circ_num == '19656':
    new_circ = circular.replace('z < ~1.5', 'z < ~3.21')
    new_circ = new_circ.replace('19656', '1000007')
    fake_circulars.append(new_circ)
    redshift_dict['1000007'] = ('z < ~3.21', '160703a', 'uvot')

    new_circ = circular.replace('z < ~1.5', 'z > 3.56')
    new_circ = new_circ.replace('19656', '1000008')
    fake_circulars.append(new_circ)
    redshift_dict['1000008'] = ('z > 3.56', '160703a', 'uvot')

    new_circ = circular.replace('z < ~1.5', 'z > ~2')
    new_circ = new_circ.replace('19656', '1000009')
    fake_circulars.append(new_circ)
    redshift_dict['1000009'] = ('z > ~2', '160703a', 'uvot')

  if circ_num == '5319':
    new_circ = circular.replace('z < 2.3', 'z < 4.32')
    new_circ = new_circ.replace('5319', '1000010')
    fake_circulars.append(new_circ)
    redshift_dict['1000010'] = ('z < 4.32', '060708', 'vlt')

  if circ_num == '17758':
    new_circ = circular.replace('z < 3', 'z < 1.4')
    new_circ = new_circ.replace('17758', '1000011')
    fake_circulars.append(new_circ)
    redshift_dict['1000011'] = ('z < 1.4', '150424a', 'gtc')

  if circ_num == '23537':
    new_circ = circular.replace('z < 2.4', 'z < 4')
    new_circ = new_circ.replace('23537', '1000012')
    fake_circulars.append(new_circ)
    redshift_dict['1000012'] = ('z < 4', '181213a', 'not')

circulars.extend(fake_circulars)

In [60]:
print(circulars[0])

title:   gcn circular
number:  17093
subject: grb 141121a: mondy optical observations
date:    14/11/24 16:56:28 gmt
from:    alexei pozanenko at iki, moscow  <apozanen@iki.rssi.ru>

e. mazaeva (iki), e. klunko (istp), a. volnova (iki), m. eselevich 
(istp), i. korobtsev (istp), a. pozanenko (iki) report on behalf of 
larger grb follow-up collaboration:

we observed the field of grb 141121a (lien et al., gcn 17075) with 
azt-33ik telescope of sayan observatory (mondy) on nov., 23  starting on 
(ut) 20:57:39.  we obtained several  images in r-filter. in a combined 
image we clearly  detect optical  afterglow (tanga et al., gcn 17078; 
perley et al., gcn 17081).

a preliminary photometry is based on nearby sdss  stars:

date       ut start   t-t0     filter   exp.    ot    ot_err
                               (mid, days)       (s)

2014-11-23 20:57:39   2.74025  r        39*120  20.53 0.04

we confirm re-brightening of the afterglow reported early (watson et al. 
gcn 17090; kuroda et al

In [129]:
'''
Create a list for our features.
Each element in the list shall be a tuple.
The first element will be the question.
The second element of the tuple will be the circular text.
The third element will the the answer text supposedly contained in the circular.
The fourth element will be the question: "what is the name of the grb?"
'''

context_test = ''
data = [] #List of tuples
circular_num_format = re.compile('number\:  (\d+)')

for circular in circulars:
  circ_num_match = circular_num_format.search(circular)
  if circ_num_match == None:
    continue
  circ_num = circ_num_match.group(1) #Extract circular number from text
  if circ_num == '1000012':
    context_test = circular
  #Find corresponding circular number in redshift dict
  #Append its redshift value, along with the GCN text to our data list
  for key in redshift_dict.keys():
    if circ_num == key:
      data.append(('what is the redshift value?', circular, redshift_dict[key][0]))
      data.append(('what is the grb name?', circular, redshift_dict[key][1]))
      data.append(('what is the name of the telescope used?', circular, redshift_dict[key][2]))
      break

print(f"Current no. of data points: {len(data)}")
print(f"Question: {data[0][0]}")
print(f"Context: {data[0][1]}")
print(f"Answer: {data[0][2]}")
print(f"Question 2: {data[1][0]}")
print(f"Answer 2: {data[1][2]}")
print(f"Question 3: {data[2][0]}")
print(f"Answer 3: {data[2][2]}")

Current no. of data points: 1671
Question: what is the redshift value?
Context: title:   gcn grb observation report
number:  5052
subject: grb060502: gemini spectroscopy
date:    06/05/02 10:13:13 gmt
from:    antonino cucchiara at psu  <cucchiara@astro.psu.edu>

a. cucchiara (penn state), p.a. price (ifa, hawaii), d.b. fox, (penn
state), s.b. cenko (srl, caltech) and b.p. schmidt (rsaa, anu) report on
behalf of a larger collaboration:

we have observed the optical afterglow of grb 060502 (la parola et al., 
gcn  5047, cenko et al., 5048)
with the gmos instrument on gemini north telescope.  observations 
consisted of 2 x
1800 sec exposures with the r400 grating, commencing at 2006 may 2.34
utc.  in the summed spectrum, we identify an absorption system
consisting of fe ii, mg ii and mg i at a redshift of z ~ 1.51.  no other
absorption or emission line systems are apparent.
 

we acknowledge the rapid response effort of gemini personnel that
yielded these data.


Answer: ~1.51
Question 2

In [130]:
print(context_test)

title:   gcn circular
number:  1000012
subject: grb 181213a: not optical observations
date:    18/12/14 10:01:41 gmt
from:    kasper elm heintz at univ. of iceland and dawn/nbi  <keh14@hi.is>

k. e. heintz (univ. of iceland), d. b. malesani (dawn/nbi/dtu and dark/nbi), j. p. u. fynbo (dawn/nbi/dtu), a. de ugarte postigo (heth/iaa-csic and dark/nbi), l. balaguer-nuã±ez, j. carbajo (dept. fqa, univ. de barcelona), f. galindo and c. perez (not), report on behalf of a larger collaboration:

we observed the optical afterglow (lipunov et al., gcn 23526; hu et al., gcn 23527; siegel et al., gcn 23529; belkin et al., gcn 23530) of grb 181213a (evans et al., gcn 23525) with the 2.5-m nordic optical telescope (not) equipped with alfosc.

spectroscopy was secured: we obtained 4x900s exposures using grism 4 (with a wavelength coverage from 350 to 900 nm) starting at 23:27:14 ut on december 13 (i.e. 10.5 hr after trigger). the observations were obtained at a seeing around 1â but at an airmass > 3

In [63]:
# '''
# Add data points where no answer is expected
# '''

# for circular in circulars[:500]:
#   circ_num_match = circular_num_format.search(circular)
#   if circ_num_match == None:
#     continue
#   circ_num = circ_num_match.group(1) #Extract circular number from text

#   #Find corresponding circular number in redshift dict
#   #Append its redshift value, along with the GCN text to our data list
#   if circ_num not in redshift_dict.keys():
#     if circular.find('redshift') == -1:
#       data.append(('what is the redshift value?', circular, ''))
#       print(data[-1])

In [64]:
'''
Make sure the redshift text is actually in the circulars
'''
del_list = []
del_idx = []
for i, tuple_ in enumerate(data):
  if tuple_[2] not in tuple_[1]:
    del_list.append(tuple_)
    del_idx.append(i)

print(f"Previous no. of data points: {len(data)}")
print(f"No. of points to be removed: {len(del_list)}")
data = [tuple_ for i,tuple_ in enumerate(data) if i not in del_idx]
print(f"Final no. of data points: {len(data)}")

Previous no. of data points: 1671
No. of points to be removed: 98
Final no. of data points: 1573


In [65]:
'''
Write data to csv file.
'''

data = [(index, *row) for index, row in enumerate(data)]
with open('preprocessed_redshift_data.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['id', 'question', 'context', 'answer'])
    writer.writerows(data)

#Step 3: Fine-Tuning our Model

In [66]:
'''
Download and import necessary Huggingface libraries.
'''

!pip install transformers datasets evaluate accelerate

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DefaultDataCollator
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
from transformers import pipeline
import evaluate
from tqdm.auto import tqdm
import numpy as np
import collections



In [67]:
'''
Create a Huggingface DatasetDict object to store the training, validation and test sets.
We use a 80/10/10 split.
Fixed seed is used for reproducibility.
'''

dataset = load_dataset("csv", data_files="preprocessed_redshift_data.csv", split='train[:]')
dataset = dataset.train_test_split(test_size=0.2, seed=42)
temp_dataset = dataset.pop("test")
temp_dict = temp_dataset.train_test_split(test_size=0.5, seed=42)
dataset["validation"] = temp_dict["train"]
dataset["test"] = temp_dict["test"]
dataset

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'context', 'answer'],
        num_rows: 1258
    })
    validation: Dataset({
        features: ['id', 'question', 'context', 'answer'],
        num_rows: 157
    })
    test: Dataset({
        features: ['id', 'question', 'context', 'answer'],
        num_rows: 158
    })
})

In [68]:
dataset["train"][0]

{'id': 53,
 'question': 'what is the name of the telescope used?',
 'context': 'title:   gcn circular\nnumber:  14291\nsubject: grb 100424a: keck host detection and vlt/x-shooter redshift\ndate:    13/03/12 16:07:54 gmt\nfrom:    daniele malesani at dark cosmology centre, niels bohr inst  <malesani@dark-cosmology.dk>\n\nd. malesani, d. xu, j. p. u. fynbo, t. kruehler (dark/nbi), d. a. perley \n(caltech), s. d. vergani (cnrs/gepi), p. goldoni (apc/irfu-cea), report \non behalf of the grb gto x-shooter collaboration:\n\nwe observed the field of grb 100424a (hoversten et al., gcn 10667) using \nthe keck-i telescope equipped with the lris instrument. observations \nwere carried out on 2010 july 8.2 ut (74.5 days after the burst), \nsimultaneously in the g and i bands, for a total exposure time of 35 and \n32 min, respectively.\n\nconsistent with the position of the nir afterglow (cenko et al., gcns \n10682, 10690, 10692), we detected a source with g = 26 (ab) and i = 24.4 \n(vega). we cons

In [69]:
dataset["validation"][0]

{'id': 1155,
 'question': 'what is the grb name?',
 'context': 'title:   gcn circular\nnumber:  6202\nsubject: grb 070306: possible emission-line redshift\ndate:    07/03/12 20:37:34 gmt\nfrom:    andreas o. jaunsen at ita/u oslo  <ajaunsen@astro.uio.no>\n\na.o. jaunsen (univ. oslo), c.c. thoene, j.p.u. fynbo, j. hjorth (dark/ \nnbi),\np. vreeswijk (eso) report on behalf of a larger collaboration.\n\nwe observed the field of grb 070306 (pandey et al., gcn 6169 & gcn  \nreport 38.1)\nwith the eso/vlt equipped with fors2. observations started on 2007  \nmar 08.11 ut\n(about 34 hr after the grb) and three 1800-s spectra were acquired  \nwith the\n300v grism covering a wavelength range of 3500-9500a.\n\nthe spectrum consists of a largely featureless continuum from about  \n4000-9400a\nwith the exception of an apparent emission line at ~9310a. the  \nemission-line\nis very close to the prominent telluric skyline at 9313a, but is seen  \nin each\nof the three individual spectra. the lack of 

In [70]:
dataset["test"][0]

{'id': 593,
 'question': 'what is the redshift value?',
 'context': 'title:   gcn circular\nnumber:  11230\nsubject: grb 100906a: gemini-n/gmos redshift\ndate:    10/09/06 15:15:39 gmt\nfrom:    nial tanvir at u.leicester  <nrt3@star.le.ac.uk>\n\nn. r. tanvir, k. wiersema (u. leicester) and a. j. levan (u. warwick)\nreport on behalf of a larger collaboration:\n\nwe observed the location of grb 100906a (markwardt et al. gcn 11227;\nivanov et al. gcn 11228; melandri et al. gcn 11229) with\nthe gemini-north telescope on mauna kea using the gmos spectrograph.\nobservations began at 14:30 ut, approximately 40 minutes post burst.\n\nwe detect strong continuum from the afterglow, and identify numerous\nabsorption lines, including civ (1458, 1551a), feii (2344, 2374, 2383a)\nat a common redshift of z=1.727.\n\nfurther analysis is ongoing.\n\nwe acknowledge the support of chad trujillo in obtaining these observations.\n\n\n\n\n\n',
 'answer': '1.727'}

In [71]:
'''
Here is where we select the model and tokenizer to be used for question answering.
The checkpoint name is the name of the model as stored in the Huggingface Model Hub.
'''

checkpoint = "deepset/deberta-v3-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)

In [72]:
'''
Lets get some preliminary inferences to see how the model performs before training.
We'll use an example from the test set.
'''

question = dataset["validation"][0]["question"]
context = dataset["validation"][0]["context"]
answer = dataset["validation"][0]["answer"]
answer #Real answer

'070306'

In [73]:
question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer)
question_answerer(question=question, context=context) #Predicted answer

{'score': 0.044039852917194366,
 'start': 354,
 'end': 365,
 'answer': ' grb 070306'}

Now let's evaluate the model's performance on our validation set before fine-tuning it.

In [74]:
'''
We start by preprocessing the validation data.
Code adapted from Huggingface NLP course page on Question Answering
'''

def preprocess_validation_examples(examples):
    inputs = tokenizer(
        examples["question"],
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

In [75]:
'''
Apply the above preprocessing function to the whole validation set using map()
'''

validation_dataset = dataset["validation"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=dataset["validation"].column_names,
)

Map:   0%|          | 0/157 [00:00<?, ? examples/s]

In [76]:
# '''
# Remove extra columns and plug dataset into model to get output logits.
# '''

# eval_set_for_model = validation_dataset.remove_columns(["example_id", "offset_mapping"])
# eval_set_for_model.set_format("torch")

# batch = {k: eval_set_for_model[k] for k in eval_set_for_model.column_names}

# with torch.no_grad():
#     outputs = model(**batch)

# start_logits = outputs.start_logits.cpu().numpy()
# end_logits = outputs.end_logits.cpu().numpy()

In [77]:
'''
The compute_metrics() function will compute the f1 and exact match score for a model
'''

def compute_metrics(start_logits, end_logits, features, examples):
    metric = evaluate.load("squad")
    n_best = 20
    max_answer_length = 30
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": str(example_id), "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": str(example_id), "prediction_text": ""})

    theoretical_answers = [{"id": str(ex["id"]),
                            "answers": {"text": [ex["answer"]],
                                        "answer_start": [ex["context"].find(ex["answer"])]}} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [78]:
# pre_training_scores = compute_metrics(start_logits, end_logits, validation_dataset, dataset["validation"])
# pre_training_scores

That is bad, but expected. Let's fine-tune our model now.

In [79]:
'''
This function is for tokenizing and preprocessing the training data.
It is similar to the validation preprocessor, except for the fact that we also predict labels here.
'''

def preprocess_training_examples(examples):
    inputs = tokenizer(
        examples["question"],
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answer"]
    contexts = examples["context"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        context = contexts[sample_idx]
        start_char = 0
        end_char = 0
        if answer != None:
          start_char = context.find(answer)
          end_char = start_char + len(answer)
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [80]:
tokenized_dataset = dataset["train"].map(preprocess_training_examples, batched=True, remove_columns=dataset["train"].column_names)

Map:   0%|          | 0/1258 [00:00<?, ? examples/s]

In [81]:
torch.cuda.empty_cache()

In [82]:
'''
Training our model.
'''
torch.cuda.empty_cache()
training_args = TrainingArguments(
    output_dir="trained_results",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)

trainer.train()

You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
500,0.4476


TrainOutput(global_step=717, training_loss=0.35916298902184396, metrics={'train_runtime': 392.7076, 'train_samples_per_second': 14.553, 'train_steps_per_second': 1.826, 'total_flos': 1120003448624640.0, 'train_loss': 0.35916298902184396, 'epoch': 3.0})

In [83]:
model.save_pretrained("trained_model")
model = AutoModelForQuestionAnswering.from_pretrained("trained_model")

#Step 4: Inference and Evaluation after Fine-Tuning

In [131]:
'''
Let's now test the model on the same test dataset example after fine-tuning.
'''

question = 'what is the redshift value?'
context = context_test
# answer = dataset["validation"][2]["answer"]
# answer #Real answer

In [132]:
question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer)
question_answerer(question=question, context=context, handle_impossible_answer=True) #Predicted answer

{'score': 0.9943572282791138, 'start': 1316, 'end': 1323, 'answer': ' z < 4.'}

In [133]:
print(question)
print(context)

what is the redshift value?
title:   gcn circular
number:  1000012
subject: grb 181213a: not optical observations
date:    18/12/14 10:01:41 gmt
from:    kasper elm heintz at univ. of iceland and dawn/nbi  <keh14@hi.is>

k. e. heintz (univ. of iceland), d. b. malesani (dawn/nbi/dtu and dark/nbi), j. p. u. fynbo (dawn/nbi/dtu), a. de ugarte postigo (heth/iaa-csic and dark/nbi), l. balaguer-nuã±ez, j. carbajo (dept. fqa, univ. de barcelona), f. galindo and c. perez (not), report on behalf of a larger collaboration:

we observed the optical afterglow (lipunov et al., gcn 23526; hu et al., gcn 23527; siegel et al., gcn 23529; belkin et al., gcn 23530) of grb 181213a (evans et al., gcn 23525) with the 2.5-m nordic optical telescope (not) equipped with alfosc.

spectroscopy was secured: we obtained 4x900s exposures using grism 4 (with a wavelength coverage from 350 to 900 nm) starting at 23:27:14 ut on december 13 (i.e. 10.5 hr after trigger). the observations were obtained at a seeing aroun

This answer is much better! Now let's get the evaluation metrics.

In [87]:
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions
post_training_scores = compute_metrics(start_logits, end_logits, validation_dataset, dataset["validation"])
post_training_scores

  0%|          | 0/157 [00:00<?, ?it/s]

{'exact_match': 94.90445859872611, 'f1': 95.32908704883228}

This is a lot better. We have successfully fine-tuned our model.

From our investigations, it seems to be that the best model is "deepset/deberta-v3-base-squad2". Let's test their performance on the test set now.

In [88]:
test_dataset = dataset["test"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=dataset["test"].column_names,
)

predictions, _, _ = trainer.predict(test_dataset)
start_logits, end_logits = predictions
post_training_scores = compute_metrics(start_logits, end_logits, test_dataset, dataset["test"])
post_training_scores

Map:   0%|          | 0/158 [00:00<?, ? examples/s]

  0%|          | 0/158 [00:00<?, ?it/s]

{'exact_match': 94.30379746835443, 'f1': 95.14767932489451}

An approximate 95% accuracy!