In [1]:
import torch
import transformers
import numpy as np
import os
import pandas as pd
from datasets import *
from tqdm import tqdm
import json
import os
import random
import time
import urllib.request
from copy import deepcopy
from sklearn.model_selection import GroupShuffleSplit
from transformers import AutoTokenizer, PreTrainedTokenizerFast
import plotly.express as px
from torch.nn.utils.rnn import pack_padded_sequence 

In [2]:
def set_reproducibility(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    transformers.set_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    os.environ['TF_DETERMINISTIC_OPS'] = '1'

In [3]:
def check_tokens(tokenizer):
    special_tokens = tokenizer.special_tokens_map
    special_ids = tokenizer.convert_tokens_to_ids(list(special_tokens.values()))
    print('Special Tokens:')
    for token_type, token_list in special_tokens.items():
        print(f'{token_type}: {token_list}')
    for token, id in zip(special_tokens.keys(), special_ids):
        print(f'{token}: {id}')

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


In [5]:
def add_history_to_context(df):
    print('Adding the history to each entry of the dataframe...')
    # Sort the dataframe by 'id' and 'turn_id'
    df = df.sort_values(['id', 'turn_id'])
    # Group the dataframe by 'id'
    groups = df.groupby('id')
    # Create an empty list to store the updated rows
    new_rows = []
    # Iterate over each group
    for _, group in tqdm(groups):
        # Initialize an empty string for the history 
        history = ''
        # Iterate over each row in the group
        for i, row in group.iterrows():
            # Concatenate 'input_text_x' and 'input_text_y' for each row
            if row['turn_id'] > 1: # only consider previous turn_ids
                prev_rows = group.loc[group['turn_id'] < row['turn_id'], ['input_text_x', 'input_text_y']]
                history = ''.join(prev_rows['input_text_x'] + prev_rows['input_text_y'] + ';')
            else:
                history = ''
            # Update the 'history_context' column for the current row
            new_row = row.copy()
            if history == '':
                new_row['history_context'] = row['story']
            else:
                new_row['history_context'] = history+'</s>'+row['story']
            # Append the updated row to the list of new rows
            new_rows.append(new_row)
    # Create a new dataframe with the updated rows
    result_df = pd.DataFrame(new_rows)
    print('History added.')
    return result_df

In [6]:
import collections
import re
import string
from typing import Callable, Sequence, TypeVar, Tuple


def make_qid_to_has_ans(dataset):
    qid_to_has_ans = {}
    for article in dataset:
        for p in article["paragraphs"]:
            for qa in p["qas"]:
                qid_to_has_ans[qa["id"]] = bool(qa["answers"])
    return qid_to_has_ans


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()


def compute_exact(a_pred: str, a_gold: str) -> int:
    return int(normalize_answer(a_pred) == normalize_answer(a_gold))


def compute_f1(a_pred: str, a_gold: str) -> float:
    pred_toks = get_tokens(a_pred)
    gold_toks = get_tokens(a_gold)
    common = collections.Counter(pred_toks) & collections.Counter(gold_toks)  # type: ignore[var-annotated]
    num_same = sum(common.values())
    if len(pred_toks) == 0 or len(gold_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return float(pred_toks == gold_toks)
    if num_same == 0:
        return 0.0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


_P = TypeVar("_P")
_G = TypeVar("_G")
_T = TypeVar("_T", int, float, Tuple[int, ...], Tuple[float, ...])


def metric_max_over_ground_truths(
    metric_fn: Callable[[_P, _G], _T], prediction: _P, ground_truths: Sequence[_G]
) -> _T:
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def get_metric_score(prediction: str, gold_answers: Sequence[str]) -> Tuple[int, float]:
    exact_scores = metric_max_over_ground_truths(compute_exact, prediction, gold_answers)
    f1_scores = metric_max_over_ground_truths(compute_f1, prediction, gold_answers)
    return exact_scores, f1_scores

In [7]:
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)
        
def download_url(url, output_path):
    with DownloadProgressBar(unit='B', unit_scale=True,
                             miniters=1, desc=url.split('/')[-1]) as t:
        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)

