# Fachpraktikum AI University Stuttgart
Author: Siyu Chen


## Install dependencies

In [1]:
# use tpu in pytorch
!pip install --quiet cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

[K     |████████████████████████████████| 144.6MB 76kB/s 
[K     |████████████████████████████████| 61kB 2.8MB/s 
[31mERROR: earthengine-api 0.1.260 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m
[?25h

In [2]:
!pip install --quiet lineflow
!pip install --quiet transformers
!pip install --quiet pytorch-lightning
!pip install --quiet json_lines

# Albert requires SentencePiece
!pip install --quiet SentencePiece

  Building wheel for lineflow (setup.py) ... [?25l[?25hdone
  Building wheel for arrayfiles (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 2.1MB 5.4MB/s 
[K     |████████████████████████████████| 3.3MB 23.3MB/s 
[K     |████████████████████████████████| 901kB 35.1MB/s 
[K     |████████████████████████████████| 849kB 6.1MB/s 
[K     |████████████████████████████████| 112kB 15.7MB/s 
[K     |████████████████████████████████| 184kB 10.4MB/s 
[K     |████████████████████████████████| 829kB 13.1MB/s 
[K     |████████████████████████████████| 276kB 21.0MB/s 
[K     |████████████████████████████████| 1.3MB 22.7MB/s 
[K     |████████████████████████████████| 143kB 41.8MB/s 
[K     |████████████████████████████████| 296kB 33.5MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
[31mERROR: earthengine-api 0.1.260 has requirement google-api-python-client<2,>=1.12.1, but you'll h

## Import libraries which are needed for fine-tuning

In [3]:
from typing import Dict
from pathlib import Path
import json
from functools import partial
from collections import OrderedDict
from argparse import ArgumentParser

import lineflow as lf
from transformers import AlbertForMultipleChoice, AlbertTokenizer, AdamW
import pytorch_lightning as pl

import torch
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
import json_lines



## Define constant variables

In [4]:
MAX_LEN = 64
NUM_LABELS = 4
label_map = {"A": 0, "B": 1, "C": 2, "D": 3}
BATCH_SIZE = 8

## Dataset

### Useful function to process dataset

In [5]:
def raw_samples_to_dataset(samples):
    datas = []
    for sample in samples:
        _id = sample["id"]
        _article = sample["fact1"]
        _question = sample["question"]['stem']
        _options = []
        _answer = sample["answerKey"]
        for idx in range(len(sample['question']['choices'])): 
            _options.append(sample["question"]['choices'][idx]['text'])

        data = {
                "id": _id,
                "article": _article,
                "options": _options,
                "question": _question,
                "answer": _answer
                }
        datas.append(data)
    return lf.Dataset(datas)


def preprocess(tokenizer: AlbertTokenizer, x: Dict) -> Dict:

    choices_features = []

    option: str
    for option in x["options"]:
        text_a = x["article"]
        text_b = x["question"] + " " + option

        inputs = tokenizer.encode_plus(
                text_a,
                text_b,
                add_special_tokens=True,
                max_length=MAX_LEN
                )
        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
        attention_mask = [1] * len(input_ids)

        pad_token_id = tokenizer.pad_token_id
        padding_length = MAX_LEN - len(input_ids)
        input_ids = input_ids + ([pad_token_id] * padding_length)
        attention_mask = attention_mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([pad_token_id] * padding_length)

        assert len(input_ids) == MAX_LEN, "Error with input length {} vs {}".format(len(input_ids), MAX_LEN)
        assert len(attention_mask) == MAX_LEN, "Error with input length {} vs {}".format(len(attention_mask), MAX_LEN)
        assert len(token_type_ids) == MAX_LEN, "Error with input length {} vs {}".format(len(token_type_ids), MAX_LEN)

        choices_features.append({
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            })

    labels = label_map.get(x["answer"], -1)
    label = torch.tensor(labels).long()

    return {
            "id": x["id"],
            "label": label,
            "input_ids": torch.tensor([cf["input_ids"] for cf in choices_features]),
            "attention_mask": torch.tensor([cf["attention_mask"] for cf in choices_features]),
            "token_type_ids": torch.tensor([cf["token_type_ids"] for cf in choices_features]),
            }


def get_dataloader(tokenizer, datadir: str, cachedir: str = "./"):
    datadir = Path(datadir)
    cachedir = Path(cachedir)
    

    preprocessor = partial(preprocess, tokenizer)

    train_samples = []
    with open(datadir / "train_complete.jsonl") as f:
        for item in json_lines.reader(f):
            train_samples.append(item)
    train = raw_samples_to_dataset(train_samples)
    print(train)
    train_dataloader = DataLoader(
            train.map(preprocessor).save(cachedir / "train_openbook.cache"),
            sampler=RandomSampler(train),
            batch_size=BATCH_SIZE
            )

    val_samples = []
    with open(datadir / "dev_complete.jsonl") as f:
        for item in json_lines.reader(f):
            val_samples.append(item)
    val = raw_samples_to_dataset(val_samples)
    val_dataloader = DataLoader(
            val.map(preprocessor).save(cachedir / "val_openbook.cache"),
            sampler=SequentialSampler(val),
            batch_size=BATCH_SIZE
            )

    
    test_samples = []
    with open(datadir / "test_complete.jsonl") as f:
        for item in json_lines.reader(f):
            test_samples.append(item)
    test = raw_samples_to_dataset(test_samples)
    test_dataloader = DataLoader(
            test.map(preprocessor).save(cachedir / "test_openbook.cache"),
            sampler=SequentialSampler(test),
            batch_size=BATCH_SIZE
            )

    return train_dataloader, val_dataloader, test_dataloader

### Download OpenBook dataset and unzip dataset

#### Download dataset

In [6]:
!ls 
!wget https://ai2-public-datasets.s3.amazonaws.com/open-book-qa/OpenBookQA-V1-Sep2018.zip

sample_data
--2021-05-05 11:38:22--  https://ai2-public-datasets.s3.amazonaws.com/open-book-qa/OpenBookQA-V1-Sep2018.zip
Resolving ai2-public-datasets.s3.amazonaws.com (ai2-public-datasets.s3.amazonaws.com)... 52.218.136.211
Connecting to ai2-public-datasets.s3.amazonaws.com (ai2-public-datasets.s3.amazonaws.com)|52.218.136.211|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1446098 (1.4M) [binary/octet-stream]
Saving to: ‘OpenBookQA-V1-Sep2018.zip’


2021-05-05 11:38:23 (3.75 MB/s) - ‘OpenBookQA-V1-Sep2018.zip’ saved [1446098/1446098]



#### Unzip dataset

In [7]:
!ls 
!unzip OpenBookQA-V1-Sep2018.zip && ls 

OpenBookQA-V1-Sep2018.zip  sample_data
Archive:  OpenBookQA-V1-Sep2018.zip
   creating: OpenBookQA-V1-Sep2018/
   creating: OpenBookQA-V1-Sep2018/Data/
   creating: OpenBookQA-V1-Sep2018/Data/Additional/
  inflating: OpenBookQA-V1-Sep2018/Data/Additional/test_complete.jsonl  
  inflating: OpenBookQA-V1-Sep2018/Data/Additional/train_complete.jsonl  
  inflating: OpenBookQA-V1-Sep2018/Data/Additional/crowdsourced-facts.txt  
  inflating: OpenBookQA-V1-Sep2018/Data/Additional/dev_complete.jsonl  
   creating: OpenBookQA-V1-Sep2018/Data/Main/
  inflating: OpenBookQA-V1-Sep2018/Data/Main/train.jsonl  
  inflating: OpenBookQA-V1-Sep2018/Data/Main/test.jsonl  
  inflating: OpenBookQA-V1-Sep2018/Data/Main/train.tsv  
  inflating: OpenBookQA-V1-Sep2018/Data/Main/dev.tsv  
  inflating: OpenBookQA-V1-Sep2018/Data/Main/dev.jsonl  
  inflating: OpenBookQA-V1-Sep2018/Data/Main/openbook.txt  
  inflating: OpenBookQA-V1-Sep2018/Data/Main/test.tsv  
OpenBookQA-V1-Sep2018  OpenBookQA-V1-Sep2018.zip  sam

In [8]:
!cd OpenBookQA-V1-Sep2018/Data/Additional && ls && pwd

crowdsourced-facts.txt	test_complete.jsonl
dev_complete.jsonl	train_complete.jsonl
/content/OpenBookQA-V1-Sep2018/Data/Additional


### Call functions to handle raw dataset

In [9]:
tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2", do_lower_case=True)
train_dataloader, val_dataloader, test_dataloader = get_dataloader(tokenizer, '/content/OpenBookQA-V1-Sep2018/Data/Additional')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760289.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1312669.0, style=ProgressStyle(descript…

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.



<lineflow.core.Dataset object at 0x7f0594b68610>
Saving data to train_openbook.cache...
Saving data to val_openbook.cache...
Saving data to test_openbook.cache...


In [14]:
print(len(train_dataloader))
print(len(val_dataloader))
print(len(test_dataloader))

620
63
63


## Play around with `DataLoader`

### Get a batch of dataloader

In [15]:
sample = next(iter(test_dataloader))

### Print some information of dataloader

In [16]:
# type of sample
print(type(sample))
# keys of sample
print(sample.keys())
# ids of sample
print(sample['id'])
# label of sample
print(sample['label'])

<class 'dict'>
dict_keys(['id', 'label', 'input_ids', 'attention_mask', 'token_type_ids'])
['8-343', '1129', '880', '7-999', '8-464', '9-794', '9-1163', '9-322']
tensor([1, 0, 2, 2, 2, 2, 2, 1])


In [17]:
# tokenised context and question
print(sample['input_ids'][0].size())
print(sample['input_ids'][0][0])

torch.Size([4, 64])
tensor([    2,   568,   787,  2566,   951,  4047,   875,    20,    44,  4377,
            3,    21,   840,  2846,    20,   799,  7599,   875,    86,    30,
           59,    92,  7829,    21,  2210, 10057,    35,    14,   241,    16,
           14,   159,     9,    75,   699,    84,    66,  3391,    17, 12640,
           15,    59,  4073,    14,   246,   161,    20,  2079,   875,    25,
           20,   233,    91,  1132,  3029,     3,     0,     0,     0,     0,
            0,     0,     0,     0])


### Decode tokenised context and question

In [18]:
de_tokenizer = tokenizer

In [19]:
de_tokenizer.decode(sample['input_ids'][0][1])

'[CLS] using less resources usually causes money to be saved[SEP] a person wants to start saving money so that they can afford a nice vacation at the end of the year. after looking over their budget and expenses, they decide the best way to save money is to quit eating lunch out[SEP]<pad><pad><pad><pad><pad><pad><pad><pad>'

## Fine tuning on AlbertModel

### Set up connection to Google drive in order to save checkpoint

In [10]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Define a model

In [11]:
from pytorch_lightning.metrics import functional as FM
from pytorch_lightning.callbacks import ModelCheckpoint

In [12]:
class Model(pl.LightningModule):

    def __init__(self):
        super(Model, self).__init__()

        model = AlbertForMultipleChoice.from_pretrained("albert-base-v2", num_labels=NUM_LABELS)
        self.model = model

        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader
        self._test_dataloader = test_dataloader

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        weight_decay = 0.0
        adam_epsilon = 1e-8

        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                'weight_decay': weight_decay
                },
            {
                'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
                }
            ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, eps=adam_epsilon)

        return optimizer

    def training_step(self, batch, batch_idx):
        labels = batch["label"]
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]

        outputs = self.model(
                input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                labels=labels
                )
        
        self.log('train_loss', outputs.loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        labels = batch["label"]
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]

        outputs = self.model(
                input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                labels=labels
                )
        
        labels_hat = torch.argmax(outputs.logits, dim=1)

        acc = FM.accuracy(labels_hat, labels)

        self.log('val_loss', outputs.loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
        self.log('val_acc', acc, on_epoch=True, on_step=True, prog_bar=True, logger=True)
                
        return acc

    def test_step(self, batch, batch_idx):
        acc = self.validation_step(batch, batch_idx)
        self.log('test_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def train_dataloader(self):
        return self._train_dataloader

    def val_dataloader(self):
        return self._val_dataloader

    def test_dataloader(self):
        return self._test_dataloader

### Use checkpointcallback to save the model with minimal val_loss

In [13]:
# saves a file like: my/path/albert-openbook-epoch=02-val_loss_epoch=0.32.ckpt
# if you don't want to save checkpoint into google drive, change dirpath!!!
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss_epoch',
    dirpath='/content/drive/My Drive/AlbertModel/',
    # dirpath='/your/path/',
    filename='albert-openbook-{epoch:02d}-{val_loss_epoch:.2f}',
    save_top_k=3,
    mode='min',
)

trainer = pl.Trainer(tpu_cores=8, max_epochs=100, callbacks=[checkpoint_callback])

GPU available: False, used: False
TPU available: True, using: 8 TPU cores


In [14]:
pl_model = Model()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=684.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=47376696.0, style=ProgressStyle(descrip…




Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForMultipleChoice: ['predictions.bias', 'predictions.LayerNorm.weight', 'predictions.LayerNorm.bias', 'predictions.dense.weight', 'predictions.dense.bias', 'predictions.decoder.weight', 'predictions.decoder.bias']
- This IS expected if you are initializing AlbertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForMultipleChoice were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on

### Train the model on TPU

In [15]:
trainer.fit(pl_model)


  | Name  | Type                    | Params
--------------------------------------------------
0 | model | AlbertForMultipleChoice | 11.7 M
--------------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.737    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…






1

### Test model on test dataset and print accuaracy

In [16]:
result = trainer.test(test_dataloaders=test_dataloader)


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.4285714626312256,
 'test_acc_epoch': 0.510617733001709,
 'val_acc': 0.4285714626312256,
 'val_acc_epoch': 0.510617733001709,
 'val_loss': 1.1871181726455688,
 'val_loss_epoch': 1.220659613609314}
--------------------------------------------------------------------------------




### Print the best model path with `checkpoint_callback`

In [17]:
checkpoint_callback.best_model_path

'/content/drive/My Drive/AlbertModel/albert-openbook-epoch=08-val_loss_epoch=1.19.ckpt'

### Save the last model into Google drive and load it again

In [18]:
trainer.model.model.save_pretrained("/content/drive/My Drive/AlbertModel/lastModel")

In [19]:
reload_last_model = AlbertForMultipleChoice.from_pretrained("/content/drive/My Drive/AlbertModel/lastModel")

### Load model from checkpoint which is just saved 

In [20]:
checkpoint = torch.load(checkpoint_callback.best_model_path)

In [21]:
# important !!! 
# please read it !!!
# in general you can use the code below to reload the model, but some keys in checkpoint[state_dict'] is a little
# different from the trainer.model.model.state_dict(). So we have to adjust it manually.

# from transformers import AlbertConfig
# config = AlbertConfig.from_pretrained('albert-base-v2')
# m = AlbertForMultipleChoice.from_pretrained(pretrained_model_name_or_path= None, config=config, state_dict=trainer.model.model.state_dict())

new_checkpoint = {}

for key in checkpoint['state_dict'].keys():
  if 'model' in key:
    new_key = key[6:]
    new_checkpoint[new_key] = checkpoint['state_dict'][key]
  else:
    new_checkpoint[key] = checkpoint['state_dict'][key]

In [22]:
from transformers import AlbertConfig
config = AlbertConfig.from_pretrained('albert-base-v2')
m = AlbertForMultipleChoice.from_pretrained(pretrained_model_name_or_path= None, config=config, state_dict=new_checkpoint)

In [None]:
pred_after_tuning = torch.tensor([])
labels = torch.tensor([])

i = 1
for data_batch in test_dataloader:
    if i %10 == 0:
        print(str(i) + "/"+str(len(test_dataloader)))
    i = i + 1
    input_ids = data_batch["input_ids"]
    attention_mask = data_batch["attention_mask"]
    token_type_ids = data_batch["token_type_ids"]

    outputs = m(input_ids, token_type_ids=token_type_ids,attention_mask=attention_mask)
    pred_after_tuning = torch.cat((pred_after_tuning,torch.argmax(outputs.logits, dim=1)))
    labels = torch.cat((labels, data_batch['label']))

In [24]:
acc = FM.accuracy(pred_after_tuning.to(torch.int16), labels.to(torch.int16))
print("acc score is:")
print(acc)

acc score is:
tensor(0.5020)
