In [12]:
import transformers 
import datasets
import torch
from torch.utils.data import Dataset
import logging
from transformers import TrainingArguments, Trainer

In [2]:
# Preprocessing
dataset_name = "quoref"
model_type="electra"
model_name= "damapika/electra-base-discriminator_squad_mod"
models_dir = "saved_models/electra-base-discriminator_mod_quoref"
checkpoint = 'electra'
max_input_length = 308


# ## Training
learning_rate = 3e-5
num_epochs = 3

In [3]:
dataset=datasets.load_dataset(dataset_name)

Found cached dataset quoref (C:/Users/dama_/.cache/huggingface/datasets/quoref/default/0.1.0/82bb58a6b25cd8dbb4625a7ba6a5d0a224af1f4d392ca0de8b9e0c23e78557fe)
100%|██████████| 2/2 [00:00<00:00, 30.06it/s]


In [21]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'context', 'title', 'url', 'answers'],
        num_rows: 19399
    })
    validation: Dataset({
        features: ['id', 'question', 'context', 'title', 'url', 'answers'],
        num_rows: 2418
    })
})

In [4]:
# calculate max context length for dataset
def calc_max_len(dataset):
  context_length_max=len(dataset[0]['context'])
  for i in range(len(dataset)):
    con_len=len(dataset[i]['context'])
    if(con_len<context_length_max):
      context_length_max=con_len
      print(context_length_max)
      print(dataset[i]['context'])
  return context_length_max

In [5]:
calc_max_len(dataset['train'])

1321
In 1919, the Chicago White Sox are considered one of the greatest baseball teams ever assembled; however, the team's stingy owner, Charles Comiskey, gives little inclination to reward his players for a spectacular season.
Gamblers "Sleepy" Bill Burns and Billy Maharg get wind of the players' discontent, asking shady player Chick Gandil to convince a select group of Sox—including star knuckleball pitcher Eddie Cicotte, who led the majors with a 29–7 win–loss record and an earned run average of 1.82—that they could earn more money by playing badly and throwing the series than they could earn by winning the World Series against the Cincinnati Reds . Cicotte was motivated because Comiskey refused him a promised $10,000 should he win 30 games for the season. Cicotte was nearing the milestone until Comiskey ordered team manager Kid Gleason to bench him for 2 weeks (missing 5 starts) with the excuse that the 35-year-old veteran's arm needed a rest before the series.
A number of players, 

308

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)

In [7]:
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_input_length ,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [8]:
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset['train'].column_names)

                                                                   

In [9]:
data_collator = transformers.DefaultDataCollator()

In [10]:
torch.cuda.empty_cache()


In [11]:
# Check if CUDA is available
if torch.cuda.is_available():
    # Set the device to CUDA
    device = torch.device('cuda')
    print('gpu')
else:
    # If CUDA is not available, fall back to CPU
    device = torch.device('cpu')
    print('cpu')

gpu


In [13]:
training_args = TrainingArguments(
    output_dir=models_dir,
    evaluation_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

Cloning https://huggingface.co/damapika/electra-base-discriminator_mod_quoref into local empty directory.


In [14]:
trainer.train()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdamapika[0m. Use [1m`wandb login --relogin`[0m to force relogin


 14%|█▎        | 500/3639 [04:16<28:05,  1.86it/s]

{'loss': 2.0232, 'learning_rate': 2.5877988458367684e-05, 'epoch': 0.41}


 27%|██▋       | 1000/3639 [08:47<20:18,  2.17it/s] 

{'loss': 1.6037, 'learning_rate': 2.1755976916735367e-05, 'epoch': 0.82}


                                                   
 33%|███▎      | 1213/3639 [10:53<16:17,  2.48it/s]

{'eval_loss': 1.5460432767868042, 'eval_runtime': 22.6556, 'eval_samples_per_second': 106.729, 'eval_steps_per_second': 6.709, 'epoch': 1.0}


 41%|████      | 1500/3639 [13:06<16:10,  2.20it/s]  

{'loss': 1.2728, 'learning_rate': 1.763396537510305e-05, 'epoch': 1.24}


 55%|█████▍    | 2000/3639 [17:13<13:25,  2.04it/s]

{'loss': 1.0994, 'learning_rate': 1.3511953833470735e-05, 'epoch': 1.65}


                                                   
 67%|██████▋   | 2426/3639 [21:05<08:28,  2.38it/s]

{'eval_loss': 1.5725703239440918, 'eval_runtime': 22.8523, 'eval_samples_per_second': 105.81, 'eval_steps_per_second': 6.651, 'epoch': 2.0}


 69%|██████▊   | 2500/3639 [21:42<09:12,  2.06it/s]  

{'loss': 1.0568, 'learning_rate': 9.389942291838417e-06, 'epoch': 2.06}


 82%|████████▏ | 3000/3639 [25:56<04:52,  2.18it/s]  

{'loss': 0.8181, 'learning_rate': 5.267930750206101e-06, 'epoch': 2.47}


 96%|█████████▌| 3500/3639 [29:59<01:08,  2.03it/s]

{'loss': 0.8029, 'learning_rate': 1.145919208573784e-06, 'epoch': 2.89}


                                                   
100%|██████████| 3639/3639 [31:35<00:00,  1.92it/s]

{'eval_loss': 1.772193431854248, 'eval_runtime': 25.7213, 'eval_samples_per_second': 94.008, 'eval_steps_per_second': 5.91, 'epoch': 3.0}
{'train_runtime': 1898.1318, 'train_samples_per_second': 30.66, 'train_steps_per_second': 1.917, 'train_loss': 1.223366385405127, 'epoch': 3.0}





TrainOutput(global_step=3639, training_loss=1.223366385405127, metrics={'train_runtime': 1898.1318, 'train_samples_per_second': 30.66, 'train_steps_per_second': 1.917, 'train_loss': 1.223366385405127, 'epoch': 3.0})

In [16]:
import wandb
wandb.init() 

0,1
eval/loss,▁▂█
eval/runtime,▁▁█
eval/samples_per_second,█▇▁
eval/steps_per_second,█▇▁
train/epoch,▁▂▃▃▄▅▅▇███
train/global_step,▁▂▃▃▄▅▅▇███
train/learning_rate,█▇▆▄▃▂▁
train/loss,█▆▄▃▂▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,1.77219
eval/runtime,25.7213
eval/samples_per_second,94.008
eval/steps_per_second,5.91
train/epoch,3.0
train/global_step,3639.0
train/learning_rate,0.0
train/loss,0.8029
train/total_flos,9147772860046128.0
train/train_loss,1.22337


In [17]:
trainer.push_to_hub()

Several commits (2) will be pushed upstream.
The progress bars may be unreliable.
To https://huggingface.co/damapika/electra-base-discriminator_mod_quoref
   17999ff..2abe38d  main -> main

