In [1]:
import os
import requests
import tarfile
from tqdm import tqdm

In [3]:
import requests
import logging
import tarfile
import urllib
from tqdm import tqdm_notebook as tqdm

logger = logging.getLogger()

import sys
import os
from pathlib import Path

def download_url(url:str, dest:str, overwrite:bool=True, show_progress=True, 
                 chunk_size=1024*1024, timeout=4, retries=5)->None:
    "Download `url` to `dest` unless it exists and not `overwrite`."
    dest = Path(dest)/os.path.basename(url)
    if os.path.exists(dest) and not overwrite: 
        print("File already existing")
        return

    s = requests.Session()
    s.mount('http://',requests.adapters.HTTPAdapter(max_retries=retries))
    u = s.get(url, stream=True, timeout=timeout)
    try: file_size = int(u.headers["Content-Length"])
    except: show_progress = False
    print(f"Downloading {url}")
    with open(dest, 'wb') as f:
        nbytes = 0
        if show_progress: 
            pbar = tqdm(range(file_size), leave=False)
        try:
            for chunk in u.iter_content(chunk_size=chunk_size):
                nbytes += len(chunk)
                if show_progress: pbar.update(nbytes)
                f.write(chunk)
        except requests.exceptions.ConnectionError as e:
            print(f"Download failed after {retries} retries.")
            import sys;sys.exit(1)
        finally:
            return str(dest)
        
def untar(file_path, dest:str):
    print(f"Untar {os.path.basename(file_path)} to {dest}")
    with tarfile.open(file_path) as tf:
        tf.extractall(path=str(dest))
    os.remove(file_path)
    return str(dest)

In [4]:
from pathlib import Path

DATA_DIR = Path('./data').resolve()
url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

In [5]:

file_path = download_url(url, '/tmp', overwrite=True)
untar(file_path, DATA_DIR)

Downloading https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=84125825.0), HTML(value='')))

Untar aclImdb_v1.tar.gz to C:\Users\Алексей\Desktop\Монтаж\Jupyter notebook\Notebooks\Transformer_001\data


'C:\\Users\\Алексей\\Desktop\\Монтаж\\Jupyter notebook\\Notebooks\\Transformer_001\\data'

In [6]:
import pandas as pd
import re

def clean_html(raw):
    cleanr = re.compile('<.*?>')
    clean = re.sub(cleanr, '  ', raw)
    return re.sub(' +', ' ', clean)


def read_imdb(imdb_dir: str, text_col='text', label_col='label'):
    "Read imdb data to {'label', 'text'} format"
    imdb_dir = Path(imdb_dir)
    data = {}
    for t in ['train', 'test']:
        texts, labels = [], []
        for p in ['pos', 'neg']:
            for file in tqdm((imdb_dir/'train'/p).glob("*.txt"), desc=f'reading {t}/{p}'):
                with open(file, 'r') as fin:
                    text = fin.readlines()[0].replace(r'\n', ' ')
                    text = clean_html(text).strip()
                    texts +=  [text]
                    labels += [0 if p=='neg' else 1]
        df = pd.DataFrame(
        {label_col: labels, text_col: texts})
        data[t] = df.sample(frac=1)

    return tuple(data.values())

def save_bertify(df: pd.DataFrame, fname: str):
    # https://medium.com/swlh/a-simple-guide-on-using-bert-for-text-classification-bbf041ac8d04
    fname = str(fname)
    assert fname.endswith('.tsv'), "fname has to be a tsv file!"
    
    df_bert = pd.DataFrame({
        'id': range(len(df)),
        'label': df['label'],
        'alpha': ['a'] * len(df),
        'text': df['text']})
    df_bert.to_csv(fname, sep='\t', index=False, header=False)
    print(f"saved {len(df_bert)} bertified samples to {fname}")

In [9]:
import pandas as pd
import re

# text and label column names
TEXT_COL = "text"
LABEL_COL = "label"

def clean_html(text: str):
    "remove html tags and whitespaces"
    cleanr = re.compile('<.*?>')
    text = re.sub(cleanr, '  ', text)
    return re.sub(' +', ' ', text)