def download_data(data_path, url_path, suffix):    
    if not os.path.exists(data_path):
        os.makedirs(data_path)
        
    data_path = os.path.join(data_path, f'{suffix}.json')

    if not os.path.exists(data_path):
        print(f"Downloading CoQA {suffix} data split... (it may take a while)")
        download_url(url=url_path, output_path=data_path)
        urllib.request.urlretrieve(url_path, filename=data_path)
        print("Download completed!")

In [8]:
# Creating Dataframes and removing unanswerable questions
train_data = json.load((open('coqa/train.json')))
test_data = json.load((open('coqa/test.json')))

qas = pd.json_normalize(train_data['data'], ['questions'], ['source', 'id', 'story'])
ans = pd.json_normalize(train_data['data'], ['answers'],['id'])
train_val_df = pd.merge(qas,ans, left_on=['id','turn_id'], right_on=['id','turn_id'])
train_val_df = train_val_df.loc[train_val_df['input_text_y']!='unknown']

qas = pd.json_normalize(test_data['data'], ['questions'], ['source', 'id', 'story'])
ans = pd.json_normalize(test_data['data'], ['answers'],['id'])
test_df = pd.merge(qas,ans, left_on=['id','turn_id'], right_on=['id','turn_id'])
test_df = test_df.loc[test_df['input_text_y']!='unknown']

In [9]:
# Removing bad turns
train_val_df = train_val_df.loc[(train_val_df['bad_turn_x'] != 'True') & (train_val_df['bad_turn_y'] != 'True')]

# Removing equal text/answer entries
train_val_df = train_val_df[train_val_df.story != train_val_df.input_text_y]
test_df = test_df[test_df.story != test_df.input_text_y]

# Removing enties with empty answers
train_val_df = train_val_df[train_val_df['input_text_y'].str.len()>0]
test_df = test_df[test_df['input_text_y'].str.len()>0]

In [10]:
# Text preprocess
def preprocess(ds,columns):
    ds = ds.replace(r'\n',' ', regex=True)
#     ds = ds.replace(r'[^\w\s]+', ' ', regex=True)
#     for feature in columns:
#         ds[feature] = ds[feature].str.lower().str.strip()
        
    return ds

columns = ['story', 'input_text_x', 'span_text', 'input_text_y']

train_val_df = preprocess(train_val_df,columns)
test_df = preprocess(test_df,columns)

In [11]:
set_reproducibility(42)

train_inds, val_inds = next(GroupShuffleSplit(test_size=.20, n_splits=2, random_state = 42).split(train_val_df, groups=train_val_df['id']))

train_df = train_val_df.iloc[train_inds]
val_df = train_val_df.iloc[val_inds].reset_index()

# Add the histoy_context column to the datasets
val_df = add_history_to_context(val_df)
#display(val_df.head())

Adding the history to each entry of the dataframe...


100%|██████████| 1439/1439 [00:29<00:00, 48.72it/s]


History added.


In [12]:
# Train/Validation Split
seed = 42

set_reproducibility(seed)

train_inds, val_inds = next(GroupShuffleSplit(test_size=.20, n_splits=2, random_state = 42).split(train_val_df, groups=train_val_df['id']))

train_df = train_val_df.iloc[train_inds]
val_df = train_val_df.iloc[val_inds].reset_index()

# Add the histoy_context column to the datasets
train_df = add_history_to_context(train_df)
val_df = add_history_to_context(val_df)
test_df = add_history_to_context(test_df)

Adding the history to each entry of the dataframe...


100%|██████████| 5754/5754 [01:58<00:00, 48.58it/s]


History added.
Adding the history to each entry of the dataframe...


