In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pandas as pd
import json
from sklearn.model_selection import train_test_split
from pandas.io.json import json_normalize
import torch
from transformers import EncoderDecoderModel, AutoTokenizer
from typing import *

Download Data

In [None]:
import os
import urllib.request
from tqdm import tqdm

class DownloadProgressBar(tqdm):
  def update_to(self, b=1, bsize=1, tsize=None):
    if tsize is not None:
      self.total = tsize
    self.update(b*bsize - self.n)

def download_url(url, output_path):
  with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
    urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)

def download_data(data_path, url_path, suffix):
  if not os.path.exists(data_path):
    os.makedirs(data_path)

  data_path = os.path.join(data_path,f'{suffix}.json')

  if not os.path.exists(data_path):
    print(f"Downloading CoQA {suffix} data split... (it may take a while)")
    download_url(url=url_path, output_path=data_path)
    print("Download Completed!")

In [None]:
#Train Data
train_url = "https://nlp.stanford.edu/data/coqa/coqa-train-v1.0.json"
download_data(data_path='coqa',url_path=train_url, suffix='train')

#Test Data
test_url = "https://nlp.stanford.edu/data/coqa/coqa-dev-v1.0.json"
download_data(data_path='coqa', url_path = test_url, suffix='test')

In [None]:
train_data = json.load((open('/content/coqa/train.json')))
qas = json_normalize(train_data['data'], ['questions'], ['source', 'id', 'story'])
ans = json_normalize(train_data['data'], ['answers'],['id'])
train_df = pd.merge(qas,ans, left_on=['id','turn_id'], right_on=['id','turn_id'])
train_df.loc[10:30,['turn_id','input_text_x', 'input_text_y', 'span_text']]

  qas = json_normalize(train_data['data'], ['questions'], ['source', 'id', 'story'])
  ans = json_normalize(train_data['data'], ['answers'],['id'])


Unnamed: 0,turn_id,input_text_x,input_text_y,span_text
10,11,when were the Secret Archives moved from the r...,at the beginning of the 17th century;,atican Secret Archives were separated from the...
11,12,how many items are in this secret collection?,150000,Vatican Secret Archives were separated from t...
12,13,Can anyone use this library?,anyone who can document their qualifications a...,The Vatican Library is open to anyone who can...
13,14,what must be requested to view?,unknown,unknown
14,15,what must be requested in person or by mail?,Photocopies,Photocopies for private study of pages from bo...
15,16,of what books?,only books published between 1801 and 1990,hotocopies for private study of pages from boo...
16,17,What is the Vat the library of?,the Holy See,"simply the Vat, is the library of the Holy See,"
17,18,How many books survived the Pre Lateran period?,a handful of volumes,"Pre-Lateran period, comprising the initial day..."
18,19,what is the point of the project started in 2014?,digitising manuscripts,Vatican Library began an initial four-year pro...
19,20,what will this allow?,them to be viewed online.,"manuscripts, to be made available online."


In [None]:
train_df['q_first_word']=train_df['input_text_x'].str.lower().str.extract(r'(\w+)')
train_df['q_first_two_words']=train_df['input_text_x'].str.lower().str.extract(r'^((?:\S+\s+){1}\S+).*')
train_df.groupby('q_first_word').count().sort_values(by='input_text_x',ascending=False).head(30)

Unnamed: 0_level_0,input_text_x,turn_id,bad_turn_x,source,id,story,span_start,span_end,span_text,input_text_y,bad_turn_y,q_first_two_words
q_first_word,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
what,32092,32092,114,32092,32092,32092,32092,32092,32092,32092,611,31711
who,15684,15684,45,15684,15684,15684,15684,15684,15684,15684,301,15075
how,10946,10946,37,10946,10946,10946,10946,10946,10946,10946,224,10662
did,7381,7381,19,7381,7381,7381,7381,7381,7381,7381,137,7381
where,7214,7214,21,7214,7214,7214,7214,7214,7214,7214,121,6305
was,5121,5121,30,5121,5121,5121,5121,5121,5121,5121,121,5121
when,4530,4530,10,4530,4530,4530,4530,4530,4530,4530,83,3614
is,3431,3431,16,3431,3431,3431,3431,3431,3431,3431,76,3431
why,2921,2921,13,2921,2921,2921,2921,2921,2921,2921,65,1885
does,2110,2110,5,2110,2110,2110,2110,2110,2110,2110,33,2110


