<a href="https://colab.research.google.com/github/MLFlexer/nlp-course/blob/Emma/bert_classification_with_lab6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install bpemb
# !pip install gensim
!pip install datasets
!pip install transformers
# !python -m spacy download en_core_web_sm



In [None]:
import os
import numpy as np
from collections import Counter
import torch
import torch.nn as nn
import datasets
datasets.logging.set_verbosity_error()
from datasets import load_metric, load_dataset
from google.colab import drive
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import f1_score
import pandas as pd
import random
from functools import partial

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score
from torch.optim.lr_scheduler import LambdaLR


# # uncomment if CAN'T CONNECT TO GPU (it happens...)
# import psutil
# import platform

In [None]:
!pip install transformers[torch] accelerate



In [None]:

# to save output of models so they can be reloaded

from google.colab import drive
drive.mount('/content/drive')
output_dir = '/content/drive/My Drive/Colab Notebooks/NLP/'

Mounted at /content/drive


In [None]:
def enforce_reproducibility(seed=42):
    # Sets seed manually for both CPU and CUDA
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # For atomic operations there is currently
    # no simple way to enforce determinism, as
    # the order of parallel operations is not known.
    # CUDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # System based
    random.seed(seed)
    np.random.seed(seed)

device = torch.device("cpu")
if torch.cuda.is_available():
  device = torch.device("cuda")

enforce_reproducibility()

In [None]:
# Preamble
import sys
sys.path.append('..')


In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['question_text', 'document_title', 'language', 'annotations', 'document_plaintext', 'document_url'],
        num_rows: 116067
    })
    validation: Dataset({
        features: ['question_text', 'document_title', 'language', 'annotations', 'document_plaintext', 'document_url'],
        num_rows: 13325
    })
})

In [None]:
dataset = load_dataset("copenlu/answerable_tydiqa")

train_set = dataset["train"]
validation_set = dataset["validation"]

df_train = train_set.to_pandas()
df_val = validation_set.to_pandas()

print(len(df_train))
print(len(df_val))

df_train.head()

116067
13325


