In [1]:
! pip install transformers[torch] datasets evaluate

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import transformers

print(transformers.__version__)

4.46.3


## Define the model we fine-tune

In [3]:
model_checkpoint = "bert-base-uncased"
batch_size = 16

# Load Dataset

In [4]:
from datasets import load_dataset

dataset = load_dataset("tau/commonsense_qa")

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 9741
    })
    validation: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 1221
    })
    test: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 1140
    })
})

In [6]:
dataset["train"][0]

{'id': '075e483d21c29a511267ef62bedc0461',
 'question': 'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?',
 'question_concept': 'punishing',
 'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
  'text': ['ignore', 'enforce', 'authoritarian', 'yell at', 'avoid']},
 'answerKey': 'A'}

In [7]:
dataset["validation"][0]

{'id': '1afa02df02c908a558b4036e80242fac',
 'question': 'A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?',
 'question_concept': 'revolving door',
 'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
  'text': ['bank', 'library', 'department store', 'mall', 'new york']},
 'answerKey': 'A'}

Notice, all the answerKey in test dataset are ""

In [8]:
dataset["test"][0]

{'id': '90b30172e645ff91f7171a048582eb8b',
 'question': 'The townhouse was a hard sell for the realtor, it was right next to a high rise what?',
 'question_concept': 'townhouse',
 'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
  'text': ['suburban development',
   'apartment building',
   'bus stop',
   'michigan',
   'suburbs']},
 'answerKey': ''}

#### The following function will show some examples picked randomly in the dataset to show what the data looks like

In [9]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [10]:
show_random_elements(dataset["train"])

Unnamed: 0,id,question,question_concept,choices,answerKey
0,4af505b05f7ebf0d26a8e8005b09d1a5,"The person was having a difficult time understanding the computer program, they were beginning to what?",program,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['stare at computer screen', 'compile', 'get frustrated', 'write code', 'think logically']}",C
1,2f90b3ff30ef328a6e69e5f7100e8c26,"John heard a language that he could not understand. He thought that the door was shut, but he eventually realized that there was no door, and that the light source that was blinding his eyes was very familiar. He was on his back looking at what?",light source,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['sky', 'lamp', 'hallway', 'dard', 'closed room']}",A
2,433c513e36799605b6d80fb3aebc28ea,What does a person sometimes do after they finish secondary education?,person,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['enter college', 'feel lonely', 'cross street', 'pass exams', 'graduate from high school']}",A
3,a9424d6cdb772a61e8d2a61d987faad5,In what place could you find air that has been breathed by many people recently?,air,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['charming', 'space shuttle', 'house', 'train station', 'surface of earth']}",D
4,fbcf10712db557876083b6256c55948b,What will a person have when very happy?,person,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['own car', 'be rich', 'catch cold', 'believe in god', 'experience joy']}",E
5,f4cbde41c795a3f32f0e1ebb243c1b1e,Cats are laying by the refrigerator why do they do that?,cats,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['lie down', 'like heat', 'eating fish', 'drink water', 'come to dinner']}",B
6,555fba06a60a28eeaf3f7e97b46f30db,What are people in a library likely doing?,people,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['talk to each other', 'board ships', 'study books', 'suffer hunger', 'playing games']}",C
7,5cc458ffab4d95c112ccef14ae75504b,Steve was surprised to find an underground map while he was shopping and carrots and paperback books. Where might he have found it?,underground map,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['subway station', 'library', 'super market', 'county engineer's office', 'a friend']}",C
8,9cdc71eb5fc6b8b76b4768564ea3fe1e,"He was eating too much and the doctors warned him, what was his condition?",eating too much,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['gain weight', 'obesity', 'getting fit', 'getting sick', 'gas']}",B
9,ddd867663ebf4ed94d6844b6d1699c95,"The archaeologist was seeing artifacts that he knew were fake, how did he feel?",seeing artifacts,"{'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['pleasure', 'awe inspiring', 'angry', 'thinking', 'painful memories']}",C


#### Define a function to check the ground truth of a specific question

In [11]:
def show_one(example):
    print(f"Question: {example['question']}")
    print(f"  {example['choices']['label'][0]}:  {example['choices']['text'][0]}")
    print(f"  {example['choices']['label'][1]}:  {example['choices']['text'][1]}")
    print(f"  {example['choices']['label'][2]}:  {example['choices']['text'][2]}")
    print(f"  {example['choices']['label'][3]}:  {example['choices']['text'][3]}")
    print(f"  {example['choices']['label'][4]}:  {example['choices']['text'][4]}")
    print(f"\nGround truth: option {example['answerKey']}")

