In [1]:
!pip install --quiet cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

[K     |████████████████████████████████| 144.6MB 85kB/s 
[K     |████████████████████████████████| 61kB 2.6MB/s 
[31mERROR: earthengine-api 0.1.264 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m
[?25h

In [2]:
!pip install --quiet lineflow
!pip install --quiet transformers
!pip install --quiet pytorch-lightning
!pip install --quiet json_lines

# Albert requires SentencePiece
!pip install --quiet SentencePiece

  Building wheel for lineflow (setup.py) ... [?25l[?25hdone
  Building wheel for arrayfiles (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 2.3MB 4.9MB/s 
[K     |████████████████████████████████| 901kB 42.8MB/s 
[K     |████████████████████████████████| 3.3MB 41.0MB/s 
[K     |████████████████████████████████| 808kB 7.0MB/s 
[K     |████████████████████████████████| 829kB 43.3MB/s 
[K     |████████████████████████████████| 645kB 38.0MB/s 
[K     |████████████████████████████████| 112kB 44.7MB/s 
[K     |████████████████████████████████| 276kB 39.7MB/s 
[K     |████████████████████████████████| 1.3MB 38.4MB/s 
[K     |████████████████████████████████| 296kB 39.1MB/s 
[K     |████████████████████████████████| 143kB 45.9MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
[31mERROR: earthengine-api 0.1.264 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0

In [3]:
from typing import Dict
from pathlib import Path
from functools import partial
from collections import OrderedDict
from argparse import ArgumentParser

import lineflow as lf
from transformers import AlbertForMultipleChoice, AlbertTokenizer, AdamW
import pytorch_lightning as pl

import torch
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
import json_lines
import pickle



In [4]:
MAX_LEN = 256
NUM_LABELS = 4
label_map = {"A": 0, "B": 1, "C": 2, "D": 3}
BATCH_SIZE = 8

In [8]:
def load_dataloader_from_cache(cachedir :str):
    cachedir = Path(cachedir)

    train_file_name = "train_race.cache"
    train_path = Path(cachedir / train_file_name)
    if train_path.exists():
        print(f'Loading data from {train_file_name}...')
        with train_path.open('rb') as f:
            train_cache = pickle.load(f)

    train_dataloader = DataLoader(
            lf.core.CacheDataset(train_cache),
            batch_size=BATCH_SIZE
            )

    val_file_name = "val_race.cache"
    val_path = Path(cachedir / val_file_name)
    if val_path.exists():
        print(f'Loading data from {val_file_name}...')
        with val_path.open('rb') as f:
            val_cache = pickle.load(f)

    val_dataloader = DataLoader(
            lf.core.CacheDataset(val_cache),
            batch_size=BATCH_SIZE
            )

    
    test_file_name = "test_race.cache"
    test_path = Path(cachedir / test_file_name)
    if test_path.exists():
        print(f'Loading data from {test_file_name}...')
        with test_path.open('rb') as f:
            test_cache = pickle.load(f)

    test_dataloader = DataLoader(
            lf.core.CacheDataset(test_cache),
            batch_size=BATCH_SIZE
            )

    return train_dataloader, val_dataloader, test_dataloader

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
!cd /content/drive/MyDrive/RACE/AlbertCache/ && ls

test_race.cache  train_race.cache  val_race.cache


In [12]:
train_dataloader, val_dataloader, test_dataloader = load_dataloader_from_cache('/content/drive/MyDrive/RACE/AlbertCache')

Loading data from train_race.cache...
Loading data from val_race.cache...
Loading data from test_race.cache...


In [13]:
print(len(train_dataloader))
print(len(val_dataloader))
print(len(test_dataloader))

10984
611
617


In [14]:
sample = next(iter(test_dataloader))

In [15]:
# type of sample
print(type(sample))
# keys of sample
print(sample.keys())
# ids of sample
print(sample['id'])
# label of sample
print(sample['label'])

<class 'dict'>
dict_keys(['id', 'label', 'input_ids', 'attention_mask', 'token_type_ids'])
['middle3797.txt', 'middle3797.txt', 'middle3797.txt', 'middle3797.txt', 'middle3474.txt', 'middle3474.txt', 'middle3474.txt', 'middle3474.txt']
tensor([0, 1, 1, 3, 1, 3, 0, 2])


In [16]:
# tokenised context and question
print(sample['input_ids'][0].size())
print(sample['input_ids'][0][0])

torch.Size([4, 128])
tensor([   2,  382,   21, 5825,   23,   19,   21, 6257,   13,    9,  651,   21,
         254,  169,  296,   34,    9,   28,  651,   28,   14, 5825,  441,   14,
         169,   15,   14, 5825,  260,   20, 3687,    9,   13,    7, 6744,  187,
        2247,  187,    7,   14, 5825,  227,    9,   13,    7, 6744,   15,  408,
          55,   70,    9,    7,   13,    7,  251,   15,    7,   87,   14,  254,
         169,    9,   13,    7,  821,   31,  107,   15,   42,  129, 2749,   55,
           9,    7,   13,    7,   49,  129,   52, 2749,   42,   15,    7,   14,
        5825,   87,    9,   13,    7, 6744,  408,   55,   70,    9,    7,   14,
         254,  169, 1570,   14,    3,   76,   14, 5825,  441,   14,  169,   15,
          24,  260,   20, 3687,   15,  185,   24,  417,   14,  169,   20,  448,
          61,   70,   16,   14, 6257,   13,    9,    3])


In [17]:
tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2", do_lower_case=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760289.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1312669.0, style=ProgressStyle(descript…




In [18]:
de_tokenizer = tokenizer
de_tokenizer.decode(sample['input_ids'][0][1])

'[CLS] once a tiger was in a cage. soon a good man went by. as soon as the tiger saw the man, the tiger began to cry. "please! please!" the tiger called. "please, let me out." "no," said the good man. "if i do, you will eat me." "i will not eat you," the tiger said. "please let me out." the good man believed the tiger. he opened the[SEP] when the tiger saw the man, he began to cry, because he wanted to eat the man.[SEP]'

In [20]:
from pytorch_lightning.metrics import functional as FM
from pytorch_lightning.callbacks import ModelCheckpoint

In [19]:
class Model(pl.LightningModule):

    def __init__(self):
        super(Model, self).__init__()

        model = AlbertForMultipleChoice.from_pretrained("albert-base-v2", num_labels=NUM_LABELS)
        self.model = model

        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader
        self._test_dataloader = test_dataloader

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        weight_decay = 0.0
        adam_epsilon = 1e-8

        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                'weight_decay': weight_decay
                },
            {
                'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
                }
            ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, eps=adam_epsilon)

        return optimizer

    def training_step(self, batch, batch_idx):
        labels = batch["label"]
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]

        outputs = self.model(
                input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                labels=labels
                )
        
        labels_hat = torch.argmax(outputs.logits, dim=1)

        # print(labels.size())

        acc = FM.accuracy(labels_hat, labels)

        self.log('train_loss', outputs.loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)

        return outputs.loss
  
    def validation_step(self, batch, batch_idx):
        labels = batch["label"]
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]

        outputs = self.model(
                input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                labels=labels
                )
        
        labels_hat = torch.argmax(outputs.logits, dim=1)

        acc = FM.accuracy(labels_hat, labels)

        self.log('val_loss', outputs.loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
        self.log('val_acc', acc, on_epoch=True, on_step=True, prog_bar=True, logger=True)
                
        return acc

    def test_step(self, batch, batch_idx):
        acc = self.validation_step(batch, batch_idx)
        self.log('test_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def train_dataloader(self):
        return self._train_dataloader

    def val_dataloader(self):
        return self._val_dataloader

    def test_dataloader(self):
        return self._test_dataloader

In [22]:
# saves a file like: my/path/albert-openbook-epoch=02-val_loss_epoch=0.32.ckpt
# if you don't want to save checkpoint into google drive, change dirpath!!!
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss_epoch',
    dirpath='/content/drive/MyDrive/RACE/AlbertModel',
    # dirpath='/your/path/',
    filename='e1-albert-race-{epoch:02d}-{val_loss_epoch:.2f}',
    save_top_k=3,
    mode='min',
)

trainer = pl.Trainer(tpu_cores=8, max_epochs=10, callbacks=[checkpoint_callback])

INFO:pytorch_lightning.utilities.distributed:GPU available: False, used: False
INFO:pytorch_lightning.utilities.distributed:TPU available: True, using: 8 TPU cores


In [23]:
pl_model = Model()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=684.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=47376696.0, style=ProgressStyle(descrip…




Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForMultipleChoice: ['predictions.dense.bias', 'predictions.dense.weight', 'predictions.decoder.bias', 'predictions.bias', 'predictions.decoder.weight', 'predictions.LayerNorm.bias', 'predictions.LayerNorm.weight']
- This IS expected if you are initializing AlbertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForMultipleChoice were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on

In [24]:
trainer.fit(pl_model)

INFO:pytorch_lightning.core.lightning:
  | Name  | Type                    | Params
--------------------------------------------------
0 | model | AlbertForMultipleChoice | 11.7 M
--------------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.737    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…






In [26]:
result = trainer.test(test_dataloaders=test_dataloader)


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.0,
 'test_acc_epoch': 0.44870516657829285,
 'val_acc': 0.0,
 'val_acc_epoch': 0.44870516657829285,
 'val_loss': 1.6161187887191772,
 'val_loss_epoch': 1.2099839448928833}
--------------------------------------------------------------------------------




In [27]:
checkpoint_callback.best_model_path

'/content/drive/MyDrive/RACE/AlbertModel/e1-albert-race-epoch=02-val_loss_epoch=1.16.ckpt'

In [28]:
checkpoint = torch.load(checkpoint_callback.best_model_path)

In [29]:
# important !!! 
# please read it !!!
# in general you can use the code below to reload the model, but some keys in checkpoint[state_dict'] is a little
# different from the trainer.model.model.state_dict(). So we have to adjust it manually.

# from transformers import AlbertConfig
# config = AlbertConfig.from_pretrained('albert-base-v2')
# m = AlbertForMultipleChoice.from_pretrained(pretrained_model_name_or_path= None, config=config, state_dict=trainer.model.model.state_dict())

new_checkpoint = {}

for key in checkpoint['state_dict'].keys():
  if 'model' in key:
    new_key = key[6:]
    new_checkpoint[new_key] = checkpoint['state_dict'][key]
  else:
    new_checkpoint[key] = checkpoint['state_dict'][key]

In [30]:
from transformers import AlbertConfig
config = AlbertConfig.from_pretrained('albert-base-v2')
m = AlbertForMultipleChoice.from_pretrained(pretrained_model_name_or_path= None, config=config, state_dict=new_checkpoint)

In [31]:
class TestModel(pl.LightningModule):

    def __init__(self, model, test_dataloader):
        super(TestModel, self).__init__()

        self.model = model
        self._test_dataloader = test_dataloader

    def test_step(self, batch, batch_idx):
        labels = batch["label"]
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]

        outputs = self.model(
                input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                labels=labels
                )
        
        labels_hat = torch.argmax(outputs.logits, dim=1)

        acc = FM.accuracy(labels_hat, labels)
        self.log('test_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def test_dataloader(self):
        return self._test_dataloader

In [32]:
trainer_for_test = pl.Trainer(tpu_cores=8)
model_for_test = TestModel(m, test_dataloader)

INFO:pytorch_lightning.utilities.distributed:GPU available: False, used: False
INFO:pytorch_lightning.utilities.distributed:TPU available: True, using: 8 TPU cores


In [33]:
trainer_for_test.test(model=model_for_test)


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.0, 'test_acc_epoch': 0.44870516657829285}
--------------------------------------------------------------------------------




[{'test_acc': 0.0, 'test_acc_epoch': 0.44870516657829285}]