100%|██████████| 1439/1439 [00:29<00:00, 47.99it/s]


History added.
Adding the history to each entry of the dataframe...


100%|██████████| 500/500 [00:10<00:00, 45.62it/s]


History added.


In [13]:
set_train = set(train_df['id'])
set_val = set(val_df['id'])

overlap = False
for i in set_train:
    if i in set_val:
        overlap = True
        break

print('Overlap' if overlap else 'No overlap')

No overlap


In [14]:
# Relevant features for the dataset
features = ['story', 'input_text_x', 'span_text', 'input_text_y', 'history_context','source']

# Dataframes to Datasets
train_df_to_ds = train_df[features]
val_df_to_ds = val_df[features]
test_df_to_ds = test_df[features]

train_df_to_ds = train_df_to_ds.rename(columns={'input_text_x': 'question', 'story': 'context',\
                                               'input_text_y': 'answer', 'span_text': 'text'})
val_df_to_ds = val_df_to_ds.rename(columns={'input_text_x': 'question', 'story': 'context',\
                                               'input_text_y': 'answer', 'span_text': 'text'})
test_df_to_ds = test_df_to_ds.rename(columns={'input_text_x': 'question', 'story': 'context',\
                                               'input_text_y': 'answer', 'span_text': 'text'})

In [15]:
model_checkpoint = 'distilroberta-base'

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
assert isinstance(tokenizer, PreTrainedTokenizerFast)

tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
check_tokens(tokenizer)

Special Tokens:
bos_token: <s>
eos_token: </s>
unk_token: <unk>
sep_token: </s>
pad_token: <pad>
cls_token: <s>
mask_token: <mask>
bos_token: 0
eos_token: 2
unk_token: 3
sep_token: 2
pad_token: 1
cls_token: 0
mask_token: 50264


In [16]:
story_lengths = [len(tokenizer(x,y)['input_ids']) for (x,y) in zip(train_val_df['story'],train_val_df['input_text_x'])]

puts = px.box(list(story_lengths), title='Tokenized Stories and Question Lengths Distribution')
puts.show()

Token indices sequence length is longer than the specified maximum sequence length for this model (715 > 512). Running this sequence through the model will result in indexing errors


In [17]:
max_len_input=516
print(f'Max Len:{max_len_input}')

Max Len:516


In [18]:
answer_lengths = [len(tokenizer(x)['input_ids']) for x in train_val_df['input_text_y']]

puts = px.box(list(answer_lengths), title='Tokenized Answers Lengths Distribution', color_discrete_sequence=['red'])
puts.show()

In [19]:
max_len_answer=13
print(f'Max Len Ans:{max_len_answer}')

Max Len Ans:13


In [20]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

def plot_sources(train_df, val_df):
    sources = train_df["source"].unique()

    fig = make_subplots(rows=1, cols=2, subplot_titles=('Sources of stories in train set', 'Sources of stories in validation set'))

    fig.add_trace(go.Bar(x=train_df["source"].value_counts().index.tolist(), 
                         y=train_df["source"].value_counts().values.tolist(),
                         name='Train Data',
                         marker=dict(color='#1f77b4')), 
                  row=1, col=1)

    fig.add_trace(go.Bar(x=val_df["source"].value_counts().index.tolist(), 
                         y=val_df["source"].value_counts().values.tolist(),
                         name='Validation Data',
                         marker=dict(color='#ff7f0e')), 
                  row=1, col=2)

    fig.update_layout(height=500, width=1300, showlegend=True, 
                      legend=dict(x=0.8, y=1.15, orientation="h"))

    fig.show()

plot_sources(train_df, val_df)

In [21]:
#BATCH SPLIT
dataset_ratio = 100
batch_size = 42