def read_imdb(data_dir, max_lengths={"train": None, "test": None}):
    datasets = {}
    for t in ["train", "test"]:
        df = pd.read_csv(os.path.join(data_dir, f"imdb5k_{t}.csv"))
        if max_lengths.get(t) is not None:
            df = df.sample(n=max_lengths.get(t))
            df[TEXT_COL] = df[TEXT_COL].apply(lambda t: clean_html(t))
        datasets[t] = df
    return datasets    

# read data
datasets = read_imdb(IMDB_DIR)

# list of labels
labels = list(set(datasets["train"][LABEL_COL].tolist()))

# labels to integers mapping
label2int = {label: i for i, label in enumerate(labels)}

In [10]:
import torch
from torch.utils.data import TensorDataset, random_split, DataLoader
import numpy as np
import warnings
from tqdm import tqdm_notebook as tqdm
from typing import Tuple

NUM_MAX_POSITIONS = 256
BATCH_SIZE = 32

class TextProcessor: 
    # special tokens for classification and padding
    CLS = '[CLS]'
    PAD = '[PAD]'
    
    def __init__(self, tokenizer, label2id: dict, num_max_positions:int=512):
        self.tokenizer=tokenizer
        self.label2id = label2id
        self.num_labels = len(label2id)
        self.num_max_positions = num_max_positions     
    
    def process_example(self, example: Tuple[str, str]):
        "Convert text (example[0]) to sequence of IDs and label (example[1] to integer"
        assert len(example) == 2
        label, text = example[0], example[1]
        assert isinstance(text, str)
        tokens = self.tokenizer.tokenize(text)

        # truncate if too long
        if len(tokens) >= self.num_max_positions:
            tokens = tokens[:self.num_max_positions-1] 
            ids =  self.tokenizer.convert_tokens_to_ids(tokens) + [self.tokenizer.vocab[self.CLS]]
        # pad if too short
        else:
            pad = [self.tokenizer.vocab[self.PAD]] * (self.num_max_positions-len(tokens)-1)
            ids =  self.tokenizer.convert_tokens_to_ids(tokens) + [self.tokenizer.vocab[self.CLS]] + pad
        
        return ids, self.label2id[label]

# download the 'bert-base-cased' tokenizer
from pytorch_transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