In [12]:
show_one(dataset["train"][0])

Question: The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?
  A:  ignore
  B:  enforce
  C:  authoritarian
  D:  yell at
  E:  avoid

Ground truth: option A


In [13]:
show_one(dataset["train"][12])

Question: Johnny sat on a bench and relaxed after doing a lot of work on his hobby.  Where is he?
  A:  state park
  B:  bus depot
  C:  garden
  D:  gym
  E:  rest area

Ground truth: option C


# Data Processing

In [14]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

Test the pretrained tokenizer

In [15]:
tokenizer("Hello, this is a sentence!", "This is another sentence.")

{'input_ids': [101, 7592, 1010, 2023, 2003, 1037, 6251, 999, 102, 2023, 2003, 2178, 6251, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

#### Define the function for batch encoding, which mainly connect the choices with its question sentance, then encode them.

This function works with one or a batch of examples. In the case of a batch of examples, the tokenizer will return a list of lists of lists for each key, which is a list of all examples (here 6), then a list of all choices (5) and a list of input IDs (length varying here since we did not apply any padding)

In [16]:
def preprocess_function(examples):
    # Extract the question stem
    first_sentences = examples["question"]  # List of question stems

    # Extract all the answer texts (choices) from the 'choices' field
    second_sentences = [choice_dict["text"] for choice_dict in examples["choices"]]  # List of lists

    # Flatten the lists for tokenization
    first_sentences = [stem for stem in first_sentences for _ in range(5)]  # Repeat each question 5 times
    second_sentences = [choice for choices in second_sentences for choice in choices]  # Flatten choices

    # Tokenize the question and choices
    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)

    # Group tokenized inputs by example (5 choices per question)
    grouped_inputs = {
        k: [v[i:i + 5] for i in range(0, len(v), 5)]  # Group every 5 entries
        for k, v in tokenized_examples.items()
    }

    return grouped_inputs


Try to work on only 6 data examples to see if it can work corretly

In [17]:
examples = dataset["train"][:6]
features = preprocess_function(examples)
print(len(features["input_ids"]), len(features["input_ids"][0]), [len(x) for x in features["input_ids"][0]])

6 5 [29, 29, 29, 30, 29]


To make sure we didn't do anything wrong when grouping all possibilites and unflattening. We have a look at the decoded inputs for a given example. We will decode the encoded examples to see the sentences.

In [18]:
idx = 3
[tokenizer.decode(features["input_ids"][idx][i]) for i in range(5)]

['[CLS] google maps and other highway and street gps services have replaced what? [SEP] united states [SEP]',
 '[CLS] google maps and other highway and street gps services have replaced what? [SEP] mexico [SEP]',
 '[CLS] google maps and other highway and street gps services have replaced what? [SEP] countryside [SEP]',
 '[CLS] google maps and other highway and street gps services have replaced what? [SEP] atlas [SEP]',
 '[CLS] google maps and other highway and street gps services have replaced what? [SEP] oceans [SEP]']

Then, we compare it with the ground truth from the original dataset

In [19]:
show_one(dataset["train"][3])

Question: Google Maps and other highway and street GPS services have replaced what?
  A:  united states
  B:  mexico
  C:  countryside
  D:  atlas
  E:  oceans

Ground truth: option D


#### They look correct. Then we can go to encode the entire dataset, including our training, validation and testing data.

In [20]:
encoded_dataset = dataset.map(preprocess_function, batched=True)

### Important! some postprocessing to our encoded_dataset
Before using the Trainer API or defining the dataloaders for training loops,
we have to apply a bit of postprocessing to our encoded_dataset, to take care of some things that the Trainer did for us automatically. Specifically, we need to:
1. Remove the columns corresponding to values the model does not expect (like the question, choices and question_concept columns).
2. Rename the column 'answerKey' to 'labels' (because the model expects the argument to be named 'labels').
3. Set the format of the datasets so they return PyTorch tensors instead of lists.

Rename the column 'answerKey' to 'labels'

In [21]:
encoded_dataset = encoded_dataset.rename_column("answerKey", "labels")

In [22]:
encoded_dataset["train"]

Dataset({
    features: ['id', 'question', 'question_concept', 'choices', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 9741
})

Set the format of the datasets so they return PyTorch tensors instead of lists

In [23]:
encoded_dataset.set_format("torch")

Remove the columns corresponding to values the model does not expect

In [24]:
encoded_dataset["train"].column_names

['id',
 'question',
 'question_concept',
 'choices',
 'labels',
 'input_ids',
 'token_type_ids',
 'attention_mask']

In [25]:
encoded_dataset = encoded_dataset.remove_columns([
    'id',
     'question',
     'question_concept',
     'choices',
])

In [26]:
encoded_dataset["train"].column_names

['labels', 'input_ids', 'token_type_ids', 'attention_mask']

### Define DataCollatorForMultipleChoice for batch padding
We need to add batch padding to the tokenized data using data collator.

Hugging Face transformers doesn't have a data collator for multiple choice, so we need to adapt the DataCollatorWithPadding to create a batch of examples. It's more efficient to dynamically pad the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.

DataCollatorForMultipleChoice flattens all the model inputs, applies padding, and then unflattens the results

In [27]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch

@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        
        labels = [feature.pop("labels") for feature in features]

        # Map answerKey (e.g., "A", "B", ...) to numerical indices
        labels = torch.tensor(
            [["A", "B", "C", "D", "E"].index(label) for label in labels],
            dtype=torch.int64
        )

        # Determine batch size and number of choices
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])

        # Flatten features for tokenization
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)]
            for feature in features
        ]
        flattened_features = sum(flattened_features, [])  # Flatten the list of lists

        # Apply padding to the flattened features
        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        # Un-flatten to restore batch structure (batch_size, num_choices, sequence_length)
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}

        # Add back the labels as a tensor
        # batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        batch["labels"] = labels

        return batch