train_samples = (round(train_df_to_ds.shape[0]*dataset_ratio /100)//batch_size)* batch_size
val_samples  = (round(val_df_to_ds.shape[0]*dataset_ratio / 100)//batch_size)*batch_size
test_samples = (round(test_df_to_ds.shape[0]*dataset_ratio / 100) // batch_size)*batch_size

train_dataset = Dataset.from_dict(train_df_to_ds.iloc[:train_samples])
val_dataset = Dataset.from_dict(val_df_to_ds.iloc[:val_samples])
test_dataset = Dataset.from_dict(test_df_to_ds.iloc[:test_samples])

dataset_COQA = DatasetDict({'train': train_dataset, 'validation': val_dataset, 'test': test_dataset})
print(dataset_COQA)

DatasetDict({
    train: Dataset({
        features: ['context', 'question', 'text', 'answer', 'history_context', 'source'],
        num_rows: 85806
    })
    validation: Dataset({
        features: ['context', 'question', 'text', 'answer', 'history_context', 'source'],
        num_rows: 21420
    })
    test: Dataset({
        features: ['context', 'question', 'text', 'answer', 'history_context', 'source'],
        num_rows: 7896
    })
})


In [22]:
def prepare_features(batch, tokenizer, max_len_input, max_len_ans, history=False, test=False):
     
    if not test:
        truncated_contexts = []
        for context, span, ans in zip(batch['context'], batch['text'],batch['answer']):
            max_len_context = max_len_input - len(ans)
            start = context.find(span)
            end = start + len(span)

            start = max(0,start-(max_len_context-len(span))//2)
            end = min(len(context), end+(max_len_context - len(span))// 2)

            truncated_context = context[start:end]
            truncated_contexts.append(truncated_context)
    if history:
        encoded_batch_inputs = tokenizer(batch['question'], batch['history_context'], max_length=max_len_input, truncation='only_second', padding='max_length', return_tensors='pt')
    else:
        encoded_batch_inputs = tokenizer(batch['question'], batch['context'], max_length=max_len_input, truncation='only_second', padding='max_length', return_tensors='pt')

    encoded_batch_labels = tokenizer(batch['answer'], max_length=max_len_ans, padding='max_length', truncation=True, return_tensors='pt')

    encoded_batch_inputs['labels']=encoded_batch_labels.input_ids

    return encoded_batch_inputs

In [23]:
tokenized_datasets = DatasetDict()

tokenized_datasets['train'] = dataset_COQA['train'].map(lambda batch: prepare_features(batch, tokenizer, max_len_input, max_len_answer), batched=True, batch_size=batch_size,remove_columns=dataset_COQA['train'].column_names)

tokenized_datasets['val'] = dataset_COQA['validation'].map(lambda batch: prepare_features(batch, tokenizer, max_len_input, max_len_answer), batched=True, batch_size=batch_size, remove_columns=dataset_COQA['validation'].column_names)

tokenized_datasets['train'] = dataset_COQA['train'].map(lambda batch: prepare_features(batch, tokenizer, max_len_input, max_len_answer, test=True), batched=True, batch_size=batch_size, remove_columns=dataset_COQA['test'].column_names)

print(tokenized_datasets)

Map:   0%|          | 0/85806 [00:00<?, ? examples/s]

Map:   0%|          | 0/21420 [00:00<?, ? examples/s]

Map:   0%|          | 0/85806 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 85806
    })
    val: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 21420
    })
})


In [62]:
from transformers import RobertaModel
model = RobertaModel.from_pretrained(model_checkpoint, output_hidden_states=True)

model.eval()

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-5): 6 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout)