Unnamed: 0,question_text,document_title,language,annotations,document_plaintext,document_url
0,Milloin Charles Fort syntyi?,Charles Fort,finnish,"{'answer_start': [18], 'answer_text': ['6. elo...",Charles Hoy Fort (6. elokuuta (joidenkin lähte...,https://fi.wikipedia.org/wiki/Charles%20Fort
1,“ダン” ダニエル・ジャドソン・キャラハンの出身はどこ,ダニエル・J・キャラハン,japanese,"{'answer_start': [35], 'answer_text': ['カリフォルニ...",“ダン”こと、ダニエル・ジャドソン・キャラハンは1890年7月26日、カリフォルニア州サンフ...,https://ja.wikipedia.org/wiki/%E3%83%80%E3%83%...
2,వేప చెట్టు యొక్క శాస్త్రీయ నామం ఏమిటి?,వేప,telugu,"{'answer_start': [12], 'answer_text': ['Azadir...","వేప (లాటిన్ Azadirachta indica, syn. Melia aza...",https://te.wikipedia.org/wiki/%E0%B0%B5%E0%B1%...
3,চেঙ্গিস খান কোন বংশের রাজা ছিলেন ?,চেঙ্গিজ খান,bengali,"{'answer_start': [414], 'answer_text': ['বোরজি...",চেঙ্গিজ খান (মঙ্গোলীয়: Чингис Хаан আ-ধ্ব-ব: ...,https://bn.wikipedia.org/wiki/%E0%A6%9A%E0%A7%...
4,రెయ్యలగడ్ద గ్రామ విస్తీర్ణత ఎంత?,రెయ్యలగడ్ద,telugu,"{'answer_start': [259], 'answer_text': ['27 హె...","రెయ్యలగడ్ద, విశాఖపట్నం జిల్లా, గంగరాజు మాడుగుల...",https://te.wikipedia.org/wiki/%E0%B0%B0%E0%B1%...


In [None]:
# Get train and validation data for each language
df_train_bengali = df_train[df_train['language'] == 'bengali']
df_train_arabic = df_train[df_train['language'] == 'arabic']
df_train_indonesian = df_train[df_train['language'] == 'indonesian']

df_val_bengali = df_val[df_val['language'] == 'bengali']
df_val_arabic = df_val[df_val['language'] == 'arabic']
df_val_indonesian = df_val[df_val['language'] == 'indonesian']


# For testing
df_val_english = df_val[df_val['language'] == 'english']
df_train_english = df_train[df_train['language'] == 'english']

In [None]:
# Create a new dataframe with the combined documents and questions and add if they are answerable
df_train_bengali_merged = pd.DataFrame({
    'text':(df_train_bengali["document_plaintext"] + df_train_bengali["question_text"]),
    'answerable':(df_train_bengali["annotations"].apply(lambda x: 0 if x['answer_start'] == [-1] else 1))
    })
df_train_arabic_merged = pd.DataFrame({
    'text': (df_train_arabic["document_plaintext"] + df_train_arabic["question_text"]),
    'answerable': (df_train_arabic["annotations"].apply(lambda x: 0 if x['answer_start'] == [-1] else 1))
                                    })
df_train_indonesian_merged = pd.DataFrame({
    'text':(df_train_indonesian["document_plaintext"] + df_train_indonesian["question_text"]),
    'answerable':(df_train_indonesian["annotations"].apply(lambda x: 0 if x['answer_start'] == [-1] else 1))
    })
df_train_english_merged = pd.DataFrame({
    'text':(df_train_english["document_plaintext"] + df_train_english["question_text"]),
    'answerable':(df_train_english["annotations"].apply(lambda x: 0 if x['answer_start'] == [-1] else 1))
    })


## Same for validation data
df_val_bengali_merged = pd.DataFrame({
    'text':(df_val_bengali["document_plaintext"] + df_val_bengali["question_text"]),
    'answerable':(df_val_bengali["annotations"].apply(lambda x: 0 if x['answer_start'] == [-1] else 1))
    })
df_val_arabic_merged = pd.DataFrame({
    'text': (df_val_arabic["document_plaintext"] + df_val_arabic["question_text"]),
    'answerable': (df_val_arabic["annotations"].apply(lambda x: 0 if x['answer_start'] == [-1] else 1))
                                    })
df_val_indonesian_merged = pd.DataFrame({
    'text':(df_val_indonesian["document_plaintext"] + df_val_indonesian["question_text"]),
    'answerable':(df_val_indonesian["annotations"].apply(lambda x: 0 if x['answer_start'] == [-1] else 1))
    })
df_val_english_merged = pd.DataFrame({
    'text':(df_val_english["document_plaintext"] + df_val_english["question_text"]),
    'answerable':(df_val_english["annotations"].apply(lambda x: 0 if x['answer_start'] == [-1] else 1))
    })

df_val_english_merged.head()

Unnamed: 0,text,answerable
30,Wound care encourages and speeds wound healing...,1
47,Brothers Amos and Wilfrid Ayre founded Burntis...,1
59,"For species of mammals, larger brains (in abso...",1
77,"As from 31 March 1989, fishing vessel registra...",1
106,"When Quezon City was created in 1939, the foll...",1


In [None]:
from datasets import Dataset, DatasetDict


val_english = Dataset.from_pandas(df_val_english_merged)
train_english = Dataset.from_pandas(df_train_english_merged)
val_indonesian = Dataset.from_pandas(df_val_indonesian_merged)
train_indonesian = Dataset.from_pandas(df_train_indonesian_merged)
val_arabic = Dataset.from_pandas(df_val_arabic_merged)
train_arabic = Dataset.from_pandas(df_train_arabic_merged)
val_bengali = Dataset.from_pandas(df_val_bengali_merged)
train_bengali = Dataset.from_pandas(df_train_bengali_merged)


# Define the dataset dictionary
dataset_eng = DatasetDict({"train": train_english, "validation": val_english})
dataset_indonesian = DatasetDict({"train": train_indonesian, "validation": val_indonesian})
dataset_arabic = DatasetDict({"train": train_arabic, "validation": val_arabic})
dataset_bengali = DatasetDict({"train": train_bengali, "validation": val_bengali})

In [None]:
def get_train_features(tokenizer, samples):
  '''
  Tokenizes all of the text in the given samples, splittling inputs that are too long for our model
  across multiple features. Finds the token offsets of the answers, which serve as the labels for
  our inputs.
  '''
  batch = tokenizer.batch_encode_plus(
        [q for q in samples['text']],
        padding='max_length',
        truncation='only_second',
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True
    )

  # Get a list which maps the input features index to their original index in the
  # samples list (for split inputs). E.g. if our batch size is 4 and the second sample
  # is split into 3 inputs because it is very large, sample_mapping would look like
  # [0, 1, 1, 1, 2, 3]
  sample_mapping = batch.pop('overflow_to_sample_mapping')
  # Get all of the character offsets for each token
  offset_mapping = batch.pop('offset_mapping')



  return batch

def collate_fn(inputs):
  '''
  Defines how to combine different samples in a batch
  '''
  input_ids = torch.tensor([i['input_ids'] for i in inputs])
  attention_mask = torch.tensor([i['attention_mask'] for i in inputs])


  # Truncate to max length
  max_len = max(attention_mask.sum(-1))
  input_ids = input_ids[:,:max_len]
  attention_mask = attention_mask[:,:max_len]

  return {'input_ids': input_ids, 'attention_mask': attention_mask}

In [None]:
tokenized_dataset = dataset_eng['train'].map(partial(get_train_features, tokenizer), batched=True, remove_columns=dataset_eng['train'].column_names)



Map:   0%|          | 0/7389 [00:00<?, ? examples/s]

ValueError: ignored

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=2)
model.cuda()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [None]:
#define parameters for the model
training_args = TrainingArguments(output_dir="my_trainer",
                                  evaluation_strategy="steps",
                                  num_train_epochs=3.0,
                                  per_device_train_batch_size=16,
                                  eval_steps=500
                                  )

In [None]:
# define the compute_metrics function for the trainer
metric_f1 = load_metric('f1')
metric_ac = load_metric('accuracy')

def compute_metrics(eval_pred):
    outputs, labels = eval_pred
    predictions = np.argmax(outputs, axis=-1)
    f1 = metric_f1.compute(predictions=predictions, references=labels)
    ac = metric_ac.compute(predictions=predictions, references=labels)
    return f1 | ac

In [None]:
# define the trainer object
trainer_eng = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

In [None]:
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
epochs = 4
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)




In [None]:
# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0  # Initialize the total loss for the epoch

    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}"):
      # takes inputs and attention masks
        inputs = batch[:2]
        print('inputs:', inputs)
      # takes labels
        labels = batch[2]
        print('labels:', labels)

        model.zero_grad()
        print('model:', model)
        outputs = model(*inputs, labels=labels)
        print('outputs', outputs)
        loss = outputs.loss
        print('loss:', loss)
        total_loss += loss.item()  # Accumulate the loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        break
    average_loss = total_loss / len(train_dataloader)  # Compute the average loss for the epoch

    model.eval()
    predictions = []
    true_labels = []
    for batch in tqdm(val_dataloader, desc=f"Evaluating Epoch {epoch + 1}"):
        inputs = batch[:2]
        labels = batch[2]
        with torch.no_grad():
            outputs = model(*inputs)
        logits = outputs.logits
        predictions.extend(logits.argmax(dim=1).tolist())
        true_labels.extend(labels.tolist())
        break

    accuracy = accuracy_score(true_labels, predictions)
    report = classification_report(true_labels, predictions, target_names=["Not Answerable", "Answerable"])
    print(f"Epoch {epoch + 1} - Accuracy: {accuracy:.4f} - Average Loss: {average_loss:.4f}")
    print(report)