When called on a list of examples, it will flatten all the inputs/attentions masks etc. in big lists that it will pass to the tokenizer.pad method. This will return a dictionary with big tensors (of shape (batch_size * 5) x seq_length) that we then unflatten.

We can check this data collator works on a list of features, we just have to make sure to remove all features that are not inputs accepted by our model (something the Trainer will do automatically for us after)

In [28]:
accepted_keys = ["input_ids", "attention_mask", "labels"]
# pick out only 10 data examples
features = [{k: v for k, v in encoded_dataset["train"][i].items() if k in accepted_keys} for i in range(10)]
collator = DataCollatorForMultipleChoice(tokenizer)
batch = collator(features)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Check the data collator works on a question

In [29]:
# The example 7, and its 5 combinations
[tokenizer.decode(batch["input_ids"][7][i].tolist()) for i in range(5)]

['[CLS] the forgotten leftovers had gotten quite old, he found it covered in mold in the back of his what? [SEP] carpet [SEP] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] the forgotten leftovers had gotten quite old, he found it covered in mold in the back of his what? [SEP] refrigerator [SEP] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] the forgotten leftovers had gotten quite old, he found it covered in mold in the back of his what? [SEP] breadbox [SEP] [PAD] [PAD] [PAD]',
 '[CLS] the forgotten leftovers had gotten quite old, he found it covered in mold in the back of his what? [SEP] fridge [SEP] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] the forgotten leftovers had gotten quite old, he found it covered in mold in the back of his what? [SEP] coach [SEP] [PAD] [PAD] [PAD] [PAD]']

Compare it with the ground truth

In [30]:
show_one(dataset["train"][7])

Question: The forgotten leftovers had gotten quite old, he found it covered in mold in the back of his what?
  A:  carpet
  B:  refrigerator
  C:  breadbox
  D:  fridge
  E:  coach

Ground truth: option B


# Fine-tune the BERT model - with Trainer API
Then we should download the pretrained model and fine-tune it on our commonsense QA dataset. Since all our task is about mutliple choice, we use the AutoModelForMultipleChoice class. Like with the tokenizer, the from_pretrained method will download and cache the model for us.

In [31]:
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

model = AutoModelForMultipleChoice.from_pretrained(model_checkpoint)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Evaluation function for training

In [32]:
import evaluate

# Load the accuracy metric
accuracy = evaluate.load("accuracy")

In [33]:
import numpy as np

def compute_metrics(eval_pred):
    # Unpack predictions and labels
    predictions, labels = eval_pred

    # Get the index of the highest logit for each example
    predictions = np.argmax(predictions, axis=1)

    # Compute accuracy
    return accuracy.compute(predictions=predictions, references=labels)

### Fine-tune using Hugging Face Trainer API
Here we define TrainingArguments for the Trainer, which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional. Here we set the evaluation to be done at the end of each epoch, and adjust the learning rate, using the batch_size defined at the top of the notebook and customize the number of epochs for training, as well as the weight decay.