In [None]:
train_df.groupby('q_first_two_words').count().sort_values(by='input_text_x',ascending=False).head(30)

Unnamed: 0_level_0,input_text_x,turn_id,bad_turn_x,source,id,story,span_start,span_end,span_text,input_text_y,bad_turn_y,q_first_word
q_first_two_words,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
what did,5622,5622,27,5622,5622,5622,5622,5622,5622,5622,97,5622
what was,5079,5079,10,5079,5079,5079,5079,5079,5079,5079,100,5079
what is,4800,4800,15,4800,4800,4800,4800,4800,4800,4800,101,4800
how many,3692,3692,12,3692,3692,3692,3692,3692,3692,3692,108,3692
who was,3390,3390,9,3390,3390,3390,3390,3390,3390,3390,74,3390
who is,2409,2409,11,2409,2409,2409,2409,2409,2409,2409,29,2409
did he,2366,2366,5,2366,2366,2366,2366,2366,2366,2366,40,2366
where did,1988,1988,3,1988,1988,1988,1988,1988,1988,1988,40,1988
when did,1810,1810,5,1810,1810,1810,1810,1810,1810,1810,38,1810
what does,1797,1797,5,1797,1797,1797,1797,1797,1797,1797,37,1797


In [None]:
train_df = train_df.loc[train_df['input_text_y']!='unknown']
train_df.loc[10:30,['turn_id','input_text_x', 'input_text_y', 'span_text']]

Unnamed: 0,turn_id,input_text_x,input_text_y,span_text
10,11,when were the Secret Archives moved from the r...,at the beginning of the 17th century;,atican Secret Archives were separated from the...
11,12,how many items are in this secret collection?,150000,Vatican Secret Archives were separated from t...
12,13,Can anyone use this library?,anyone who can document their qualifications a...,The Vatican Library is open to anyone who can...
14,15,what must be requested in person or by mail?,Photocopies,Photocopies for private study of pages from bo...
15,16,of what books?,only books published between 1801 and 1990,hotocopies for private study of pages from boo...
16,17,What is the Vat the library of?,the Holy See,"simply the Vat, is the library of the Holy See,"
17,18,How many books survived the Pre Lateran period?,a handful of volumes,"Pre-Lateran period, comprising the initial day..."
18,19,what is the point of the project started in 2014?,digitising manuscripts,Vatican Library began an initial four-year pro...
19,20,what will this allow?,them to be viewed online.,"manuscripts, to be made available online."
20,1,Where was the Auction held?,Hard Rock Cafe,Hard Rock Cafe in New York's Times Square


In [None]:
test_data = json.load((open('/content/coqa/test.json')))
qas = json_normalize(test_data['data'], ['questions'], ['source', 'id', 'story'])
ans = json_normalize(test_data['data'], ['answers'],['id'])
test_df = pd.merge(qas,ans, left_on=['id','turn_id'], right_on=['id','turn_id'])
test_df.loc[10:30,['turn_id','input_text_x', 'input_text_y', 'span_text']]

  qas = json_normalize(test_data['data'], ['questions'], ['source', 'id', 'story'])
  ans = json_normalize(test_data['data'], ['answers'],['id'])