# initialize a TextProcessor
processor = TextProcessor(tokenizer, label2int, num_max_positions=NUM_MAX_POSITIONS)



  0%|                                                | 0/213450 [00:00<?, ?B/s][A[A

  8%|██▋                              | 17408/213450 [00:00<00:02, 75077.96B/s][A[A

 16%|█████▍                           | 34816/213450 [00:00<00:02, 75078.80B/s][A[A

 24%|████████                         | 52224/213450 [00:00<00:01, 83146.30B/s][A[A

 40%|████████████▉                   | 86016/213450 [00:00<00:01, 103559.12B/s][A[A

 49%|███████████████▋                | 104448/213450 [00:01<00:01, 83521.05B/s][A[A

 65%|████████████████████▉           | 139264/213450 [00:01<00:00, 80368.49B/s][A[A

 73%|███████████████████████▍        | 156672/213450 [00:01<00:00, 90487.83B/s][A[A

 82%|██████████████████████████      | 174080/213450 [00:01<00:00, 79446.99B/s][A[A

100%|███████████████████████████████| 213450/213450 [00:02<00:00, 101397.36B/s][A[A


In [11]:

from collections import namedtuple
import torch

LOG_DIR = "./logs/"
CACHE_DIR = "./cache/"

device = "cuda" if torch.cuda.is_available() else "cpu"

FineTuningConfig = namedtuple('FineTuningConfig',
      field_names="num_classes, dropout, init_range, batch_size, lr, max_norm,"
                  "n_warmup, valid_pct, gradient_acc_steps, device, log_dir, dataset_cache")

finetuning_config = FineTuningConfig(
                2, 0.1, 0.02, BATCH_SIZE, 6.5e-5, 1.0,
                10, 0.1, 1, device, LOG_DIR, 
                CACHE_DIR+'dataset_cache.bin')

finetuning_config

FineTuningConfig(num_classes=2, dropout=0.1, init_range=0.02, batch_size=32, lr=6.5e-05, max_norm=1.0, n_warmup=10, valid_pct=0.1, gradient_acc_steps=1, device='cpu', log_dir='./logs/', dataset_cache='./cache/dataset_cache.bin')

In [12]:
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from itertools import repeat

num_cores = cpu_count()

def process_row(processor, row):
    return processor.process_example((row[1][LABEL_COL], row[1][TEXT_COL]))

def create_dataloader(df: pd.DataFrame,
                      processor: TextProcessor,
                      batch_size: int = 32,
                      shuffle: bool = False,
                      valid_pct: float = None,
                      text_col: str = "text",
                      label_col: str = "label"):
    "Process rows in `df` with `num_cores` workers using `processor`."

    with ProcessPoolExecutor(max_workers=num_cores) as executor:
        result = list(
            tqdm(executor.map(process_row,
                              repeat(processor),
                              df.iterrows(),
                              chunksize=len(df) // 10),
                 desc=f"Processing {len(df)} examples on {num_cores} cores",
                 total=len(df)))

    features = [r[0] for r in result]
    labels = [r[1] for r in result]

    dataset = TensorDataset(torch.tensor(features, dtype=torch.long),
                            torch.tensor(labels, dtype=torch.long))

    if valid_pct is not None:
        valid_size = int(valid_pct * len(df))
        train_size = len(df) - valid_size
        valid_dataset, train_dataset = random_split(dataset,
                                                    [valid_size, train_size])
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=False)
        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
        return train_loader, valid_loader

    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             num_workers=0,
                             shuffle=shuffle,
                             pin_memory=torch.cuda.is_available())
    return data_loader

In [15]:
import torch.nn as nn
import torch

class Transformer(nn.Module):
    "Adopted from https://github.com/huggingface/naacl_transfer_learning_tutorial"
  
    def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, 
                 num_heads, num_layers, dropout, causal):
        super().__init__()
        self.causal = causal
        self.tokens_embeddings = nn.Embedding(num_embeddings, embed_dim)
        self.position_embeddings = nn.Embedding(num_max_positions, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList()
        self.layer_norms_1, self.layer_norms_2 = nn.ModuleList(), nn.ModuleList()
        for _ in range(num_layers):
            self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout))
            self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                    nn.ReLU(),
                                                    nn.Linear(hidden_dim, embed_dim)))
            self.layer_norms_1.append(nn.LayerNorm(embed_dim, eps=1e-12))
            self.layer_norms_2.append(nn.LayerNorm(embed_dim, eps=1e-12))

    def forward(self, x, padding_mask=None):
        "x, padding_mask - shape: [S, B]"
        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
        h = self.tokens_embeddings(x)
        h = h + self.position_embeddings(positions).expand_as(h)
        h = self.dropout(h)
        attn_mask = None
        if self.causal:
            attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
            attn_mask = torch.triu(attn_mask, diagonal=1)

        for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.layer_norms_1, self.attentions,
                                                                       self.layer_norms_2, self.feed_forwards):
            h = layer_norm_1(h)
            x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False, key_padding_mask=padding_mask)
            x = self.dropout(x)
            h = x + h
            h = layer_norm_2(h)
            x = feed_forward(h)
            x = self.dropout(x)
            h = x + h
        return h


class TransformerWithClfHead(nn.Module):
    "Adopted from https://github.com/huggingface/naacl_transfer_learning_tutorial"
    
    def __init__(self, config, fine_tuning_config):
        super().__init__()
        self.config = fine_tuning_config
        self.transformer = Transformer(config.embed_dim, config.hidden_dim, config.num_embeddings,
                                       config.num_max_positions, config.num_heads, config.num_layers,
                                       fine_tuning_config.dropout, causal=not config.mlm)
        
        self.classification_head = nn.Linear(config.embed_dim, fine_tuning_config.num_classes)
        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x, clf_tokens_mask, clf_labels=None, padding_mask=None):
        hidden_states = self.transformer(x, padding_mask)
        clf_tokens_states = (hidden_states * clf_tokens_mask.unsqueeze(-1).float()).sum(dim=0)
        clf_logits = self.classification_head(clf_tokens_states)

        if clf_labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(clf_logits.view(-1, clf_logits.size(-1)), clf_labels.view(-1))
            return clf_logits, loss
        return clf_logits

In [16]:
from pytorch_transformers import cached_path

# download pre-trained model and config
state_dict = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
                                    "naacl-2019-tutorial/model_checkpoint.pth"), map_location='cpu')

config = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
                                        "naacl-2019-tutorial/model_training_args.bin"))

# init model: Transformer base + classifier head
#model = TransformerWithClfHead(config=config, fine_tuning_config=finetuning_config).to(finetuning_config.device)
#model.load_state_dict(state_dict, strict=False)




  0%|                                             | 0/201626725 [00:00<?, ?B/s][A[A

  0%|                             | 1024/201626725 [00:00<20:10:55, 2775.07B/s][A[A

  0%|                            | 17408/201626725 [00:00<14:46:06, 3792.01B/s][A[A

  0%|                            | 34816/201626725 [00:01<10:37:03, 5274.06B/s][A[A

  0%|                             | 51200/201626725 [00:01<7:52:54, 7104.18B/s][A[A

  0%|                             | 54272/201626725 [00:02<7:34:44, 7387.92B/s][A[A

  0%|                             | 68608/201626725 [00:02<5:42:14, 9815.38B/s][A[A

  0%|                             | 71680/201626725 [00:02<6:07:17, 9145.94B/s][A[A

  0%|                           | 104448/201626725 [00:03<4:28:06, 12527.52B/s][A[A

  0%|                           | 121856/201626725 [00:03<3:35:08, 15610.04B/s][A[A

  0%|                           | 139264/201626725 [00:04<2:56:47, 18994.99B/s][A[A

  0%|                           | 174080/

 12%|██▉                      | 23900160/201626725 [00:31<02:21, 1255999.59B/s][A[A

 12%|███                      | 24239104/201626725 [00:31<01:54, 1548051.97B/s][A[A

 12%|███                      | 24440832/201626725 [00:31<02:06, 1401244.76B/s][A[A

 12%|███                      | 24615936/201626725 [00:31<02:05, 1409304.44B/s][A[A

 12%|███                      | 24850432/201626725 [00:32<02:33, 1149898.37B/s][A[A

 12%|███                      | 25194496/201626725 [00:32<02:17, 1286352.64B/s][A[A

 13%|███▏                     | 25346048/201626725 [00:32<02:26, 1202289.12B/s][A[A

 13%|███▎                      | 25483264/201626725 [00:32<03:54, 751505.93B/s][A[A

 13%|███▎                      | 25735168/201626725 [00:33<03:39, 802527.49B/s][A[A

 13%|███▎                      | 25964544/201626725 [00:33<03:32, 825993.92B/s][A[A

 13%|███▍                      | 26243072/201626725 [00:33<03:16, 893933.61B/s][A[A

 13%|███▍                      | 26537984/2

 22%|█████▋                    | 44024832/201626725 [00:55<06:01, 435595.46B/s][A[A

 22%|█████▋                    | 44101632/201626725 [00:55<06:48, 385837.74B/s][A[A

 22%|█████▋                    | 44151808/201626725 [00:56<09:58, 262923.97B/s][A[A

 22%|█████▋                    | 44249088/201626725 [00:56<09:41, 270680.47B/s][A[A

 22%|█████▋                    | 44396544/201626725 [00:56<08:34, 305867.25B/s][A[A

 22%|█████▋                    | 44560384/201626725 [00:57<07:37, 343551.89B/s][A[A

 22%|█████▊                    | 44724224/201626725 [00:58<10:30, 248996.99B/s][A[A

 22%|█████▊                    | 44904448/201626725 [00:58<09:20, 279552.22B/s][A[A

 22%|█████▊                    | 45084672/201626725 [00:59<08:28, 307840.10B/s][A[A

 22%|█████▊                    | 45264896/201626725 [00:59<07:34, 343942.29B/s][A[A

 23%|█████▊                    | 45461504/201626725 [00:59<06:23, 407639.58B/s][A[A

 23%|█████▉                    | 45641728/2

 30%|███████▉                  | 61353984/201626725 [01:24<04:29, 520757.14B/s][A[A

 31%|███████▉                  | 61550592/201626725 [01:24<04:46, 489114.90B/s][A[A

 31%|███████▉                  | 61763584/201626725 [01:25<04:51, 479523.37B/s][A[A

 31%|███████▉                  | 61861888/201626725 [01:25<06:39, 350066.32B/s][A[A

 31%|███████▉                  | 62025728/201626725 [01:26<06:51, 338960.06B/s][A[A

 31%|████████                  | 62124032/201626725 [01:26<08:05, 287046.03B/s][A[A

 31%|████████                  | 62287872/201626725 [01:27<07:24, 313470.89B/s][A[A

 31%|████████                  | 62468096/201626725 [01:27<06:25, 360524.55B/s][A[A

 31%|████████                  | 62648320/201626725 [01:27<05:35, 414269.43B/s][A[A

 31%|████████                  | 62844928/201626725 [01:28<04:57, 466741.59B/s][A[A

 31%|████████                  | 62899200/201626725 [01:28<07:04, 326823.10B/s][A[A

 31%|████████                  | 62992384/2

 41%|██████████▎              | 82980864/201626725 [01:55<01:10, 1693558.77B/s][A[A

 41%|██████████▎              | 83374080/201626725 [01:55<01:05, 1794115.68B/s][A[A

 42%|██████████▍              | 83783680/201626725 [01:55<01:05, 1805461.01B/s][A[A

 42%|██████████▍              | 84193280/201626725 [01:55<01:01, 1899719.63B/s][A[A

 42%|██████████▍              | 84471808/201626725 [01:55<01:10, 1660159.99B/s][A[A

 42%|██████████▌              | 84914176/201626725 [01:56<01:10, 1663193.92B/s][A[A

 42%|██████████▌              | 85143552/201626725 [01:56<01:37, 1195910.31B/s][A[A

 42%|██████████▌              | 85520384/201626725 [01:56<01:38, 1175783.93B/s][A[A

 43%|██████████▋              | 85848064/201626725 [01:57<01:35, 1214070.60B/s][A[A

 43%|██████████▋              | 86208512/201626725 [01:57<01:33, 1240942.34B/s][A[A

 43%|██████████▋              | 86568960/201626725 [01:57<01:37, 1181007.89B/s][A[A

 43%|██████████▊              | 86962176/20

 52%|█████████████            | 105820160/201626725 [02:15<01:41, 941532.18B/s][A[A

 53%|█████████████▏           | 106000384/201626725 [02:15<01:37, 980898.25B/s][A[A

 53%|████████████▋           | 106196992/201626725 [02:15<01:31, 1042970.27B/s][A[A

 53%|████████████▋           | 106377216/201626725 [02:15<01:30, 1050524.11B/s][A[A

 53%|█████████████▏           | 106486784/201626725 [02:15<01:50, 858005.74B/s][A[A

 53%|█████████████▏           | 106672128/201626725 [02:16<01:50, 858049.72B/s][A[A

 53%|█████████████▎           | 106885120/201626725 [02:16<01:45, 899595.40B/s][A[A

 53%|█████████████▎           | 107081728/201626725 [02:16<01:37, 968089.65B/s][A[A

 53%|████████████▊           | 107294720/201626725 [02:16<01:26, 1088121.18B/s][A[A

 53%|████████████▊           | 107524096/201626725 [02:16<01:17, 1218187.53B/s][A[A

 53%|████████████▊           | 107737088/201626725 [02:16<01:12, 1292272.87B/s][A[A

 54%|████████████▊           | 107966464/20

 62%|███████████████▍         | 124677120/201626725 [02:39<09:08, 140197.28B/s][A[A

 62%|███████████████▍         | 124792832/201626725 [02:40<07:06, 180229.69B/s][A[A

 62%|███████████████▍         | 124841984/201626725 [02:40<07:01, 182285.04B/s][A[A

 62%|███████████████▍         | 124989440/201626725 [02:40<05:38, 226165.29B/s][A[A

 62%|███████████████▌         | 125087744/201626725 [02:42<11:16, 113204.03B/s][A[A

 62%|███████████████▌         | 125202432/201626725 [02:43<11:11, 113819.06B/s][A[A

 62%|███████████████▌         | 125317120/201626725 [02:43<08:56, 142211.92B/s][A[A

 62%|███████████████▌         | 125448192/201626725 [02:44<07:08, 177732.68B/s][A[A

 62%|███████████████▌         | 125562880/201626725 [02:44<05:33, 228394.49B/s][A[A

 62%|███████████████▌         | 125710336/201626725 [02:44<04:15, 297262.10B/s][A[A

 62%|███████████████▌         | 125841408/201626725 [02:44<03:26, 367104.09B/s][A[A

 62%|███████████████▌         | 125907968/2

 71%|█████████████████▋       | 143028224/201626725 [03:07<02:11, 444040.44B/s][A[A

 71%|█████████████████▊       | 143208448/201626725 [03:08<01:48, 537846.12B/s][A[A

 71%|█████████████████▊       | 143388672/201626725 [03:08<01:31, 639653.60B/s][A[A

 71%|█████████████████▊       | 143585280/201626725 [03:08<01:15, 765117.75B/s][A[A

 71%|█████████████████▊       | 143716352/201626725 [03:08<01:12, 798010.13B/s][A[A

 71%|█████████████████▊       | 143863808/201626725 [03:08<01:10, 821686.39B/s][A[A

 71%|█████████████████▊       | 144027648/201626725 [03:08<01:04, 889090.15B/s][A[A

 71%|█████████████████▊       | 144158720/201626725 [03:09<01:07, 850087.86B/s][A[A

 72%|█████████████████▉       | 144252928/201626725 [03:09<01:26, 665632.44B/s][A[A

 72%|█████████████████▉       | 144388096/201626725 [03:09<01:28, 644662.52B/s][A[A

 72%|█████████████████▉       | 144568320/201626725 [03:09<01:18, 725548.37B/s][A[A

 72%|█████████████████▉       | 144748544/2

 78%|███████████████████▌     | 157741056/201626725 [03:27<01:51, 394645.85B/s][A[A

 78%|███████████████████▌     | 157839360/201626725 [03:27<01:41, 432564.46B/s][A[A

 78%|███████████████████▌     | 157954048/201626725 [03:28<01:29, 488439.96B/s][A[A

 78%|███████████████████▌     | 158009344/201626725 [03:28<02:12, 328850.05B/s][A[A

 78%|████████████████████▍     | 158085120/201626725 [03:30<08:54, 81482.17B/s][A[A

 78%|███████████████████▌     | 158199808/201626725 [03:31<06:47, 106539.91B/s][A[A

 79%|███████████████████▋     | 158314496/201626725 [03:31<05:07, 141008.43B/s][A[A

 79%|███████████████████▋     | 158445568/201626725 [03:31<03:55, 183613.74B/s][A[A

 79%|███████████████████▋     | 158576640/201626725 [03:31<03:05, 231472.30B/s][A[A

 79%|███████████████████▋     | 158724096/201626725 [03:32<02:26, 293692.76B/s][A[A

 79%|███████████████████▋     | 158871552/201626725 [03:32<01:56, 368527.93B/s][A[A

 79%|███████████████████▋     | 159019008/2

 88%|█████████████████████▏  | 178139136/201626725 [03:51<00:14, 1612954.35B/s][A[A

 89%|█████████████████████▏  | 178499584/201626725 [03:51<00:12, 1838284.53B/s][A[A

 89%|█████████████████████▎  | 178860032/201626725 [03:51<00:11, 2039932.36B/s][A[A

 89%|█████████████████████▎  | 179220480/201626725 [03:51<00:10, 2041949.99B/s][A[A

 89%|█████████████████████▍  | 179597312/201626725 [03:51<00:11, 1849056.77B/s][A[A

 89%|█████████████████████▍  | 179957760/201626725 [03:52<00:11, 1812224.18B/s][A[A

 89%|█████████████████████▍  | 180318208/201626725 [03:52<00:14, 1495954.98B/s][A[A

 90%|█████████████████████▌  | 180678656/201626725 [03:52<00:14, 1400456.92B/s][A[A

 90%|█████████████████████▌  | 181055488/201626725 [03:52<00:14, 1452513.48B/s][A[A

 90%|█████████████████████▌  | 181268480/201626725 [03:53<00:14, 1369014.02B/s][A[A

 90%|█████████████████████▌  | 181530624/201626725 [03:53<00:15, 1268021.44B/s][A[A

 90%|█████████████████████▋  | 181776384/20

AttributeError: 'FineTuningConfig' object has no attribute 'initializer_range'

In [21]:
# init model: Transformer base + classifier head
model = TransformerWithClfHead(config=config, fine_tuning_config=finetuning_config).to(finetuning_config.device)
model.load_state_dict(state_dict, strict=False)


AttributeError: 'FineTuningConfig' object has no attribute 'initializer_range'

In [18]:
from ignite.engine import Engine, Events
from ignite.metrics import RunningAverage, Accuracy 
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import CosineAnnealingScheduler, PiecewiseLinear, create_lr_scheduler_with_warmup, ProgressBar
import torch.nn.functional as F
from pytorch_transformers.optimization import AdamW

# Bert optimizer
optimizer = AdamW(model.parameters(), lr=finetuning_config.lr, correct_bias=False) 

def update(engine, batch):
    "update function for training"
    model.train()
    inputs, labels = (t.to(finetuning_config.device) for t in batch)
    inputs = inputs.transpose(0, 1).contiguous() # [S, B]
    _, loss = model(inputs, 
                    clf_tokens_mask = (inputs == tokenizer.vocab[processor.CLS]), 
                    clf_labels=labels)
    loss = loss / finetuning_config.gradient_acc_steps
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), finetuning_config.max_norm)
    if engine.state.iteration % finetuning_config.gradient_acc_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
    return loss.item()

def inference(engine, batch):
    "update function for evaluation"
    model.eval()
    with torch.no_grad():
        batch, labels = (t.to(finetuning_config.device) for t in batch)
        inputs = batch.transpose(0, 1).contiguous()
        logits = model(inputs,
                       clf_tokens_mask = (inputs == tokenizer.vocab[processor.CLS]),
                       padding_mask = (batch == tokenizer.vocab[processor.PAD]))
    return logits, labels

trainer = Engine(update)
evaluator = Engine(inference)

# add metric to evaluator 
Accuracy().attach(evaluator, "accuracy")

# add evaluator to trainer: eval on valid set after each epoch
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(valid_dl)
    print(f"validation epoch: {engine.state.epoch} acc: {100*evaluator.state.metrics['accuracy']}")
          
# lr schedule: linearly warm-up to lr and then to zero
scheduler = PiecewiseLinear(optimizer, 'lr', [(0, 0.0), (finetuning_config.n_warmup, finetuning_config.lr),
                                              (len(train_dl)*finetuning_config.n_epochs, 0.0)])
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)


# add progressbar with loss
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
ProgressBar(persist=True).attach(trainer, metric_names=['loss'])

# save checkpoints and finetuning config
checkpoint_handler = ModelCheckpoint(finetuning_config.log_dir, 'finetuning_checkpoint', 
                                     save_interval=1, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'imdb_model': model})

# save config to logdir
torch.save(finetuning_config, os.path.join(finetuning_config.log_dir, 'fine_tuning_args.bin'))          

NameError: name 'model' is not defined