In [34]:
model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
    output_dir=f"{model_name}-finetuned-csQA",
    eval_strategy = "epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

Then we just need to pass all of this along with our datasets to the Hugging Face Trainer API

In [36]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    processing_class=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics,
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Start the fine-tune process and retrain the model on our commonsenseQA dataset

In [37]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.074509,0.578215
2,1.073800,1.197831,0.587224
3,1.073800,1.418299,0.58231




TrainOutput(global_step=915, training_loss=0.8009680576011783, metrics={'train_runtime': 269.3162, 'train_samples_per_second': 108.508, 'train_steps_per_second': 3.397, 'total_flos': 2979924348904950.0, 'train_loss': 0.8009680576011783, 'epoch': 3.0})

# Fine-tune the BERT model - with self-defined training loops

### Define the dataloaders
We need to define the dataloaders that we will use to iterate over batches. Before that, we need to get a instance of the DataCollatorForMultipleChoice we defined earlier, which will be used to defined the dataloaders.

In [38]:
data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)

In [39]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    encoded_dataset["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    encoded_dataset["validation"], batch_size=batch_size, collate_fn=data_collator
)

To quickly check there is no mistake in the data processing, we can inspect a batch like this.

The shapes will probably be slightly different after each time running the code, since we set shuffle=True for the training dataloader and we are padding to the maximum length inside the batch.

In [40]:
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

{'input_ids': torch.Size([16, 5, 50]),
 'token_type_ids': torch.Size([16, 5, 50]),
 'attention_mask': torch.Size([16, 5, 50]),
 'labels': torch.Size([16])}

### Load the model

In [41]:
from transformers import AutoModelForMultipleChoice

model = AutoModelForMultipleChoice.from_pretrained(model_checkpoint)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


To make sure that everything will go smoothly during training, we pass a batch to this model for checking.
As stated in the Hugging Face documentation, All Huggin Face Transformers models will return the loss when labels are provided, and we also get the logits (two for each input in our batch, so a tensor of size 8 x 5 in our case).

In [42]:
outputs = model(**batch)
print(outputs.loss, outputs.logits.shape)

tensor(1.5778, grad_fn=<NllLossBackward0>) torch.Size([16, 5])


### Define the optimizer
We will use the same defaults as the Hugging Face Trainer API, where the optimizer used by the Trainer is AdamW.

In [43]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

### Define the learning rate scheduler
The learning rate scheduler used by default is just a linear decay from the maximum value (5e-5) to 0. To properly define it, we need to know the number of training steps we will take, which is the number of epochs we want to run multiplied by the number of training batches (which is the length of our training dataloader). The Trainer uses three epochs by default, so we will follow that

In [44]:
from transformers import get_scheduler

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

1827


### Define the training loop

Enable distributed training on multiple GPUs

In [45]:
from accelerate import Accelerator

# instantiates an Accelerator object that will look at the environment and initialize the proper distributed setup.
accelerator = Accelerator()

# this will wrap those objects in the proper container to make sure our distributed training works as intended.
train_dl, eval_dl, model, optimizer = accelerator.prepare(
    train_dataloader, eval_dataloader, model, optimizer
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


**The training loop**

In [46]:
from tqdm.auto import tqdm

# add a progress bar over our number of training steps
progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dl:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

  0%|          | 0/1827 [00:00<?, ?it/s]

### Define the evaluation loop
Metrics can actually accumulate batches for us as we go over the prediction loop with the method add_batch(). So, once we have accumulated all the batches, we can get the final result with metric.compute().

In [47]:
import evaluate

metric = evaluate.load("accuracy")

model.eval()
for batch in eval_dl:
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

{'accuracy': 0.5798525798525799}

### important! The code above is still using the single GPU!
To utilize the multiple GPU, we need to rather putting them into a train.py script and run the command `accelerate config` and `accelerate launch train.py`, or using the `notebook_launcher` if we are using multiple GPU inside the notebook

However, if we are doing it here, it always give the following error:
ValueError: To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized inside your training function. Restart your notebook and make sure no cells initializes an `Accelerator`.

**Important!!!**
**Important!!!**
**Important!!!**


As it requires the `Accelerator` only show up in the training_function, and not in any other cells, even not allowed to exist in comments, so we move the version of self-defined training loops with multiple GPU into a separate notebook called `BERT_finetune_CSQA_notebook_launcher.ipynb`. Please refer to that notebook for our version of self-defined training loops with multiple GPU.