Unnamed: 0,turn_id,input_text_x,input_text_y,span_text
10,11,What did the other cats do when Cotton emerged...,licked her face,Her sisters licked her face
11,12,Did they want Cotton to change the color of he...,no,We would never want you to be any other way
12,1,what was the name of the fish,Asta.,Asta.
13,2,What looked like a birds belly,a bottle,a bottle
14,3,who said that,Asta.,"""It looks like a bird's belly,"" said Asta."
15,4,Was Sharkie a friend?,Yes,Asta's friend Sharkie
16,5,did they get the bottle?,Yes,So they caught the bottle
17,6,What was in it,a note,It was a note.
18,7,Did a little boy write the note,No,This note is from a little girl
19,8,Who could read the note,Asta's papa,Asta's papa read the note


In [None]:
test_df = test_df.loc[test_df['input_text_y']!='unknown']
test_df.loc[10:30,['turn_id','input_text_x','input_text_y','span_text']]

Unnamed: 0,turn_id,input_text_x,input_text_y,span_text
10,11,What did the other cats do when Cotton emerged...,licked her face,Her sisters licked her face
11,12,Did they want Cotton to change the color of he...,no,We would never want you to be any other way
12,1,what was the name of the fish,Asta.,Asta.
13,2,What looked like a birds belly,a bottle,a bottle
14,3,who said that,Asta.,"""It looks like a bird's belly,"" said Asta."
15,4,Was Sharkie a friend?,Yes,Asta's friend Sharkie
16,5,did they get the bottle?,Yes,So they caught the bottle
17,6,What was in it,a note,It was a note.
18,7,Did a little boy write the note,No,This note is from a little girl
19,8,Who could read the note,Asta's papa,Asta's papa read the note


In [None]:
train, val = train_test_split(train_df, test_size=0.2, random_state=42)
train.head()

Unnamed: 0,input_text_x,turn_id,bad_turn_x,source,id,story,span_start,span_end,span_text,input_text_y,bad_turn_y,q_first_word,q_first_two_words
54860,So how did they get to 28?,20,,race,39dd6s19jpbtyxnmal6qgea8wr2ze3,Where did that number come from? Eleven and Tw...,1639,1740,he took one day from each of the 30-day months...,he took one day from each of the 30-day months...,,so,so how
69607,How much was the package in value?,9,,cnn,3ii4upycoj7fsz8vructj3gjsr7qdt,"Abidjan, Ivory Coast (CNN) -- The European Uni...",80,98,180 million euros,180 million euros,,how,how much
94456,Did she think Adams was untrustworthy?,6,,cnn,3wq3b2kge8gywyqusjv8nckbhrp1bi,"ATLANTA, Georgia (CNN) -- Michele Trobaugh reg...",426,462,She says she trusted him right away.,No,,did,did she
94333,Who was he talking to?,3,,gutenberg,3qapzx2qn4d41w5gd7yx8eyxhj320q,"CHAPTER V--""BLOODY AS THE HUNTER"" \n\nThe lads...",1208,1244,"""Ye but deride me,"" answered Matcham",Matcham,,who,who was
47220,What does Pleistocene mean literally?,15,,wikipedia,3nvc2eb65qzqj9xkpfnbjgx90ke3yk,"The Pleistocene (, often colloquially referred...",1410,1420,"""Most New""","""Most New.""",,what,what does


In [None]:
model_berttiny = EncoderDecoderModel.from_encoder_decoder_pretrained('prajjwal1/bert-tiny','prajjwal1/bert-tiny')
tokenizer = AutoTokenizer.from_pretrained('prajjwal1/bert-tiny')

Downloading:   0%|          | 0.00/285 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertLMHeadModel: ['cls.seq_re

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [None]:
input_ids = tokenizer('train', return_tensors='pt').input_ids

In [None]:
labels = tokenizer('train', return_tensors='pt').input_ids

In [None]:
loss = model_berttiny(input_ids = input_ids, decoder_input_ids=labels, labels=labels).loss



In [None]:
loss.backward()

In [None]:
model_berttiny.eval()

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_af

In [None]:
greedy_output = model_berttiny.generate(input_ids, decoder_start_token_id=model_berttiny.config.decoder.pad_token_id)