# Fine-tuning GPT-2 for text classification
Stefan/Yuzhao Heng
Since Wed. Feb. 9th, 2022


Reproduce the results in paper [Zero-shot Text Classification With Generative Language Models](https://arxiv.org/abs/1912.10165),
since the authors didn't release the code.

Serve as infrastructure and baseline for project on efficient and accurate encoder for text classification with many labels.


## Notebook Setup



In [25]:
%load_ext autoreload
%autoreload 2



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Colab Setup



In [26]:
import os
import sys


if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')

    ! pip3 install sty icecream transformers datasets

    # base_path = '/content/drive/My Drive//Research/'
    # os.chdir(os.path.join(base_path, 'Unified Encoder/Unified-Encoder'))

    sys.path.append(os.path.join('drive', 'My Drive', 'Research', 'Unified Encoder', 'Unified-Encoder'))


from unified_encoder.util import *
print(PATH_BASE)  # Sanity check, should be the path appended if Colab



/Users/stefanh/Documents/UMich/Research/Clarity Lab/Unified Encoder


## Setup



In [27]:
import random

import numpy as np
import torch
import transformers
from icecream import ic

from unified_encoder.baseline import gpt2


rcParams['figure.dpi'] = 200
rcParams['font.size'] = 6



## Seed setup



In [28]:
if torch.cuda.is_available():
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # Required for some CuBLAS operations
    ! echo $CUBLAS_WORKSPACE_CONFIG


seed = config('random-seed')
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)
transformers.set_seed(seed)



## Prep Model & Dataset for training


In [29]:
nm, n = 'debug', 8
# nm, n = 'debug-large', 128
model, tokenizer, data_collator, train_args, dset_tr, dset_vl, trainer = gpt2.get_all_setup(
    nm, n_sample=n, random_seed=seed
)



Using custom data configuration default
Reusing dataset ag_news (/Users/stefanh/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
Using custom data configuration default
Reusing dataset ag_news (/Users/stefanh/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/stefanh/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-76a45c5155a2ee95.arrow
Loading cached processed dataset at /Users/stefanh/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-9b3097dc8e0c748c.arrow
Loading cached shuffled indices for dataset at /Users/stefanh/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-ce92bb656b71aa1d.arrow
Loading cached shuffled indices for dataset at /Users/stefanh/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-fb8626521cf20fea.arrow


## Train


In [31]:
trainer.train()
trainer.save_model(os.path.join(trainer.args.output_dir, now(sep='-')))



 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 6

AssertionError: 

## Evaluate


In [None]:
trainer.evaluate()