Epoch 1:   0%|          | 0/231 [00:00<?, ?it/s]

inputs: [tensor([[  101, 10167, 10151,  ...,     0,     0,     0],
        [  101, 20469, 16025,  ...,     0,     0,     0],
        [  101, 10117, 10684,  ...,     0,     0,     0],
        ...,
        [  101, 11301, 10105,  ..., 41784, 56082,   102],
        [  101, 33939, 15381,  ..., 15459, 13034,   102],
        [  101, 21208, 10124,  ...,     0,     0,     0]], device='cuda:0'), tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')]
labels: tensor([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1,
        0, 0, 1, 1, 0, 0, 1, 0], device='cuda:0')
model: BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): E

Epoch 1:   0%|          | 0/231 [00:00<?, ?it/s]
Evaluating Epoch 1:   0%|          | 0/31 [00:00<?, ?it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1 - Accuracy: 0.4375 - Average Loss: 0.0026
                precision    recall  f1-score   support

Not Answerable       0.00      0.00      0.00         0
    Answerable       1.00      0.44      0.61        32

      accuracy                           0.44        32
     macro avg       0.50      0.22      0.30        32
  weighted avg       1.00      0.44      0.61        32



Epoch 2:   0%|          | 0/231 [00:00<?, ?it/s]

inputs: [tensor([[  101, 15006, 28849,  ...,     0,     0,     0],
        [  101, 12610, 10105,  ...,     0,     0,     0],
        [  101, 54127, 25019,  ..., 10238, 19423,   102],
        ...,
        [  101, 10117, 11324,  ...,     0,     0,     0],
        [  101, 10882, 10105,  ..., 52152, 10108,   102],
        [  101, 21980, 10134,  ...,     0,     0,     0]], device='cuda:0'), tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')]
labels: tensor([1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1,
        1, 1, 1, 1, 1, 0, 0, 1], device='cuda:0')
model: BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): E

Epoch 2:   0%|          | 0/231 [00:00<?, ?it/s]
Evaluating Epoch 2:   0%|          | 0/31 [00:00<?, ?it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2 - Accuracy: 0.4375 - Average Loss: 0.0024
                precision    recall  f1-score   support

Not Answerable       0.00      0.00      0.00         0
    Answerable       1.00      0.44      0.61        32

      accuracy                           0.44        32
     macro avg       0.50      0.22      0.30        32
  weighted avg       1.00      0.44      0.61        32



Epoch 3:   0%|          | 0/231 [00:00<?, ?it/s]

inputs: [tensor([[  101, 10167, 10105,  ...,   119,   164,   102],
        [  101, 21230, 39782,  ...,     0,     0,     0],
        [  101, 25059, 26134,  ..., 26134, 10537,   102],
        ...,
        [  101, 11301, 13677,  ...,     0,     0,     0],
        [  101, 29981, 21187,  ...,     0,     0,     0],
        [  101, 10117, 84104,  ...,     0,     0,     0]], device='cuda:0'), tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')]
labels: tensor([0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1,
        1, 1, 1, 0, 1, 0, 0, 0], device='cuda:0')
model: BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): E

Epoch 3:   0%|          | 0/231 [00:00<?, ?it/s]
Evaluating Epoch 3:   0%|          | 0/31 [00:00<?, ?it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3 - Accuracy: 0.4375 - Average Loss: 0.0026
                precision    recall  f1-score   support

Not Answerable       0.00      0.00      0.00         0
    Answerable       1.00      0.44      0.61        32

      accuracy                           0.44        32
     macro avg       0.50      0.22      0.30        32
  weighted avg       1.00      0.44      0.61        32



Epoch 4:   0%|          | 0/231 [00:00<?, ?it/s]

inputs: [tensor([[  101, 12716, 11939,  ..., 11358,   166,   102],
        [  101, 10117, 11486,  ...,     0,     0,     0],
        [  101, 10882, 10455,  ...,   119,   164,   102],
        ...,
        [  101, 10167, 11944,  ..., 10551, 18866,   102],
        [  101,   138, 43477,  ...,     0,     0,     0],
        [  101,   138, 10799,  ...,     0,     0,     0]], device='cuda:0'), tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')]
labels: tensor([0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1,
        1, 1, 0, 0, 1, 0, 1, 1], device='cuda:0')
model: BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): E

Epoch 4:   0%|          | 0/231 [00:00<?, ?it/s]
Evaluating Epoch 4:   0%|          | 0/31 [00:00<?, ?it/s]

Epoch 4 - Accuracy: 0.5000 - Average Loss: 0.0020
                precision    recall  f1-score   support

Not Answerable       0.00      0.00      0.00         0
    Answerable       1.00      0.50      0.67        32

      accuracy                           0.50        32
     macro avg       0.50      0.25      0.33        32
  weighted avg       1.00      0.50      0.67        32




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
def train(
    model: nn.Module,
    train_dl: DataLoader,
    optimizer: torch.optim.Optimizer,
    schedule: LambdaLR,
    n_epochs: int,
    device: torch.device
):
  """
  The main training loop which will optimize a given model on a given dataset
  :param model: The model being optimized
  :param train_dl: The training dataset
  :param optimizer: The optimizer used to update the model parameters
  :param n_epochs: Number of epochs to train for
  :param device: The device to train on
  """

  # Keep track of the loss and best accuracy
  losses = []
  best_acc = 0.0
  pcounter = 0

  # Iterate through epochs
  for ep in range(n_epochs):

    loss_epoch = []

    #Iterate through each batch in the dataloader
    for batch in tqdm(train_dl):
      # VERY IMPORTANT: Make sure the model is in training mode, which turns on
      # things like dropout and layer normalization
      model.train()

      # VERY IMPORTANT: zero out all of the gradients on each iteration -- PyTorch
      # keeps track of these dynamically in its computation graph so you need to explicitly
      # zero them out
      optimizer.zero_grad()

      # Place each tensor on the GPU
      batch = {b: batch[b].to(device) for b in batch}

      # Pass the inputs through the model, get the current loss and logits
      outputs = model(
          input_ids=batch['input_ids'],
          attention_mask=batch['attention_mask'],
          # start_positions=batch['start_tokens'],
          # end_positions=batch['end_tokens']
      )
      loss = outputs['loss']
      losses.append(loss.item())
      loss_epoch.append(loss.item())

      # Calculate all of the gradients and weight updates for the model
      loss.backward()

      # Optional: clip gradients
      #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

      # Finally, update the weights of the model and advance the LR schedule
      optimizer.step()
      scheduler.step()
      #gc.collect()
  return losses

In [None]:
# Create the optimizer
lr=2e-5
n_epochs = 3
weight_decay = 0.01
warmup_steps = 200

no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      'weight_decay': weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
# optimizer = Adam(optimizer_grouped_parameters, lr=1e-3)
# scheduler = None
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    warmup_steps,
    n_epochs * len(train_dataloader)
)



In [None]:
losses = train(
    model,
    train_dataloader,
    optimizer,
    scheduler,
    n_epochs,
    device
)

  0%|          | 0/231 [00:00<?, ?it/s]


TypeError: ignored