In [88]:
import torch.nn as nn
import torch.functional as F
class EncoderRNN(nn.Module):
    
    def __init__(self, checkpoint):
        super().__init__()

        self.bert = RobertaModel.from_pretrained(checkpoint)

        #BERT CONFIG
        self.bert.config.eos_token_id = tokenizer.sep_token_id
        self.bert.config.pad_token_id = tokenizer.pad_token_id
        self.bert.config.vocab_size = tokenizer.vocab_size
        self.bert.config.max_length = max_len_answer
        self.bert.config.min_length = 1

        self.hidden_size = self.bert.config.hidden_size
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size, bidirectional=True)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, inputs):
        encoded_layers, pooled_output = self.bert(input_ids=inputs[0], attention_mask=inputs[1])
        encoded_layers = encoded_layers.permute(1,0,2)
        enc_hidden, (last_hidden, last_cell) = self.lstm(pack_padded_sequence(encoded_layers, inputs[2]))
        output_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)
        output_hidden = F.dropout(output_hidden, 0.2)
        return enc_hidden, output_hidden



In [101]:
class DecoderRNN(nn.Module):
    def __init__(self, checkpoint):
        super().__init__()

        self.bert = RobertaModel.from_pretrained(checkpoint)

        self.bert.config.decoder_start_token_id = tokenizer.cls_token_id
        self.bert.config.eos_token_id = tokenizer.sep_token_id
        self.bert.config.pad_token_id = tokenizer.pad_token_id
        self.bert.config.vocab_size = tokenizer.vocab_size
        self.bert.config.max_length = max_len_answer
        self.bert.config.min_length = 1

        self.hidden_size = self.bert.config.hidden_size
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size, bidirectional=True)
        self.clf = nn.Linear(self.hidden_size*2,1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, inputs, hidden, cell):

        embedded, pooled_output = self.dropout(self.bert(input_ids=inputs[0], attention_mask=inputs[1]))
        embedded= embedded.permute(1,0,2)
        output, (last_hidden, last_cell) = self.lstm(pack_padded_sequence(embedded,(hidden, cell)))
        output_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)
        output_hidden = F.dropout(output_hidden, 0.2)
        output = self.clf(output_hidden)

        return F.sigmoid(output)


[A

In [102]:
class Seq2SeqRNN(nn.Module):

    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, inputs):
        outputs = torch.zeros(max_len_answer, batch_size, tokenizer.vocab_size).to(self.device)
        hidden, cell = self.encoder(inputs)
        for t in range(1, max_len_answer):

            output = self.decoder(inputs, hidden, cell)

            outputs[t] = output
    
        return outputs

            

In [103]:
enc = EncoderRNN(model_checkpoint)
dec = DecoderRNN(model_checkpoint)

model = Seq2SeqRNN(enc,dec,device).to(device)

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']
- This IS expected if you are initializin

In [104]:
print(model)

Seq2SeqRNN(
  (encoder): EncoderRNN(
    (bert): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50265, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0-5): 6 x RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
                (dense): Linear(in_features=768, out_

In [105]:
def compute_metrics(pred,tokenizer):
    labels = pred.label_ids
    preds = pred.predictions
    
    labels_text = tokenizer.batch_decode(labels, skip_special_tokens=True)
    preds_text = tokenizer.batch_decode(preds, skip_special_tokens=True)
    
    squad_scores=[]
    for i in range(len(preds_text)):
        squad_scores.append(compute_f1(str(preds_text[i]), str(labels_text[i])))
    mean_squad_f1 = sum(squad_scores)/len(squad_scores)

    return {"squad_f1_score": mean_squad_f1}

In [106]:
from ignite.utils import convert_tensor
from ignite.engine.engine import Engine

def _prepare_batch(batch, device=None,non_blocking=False):
    x,y=batch
    return (convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking))


def create_supervised_trainer_1(model, optimizer, loss_fn, metrics={}, device = None):
    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = _prepare_batch(batch, device=device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y.float())
        loss.backward()
        optimizer.step()
        return loss.item(), y_pred, y
    
    def _metrics_transform(output):
        return output[1], output[2]

    engine=Engine(_update)

    for name, metric in metrics.items():
        metric._output_transform = _metrics_transform
        metric.attach(engine, name)
    return engine

def create_supervised_evaluator_1(model, metrics=None, device=None, non_blocking=False, prepare_batch=_prepare_batch, output_transform=lambda x,y,y_pred: (y_pred,y, )):
    metrics = metrics or {}

    if device:
        model
    
    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
            y_pred = model(x)
            return output_transform(x,y.float(), y_pred)
    
    engine = Engine(_inference)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine

In [119]:
from ignite.engine import create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Loss
from ignite.engine.engine import Events
from torch.optim.lr_scheduler import ExponentialLR


def train(log_interval = 100, epochs=2, lr= 0.0001):
    print("in1")
    criterion = nn.BCELoss()    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    lr_scheduler = ExponentialLR(optimizer, gamma=0.90)
    trainer = create_supervised_trainer_1(model.to(device), optimizer, criterion, device=device)
    evaluator = create_supervised_evaluator_1(model.to(device), metrics={'BCELoss ': Loss(criterion)}, device=device)
    print("in1")
    if log_interval is None:
        e = Events.ITERATION_COMPLETED
        log_interval = 1
    else:
        e = Events.ITERATION_COMPLETED(every=log_interval)
    print("in1")
    desc = 'Loss: {:.4f} | lr: {:.4f}'
    print("in1")
    pbar = tqdm(initial=0, leave=False, total=len(tokenized_datasets['train'][0]), desc=desc.format(0,lr))
    print("in1")
    @trainer.on(e)
    def log_training_loss(engine):
        pbar.refresh()
        lr=optimizer.param_groups[0]['lr']
        pbar.desc = desc.format(engine.state.output[0],lr)
        pbar.update(log_interval)
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def update_lr_scheduler(engine):
        lr_scheduler.step()


    @trainer.on(Events.EPOCH_COMPLETED) 
    def log_training_results(engine):
        evaluator.run(tokenized_datasets['train'])
        metrics = evaluator.state.metrics
        avg_loss = metrics['BCELoss']
        tqdm.write('Train Epoch: {} BCE loss: {:.2f}'.format(engine.state.epoch, avg_loss))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        pbar.refresh()
        evaluator.run(tokenized_datasets['val'])
        metrics=evaluator.state.metrics
        avg_loss=metrics['BCELoss']
        tqdm.write(
            "Valid Epoch: {} BCE loss: {:.2f}".format(engine.state.epoch, avg_loss)
        )
        pbar.n = pbar.last_print_n = 0

    
    try:
        trainer.run(tokenized_datasets['train'], max_epochs=epochs)

    except Exception as e:
        import traceback
        print(traceback.format_exc())
    return model




In [120]:
model=train()

in1
in1
in1
in1


Current run is terminating due to exception: too many values to unpack (expected 2)
Engine run is terminating due to exception: too many values to unpack (expected 2)


in1
Traceback (most recent call last):
  File "C:\Users\rullo\AppData\Local\Temp\ipykernel_14284\1070021302.py", line 57, in train
    trainer.run(tokenized_datasets['train'], max_epochs=epochs)
  File "d:\envs\tf\lib\site-packages\ignite\engine\engine.py", line 892, in run
    return self._internal_run()
  File "d:\envs\tf\lib\site-packages\ignite\engine\engine.py", line 935, in _internal_run
    return next(self._internal_run_generator)
  File "d:\envs\tf\lib\site-packages\ignite\engine\engine.py", line 993, in _internal_run_as_gen
    self._handle_exception(e)
  File "d:\envs\tf\lib\site-packages\ignite\engine\engine.py", line 638, in _handle_exception
    raise e
  File "d:\envs\tf\lib\site-packages\ignite\engine\engine.py", line 959, in _internal_run_as_gen
    epoch_time_taken += yield from self._run_once_on_dataset_as_gen()
  File "d:\envs\tf\lib\site-packages\ignite\engine\engine.py", line 1087, in _run_once_on_dataset_as_gen
    self._handle_exception(e)
  File "d:\envs\tf\lib