In [None]:
!git clone https://github.com/Muennighoff/performers.git
!cd performers; pip install -q -e '.[dev]'

!pip install -q datasets
!pip install -q ml_collections
!pip install -q --upgrade jax jaxlib

Cloning into 'performers'...
remote: Enumerating objects: 7, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 56020 (delta 0), reused 2 (delta 0), pack-reused 56013[K
Receiving objects: 100% (56020/56020), 42.40 MiB | 28.39 MiB/s, done.
Resolving deltas: 100% (39274/39274), done.
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 2.9MB 8.0MB/s 
[K     |████████████████████████████████| 890kB 53.4MB/s 
[K     |████████████████████████████████| 163kB 55.2MB/s 
[K     |████████████████████████████████| 163kB 48.8MB/s 
[K     |████████████████████████████████| 6.4MB 60.6MB/s 
[K     |████████████████████████████████| 8.2MB 56.8MB/s 
[K     |████████████████████████████████| 2.9MB 59.0MB/s 
[K     |████████████████████████████████| 102kB 13.8MB/s 
[K     |█████

In [None]:
import os
import sys
import tensorflow as tf
import math
from datasets import load_dataset

module_path = os.path.abspath(os.path.join('./performers/src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

from transformers.models.distilbert import DistilBertConfig, DistilBertForSequenceClassification
from transformers.models.distilbert import TFDistilBertForSequenceClassification
from transformers.models.bert import BertConfig, TFBertForSequenceClassification

##### IMDB Data Preparation

In [None]:
"""Input pipeline for the imdb dataset."""

from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

AUTOTUNE = tf.data.experimental.AUTOTUNE


def preprocess_dataset(file_path, batch_size):
  """Preprocess dataset."""
  tf.logging.info(file_path)
  sel_cols = ['Source', 'Target']
  col_defaults = [tf.string, tf.int32]
  ds = tf.data.experimental.make_csv_dataset([file_path],
                                             batch_size,
                                             column_defaults=col_defaults,
                                             select_columns=sel_cols,
                                             field_delim=',',
                                             header=True,
                                             shuffle=False,
                                             num_epochs=1)
  ds = ds.unbatch()
  return ds


def get_imdb_dataset():
  """Get dataset from  imdb tfds. converts into src/tgt pairs."""
  data = tfds.load('imdb_reviews')
  train_raw = data['train']
  valid_raw = data['test']
  test_raw = data['test']
  # use test set for validation because IMDb doesn't have val set.
  # Print an example.
  logging.info('Data sample: %s', next(iter(tfds.as_numpy(train_raw.skip(4)))))

  def adapt_example(example):
    return {'Source': example['text'], 'Target': example['label']}

  train = train_raw.map(adapt_example)
  valid = valid_raw.map(adapt_example)
  test = test_raw.map(adapt_example)

  return train, valid, test


def get_yelp_dataset():
  """Get dataset from yelp tfds. converts into src/tgt pairs."""
  data = tfds.load('yelp_polarity_reviews')
  train_raw = data['train']
  valid_raw = data['test']
  test_raw = data['test']
  # use test set for validation because yelp doesn't have val set.
  # Print an example.
  logging.info('Data sample: %s', next(iter(tfds.as_numpy(train_raw.skip(4)))))

  def adapt_example(example):
    return {'Source': example['text'], 'Target': example['label']}

  train = train_raw.map(adapt_example)
  valid = valid_raw.map(adapt_example)
  test = test_raw.map(adapt_example)

  return train, valid, test


def get_agnews_dataset():
  """Get dataset from  agnews tfds. converts into src/tgt pairs."""
  data = tfds.load('ag_news_subset')
  train_raw = data['train']
  valid_raw = data['test']
  test_raw = data['test']
  # use test set for validation because agnews doesn't have val set.
  # Print an example.
  logging.info('Data sample: %s', next(iter(tfds.as_numpy(train_raw.skip(4)))))

  def adapt_example(example):
    return {'Source': example['description'], 'Target': example['label']}

  train = train_raw.map(adapt_example)
  valid = valid_raw.map(adapt_example)
  test = test_raw.map(adapt_example)

  return train, valid, test


def get_tc_datasets(n_devices,
                    task_name,
                    data_dir=None,
                    batch_size=256,
                    fixed_vocab=None,
                    max_length=512,
                    tokenizer='char'):
  """Get text classification datasets."""
  if batch_size % n_devices:
    raise ValueError("Batch size %d isn't divided evenly by n_devices %d" %
                     (batch_size, n_devices))

  if task_name == 'imdb_reviews':
    train_dataset, val_dataset, test_dataset = get_imdb_dataset()
  elif task_name == 'yelp_reviews':
    train_dataset, val_dataset, test_dataset = get_yelp_dataset()
  elif task_name == 'agnews':
    train_dataset, val_dataset, test_dataset = get_agnews_dataset()
  else:
    train_path = data_dir + task_name + '_train.tsv'
    val_path = data_dir + task_name + '_val.tsv'
    test_path = data_dir + task_name + '_test.tsv'

    train_dataset = preprocess_dataset(train_path, batch_size)
    val_dataset = preprocess_dataset(val_path, batch_size)
    test_dataset = preprocess_dataset(test_path, batch_size)

  tf.logging.info('Finished preprocessing')

  tf.logging.info(val_dataset)

  if tokenizer == 'char':
    logging.info('Using char/byte level vocab')
    encoder = tfds.deprecated.text.ByteTextEncoder() ####################
  else:
    if fixed_vocab is None:
      tf.logging.info('Building vocab')
      # build vocab
      vocab_set = set()
      tokenizer = tfds.deprecated.text.Tokenizer()
      for i, data in enumerate(train_dataset):
        examples = data['Source']
        examples = tokenizer.tokenize(examples.numpy())
        examples = np.reshape(examples, (-1)).tolist()
        vocab_set.update(examples)
        if i % 1000 == 0:
          tf.logging.info('Processed {}'.format(i))
      tf.logging.info(len(vocab_set))
      vocab_set = list(set(vocab_set))
      tf.logging.info('Finished processing vocab size={}'.format(
          len(vocab_set)))
    else:
      vocab_set = list(set(fixed_vocab))
    encoder = tfds.deprecated.text.TokenTextEncoder(vocab_set)

  def tf_encode(x):
    result = tf.py_function(lambda s: tf.constant(encoder.encode(s.numpy())), [
        x,
    ], tf.int32)
    result.set_shape([None])
    return result

  def tokenize(d):
    return {
        'input_ids': tf_encode(d['Source'])[:max_length], ################
        'targets': d['Target']
    }

  train_dataset = train_dataset.map(tokenize, num_parallel_calls=AUTOTUNE)
  val_dataset = val_dataset.map(tokenize, num_parallel_calls=AUTOTUNE)
  test_dataset = test_dataset.map(tokenize, num_parallel_calls=AUTOTUNE)

  max_shape = {'input_ids': [max_length], 'targets': []}
  train_dataset = train_dataset.shuffle(
      buffer_size=256, reshuffle_each_iteration=True).padded_batch(
          batch_size, padded_shapes=max_shape, drop_remainder=True) ######## Added drop_remainder
  val_dataset = val_dataset.padded_batch(batch_size, padded_shapes=max_shape, drop_remainder=True)
  test_dataset = test_dataset.padded_batch(batch_size, padded_shapes=max_shape, drop_remainder=True)

  return train_dataset, val_dataset, test_dataset, encoder

In [None]:
### Helper functions ###

import ml_collections

def get_config():
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.batch_size = 8#32            # At BS 8, 3125 in train & val dataset
  config.eval_frequency = 100
  config.num_train_steps = 20000
  config.num_eval_steps = -1
  config.learning_rate = 0.05
  config.weight_decay = 1e-1
  config.max_target_length = 200
  config.max_eval_target_length = 200
  config.sampling_temperature = 0.6
  config.sampling_top_k = 20
  config.max_predict_token_length = 50
  config.save_checkpoints = True
  config.restore_checkpoints = True
  config.checkpoint_freq = 10000
  config.random_seed = 0
  config.prompt = ""
  config.factors = "constant * linear_warmup * rsqrt_decay"
  config.warmup = 8000
  config.classifier_pool = "CLS"

  config.max_length = 1000 #4K in paper?

  config.emb_dim = 256
  config.num_heads = 4
  config.num_layers = 4
  config.qkv_dim = 256
  config.mlp_dim = 1024

  config.trial = 0  # dummy for repeated runs.
  return config

In [None]:
# Get data
lra_config = get_config()

train_ds, eval_ds, test_ds, encoder = get_tc_datasets(
    n_devices=1, #jax.local_device_count(),
    task_name="imdb_reviews", # Used in paper
    data_dir="/content/",
    batch_size=lra_config.batch_size,
    fixed_vocab=None, # If we already have a vocab
    max_length=lra_config.max_length-1) # Use fixed max l of 4K > -2 to acount for <s> & </s>

INFO:absl:No config specified, defaulting to first: imdb_reviews/plain_text
INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: imdb_reviews/plain_text/1.0.0
INFO:absl:Load dataset info from /tmp/tmppdtw_p3rtfds
INFO:absl:Field info.config_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.config_description from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Generating dataset imdb_reviews (/root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0)


[1mDownloading and preparing dataset imdb_reviews/plain_text/1.0.0 (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

INFO:absl:Downloading http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz into /root/tensorflow_datasets/downloads/ai.stanfor.edu_amaas_sentime_aclImdb_v1PaujRp-TxjBWz59jHXsMDm5WiexbxzaFQkEnXc3Tvo8.tar.gz.tmp.ce1226337232419a8f1de8491373c6bc...
INFO:absl:Generating split train








HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete49EGZB/imdb_reviews-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=25000.0), HTML(value='')))

INFO:absl:Done writing /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete49EGZB/imdb_reviews-train.tfrecord. Shard lengths: [25000]
INFO:absl:Generating split test




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete49EGZB/imdb_reviews-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=25000.0), HTML(value='')))

INFO:absl:Done writing /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete49EGZB/imdb_reviews-test.tfrecord. Shard lengths: [25000]
INFO:absl:Generating split unsupervised




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete49EGZB/imdb_reviews-unsupervised.tfrecord


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))

INFO:absl:Done writing /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete49EGZB/imdb_reviews-unsupervised.tfrecord. Shard lengths: [50000]
INFO:absl:Skipping computing stats for mode ComputeStatsMode.SKIP.
INFO:absl:Constructing tf.data.Dataset for split None, from /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0


[1mDataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.[0m


INFO:absl:Data sample: {'label': 1, 'text': b'As others have mentioned, all the women that go nude in this film are mostly absolutely gorgeous. The plot very ably shows the hypocrisy of the female libido. When men are around they want to be pursued, but when no "men" are around, they become the pursuers of a 14 year old boy. And the boy becomes a man really fast (we should all be so lucky at this age!). He then gets up the courage to pursue his true love.'}


INFO:tensorflow:Finished preprocessing


INFO:tensorflow:Finished preprocessing


INFO:tensorflow:<MapDataset shapes: {Source: (), Target: ()}, types: {Source: tf.string, Target: tf.int64}>


INFO:tensorflow:<MapDataset shapes: {Source: (), Target: ()}, types: {Source: tf.string, Target: tf.int64}>
INFO:absl:Using char/byte level vocab


In [None]:
# Create attn masks

def add_attn_mask(example):
    """
    Add attention_mask to inputs; Don't attend to 0s
    """
    # Get a boolean wherever input_ids > 0; then cast to ints for attn_mask
    example["attention_mask"] = tf.cast((example["input_ids"] > 0)[...], tf.int32)

    return example

train_ds = train_ds.map(add_attn_mask)
eval_ds = eval_ds.map(add_attn_mask)
test_ds = test_ds.map(add_attn_mask)

In [None]:
# Add BOS (& EOS)

vocab_size = encoder.vocab_size
print('Vocab Size: ', vocab_size)

bos = tf.constant(vocab_size, shape=(lra_config.batch_size,1))
eos = tf.constant(vocab_size+1, shape=(lra_config.batch_size,1))

attn_addon = tf.constant(1, shape=(lra_config.batch_size,1))

def add_tokens(example):
    """
    Add <s> & </s> tokens // bos & eos

    Only adding bos for now, as they do not add eos token in the original & we alrdy have it padded
    """

    example["input_ids"] = tf.concat([bos, example["input_ids"]], axis=1)

    # Add 1 to attn_mask
    example["attention_mask"] = tf.concat([attn_addon, example["attention_mask"]], axis=1)

    return example


train_ds = train_ds.map(add_tokens)
eval_ds = eval_ds.map(add_tokens)
test_ds = test_ds.map(add_tokens)

Vocab Size:  257


In [None]:
### Turn TF Dataset to torch dataset ###

import torch

# Cast to torch dataloader; Needs list as else no __len__
train_torch = torch.utils.data.DataLoader(list(tfds.as_numpy(train_ds)))
val_torch = torch.utils.data.DataLoader(list(tfds.as_numpy(eval_ds)))

##### Modeling & Training

In [None]:
# Load & adjust hgface config
#config = RobertaConfig.from_pretrained("roberta-base")
config = DistilBertConfig()

print(config)

config.attention_type='performer'
#config.performer = True
config.vocab_size = vocab_size + 1
config.max_position_embeddings = lra_config.max_length + 1 # Need more than our seq_len as it seems the first isnt counted

config.bos_token_id = vocab_size
config.eos_token_id = vocab_size + 1 # Not used for now
config.pad_token_id = 0

# Change to config dims
config.n_heads = lra_config.num_heads
config.n_layers = lra_config.num_layers
config.intermediate_size = lra_config.mlp_dim # 1024
config.hidden_dim = lra_config.qkv_dim
#config.hidden_size = lra_config.qkv_dim # 256 # Is this the same?

print(config)

DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "attention_type": "softmax",
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "performer_attention_config": null,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "vocab_size": 30522
}

DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "attention_type": "performer",
  "bos_token_id": 257,
  "dim": 768,
  "dropout": 0.1,
  "eos_token_id": 258,
  "hidden_dim": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "max_position_embeddings": 1001,
  "model_type": "distilbert",
  "n_heads": 4,
  "n_layers": 4,
  "pad_token_id": 0,
  "performer_attention_config": null,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "vocab_size": 258
}



In [None]:
model = DistilBertForSequenceClassification(config)

#optimizer = torch.optim.Adam(model.parameters(), lr=lra_config.learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
optimizer = torch.optim.AdamW(model.parameters(), lr=lra_config.learning_rate, weight_decay=lra_config.weight_decay, betas=(0.9, 0.98), eps=1e-9)

criterion = torch.nn.CrossEntropyLoss()

In [None]:
def create_learning_rate_scheduler(
    factors='constant * linear_warmup * rsqrt_decay',
    base_learning_rate=0.5,
    warmup_steps=1000,
    decay_factor=0.5,
    steps_per_decay=20000,
    steps_per_cycle=100000):
  """Creates learning rate schedule.
  Interprets factors in the factors string which can consist of:
  * constant: interpreted as the constant value,
  * linear_warmup: interpreted as linear warmup until warmup_steps,
  * rsqrt_decay: divide by square root of max(step, warmup_steps)
  * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
  * decay_every: Every k steps decay the learning rate by decay_factor.
  * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
  Args:
    factors: string, factors separated by '*' that defines the schedule.
    base_learning_rate: float, the starting constant for the lr schedule.
    warmup_steps: int, how many steps to warm up for in the warmup schedule.
    decay_factor: float, the amount to decay the learning rate by.
    steps_per_decay: int, how often to decay the learning rate.
    steps_per_cycle: int, steps per cycle when using cosine decay.
  Returns:
    a function learning_rate(step): float -> {'learning_rate': float}, the
    step-dependent lr.
  """
  factors = [n.strip() for n in factors.split('*')]

  def step_fn(step):
    """Step to learning rate function."""
    ret = 1.0
    for name in factors:
      if name == 'constant':
        ret *= base_learning_rate
      elif name == 'linear_warmup':
        ret *= min(1.0, step / warmup_steps)
      elif name == 'rsqrt_decay':
        ret /= math.sqrt(max(step, warmup_steps))
      elif name == 'rsqrt_normalized_decay':
        ret *= math.sqrt(warmup_steps)
        ret /= math.sqrt(max(step, warmup_steps))
      elif name == 'decay_every':
        ret *= (decay_factor**(step // steps_per_decay))
      elif name == 'cosine_decay':
        progress = max(0.0, (step - warmup_steps) / float(steps_per_cycle))
        ret *= max(0.0, 0.5 * (1.0 + math.cos((math.pi) * (progress % 1.0))))
      else:
        raise ValueError('Unknown factor %s.' % name)
    return ret

  return step_fn

In [None]:
step_fn = create_learning_rate_scheduler(
                                        factors=lra_config.factors,
                                        base_learning_rate=lra_config.learning_rate,
                                        warmup_steps=lra_config.warmup
                                        )

def adjust_optim(optimizer, step, total_steps):

    lr = step_fn(total_steps) / 4 # By 4 as 1/4 their BS
    optimizer.param_groups[0]['lr'] = lr

    if (step) % 200 == 0:
        print("Current LR: ", lr)

In [None]:
import time

correct = 0
total = 0

epochs = 10 # > 20K
total_steps = 0 # Break at 20K as in original

# Move to cuda
model = model.cuda()

for epoch in range(epochs):

    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (batch) in enumerate(train_torch, 0):

        total_steps += 1
        adjust_optim(optimizer, step, total_steps)

        # zero the parameter gradients
        optimizer.zero_grad()

        input_ids = torch.squeeze(batch["input_ids"]).to(torch.int64).cuda()
        attn_mask = torch.squeeze(batch["attention_mask"]).to(torch.int64).cuda()
        targets = torch.squeeze(batch["targets"]).cuda()

        # forward + backward + optimize
        logits = model(input_ids, attn_mask).logits

        loss = criterion(logits, targets)
        loss.backward()

        optimizer.step()

        # Logits should be of shape 8, 2
        _, predicted = torch.max(logits, 1)

        total += targets.size(0)
        correct += (predicted == targets).sum().item()

        # Log every 200 batches.
        if (step) % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss))
            )

        # Evaluate every 2000 steps
        if (total_steps - 1) % 2000 == 0:
            # Display metrics at the end of each epoch.
            print("Training acc over 2K Steps: %.4f" % (100 * correct / total))

            model.eval()

            total, correct = 0, 0

            with torch.no_grad(): 
                # Run a validation loop at the end of each epoch.
                for batch in val_torch:

                    input_ids = torch.squeeze(batch["input_ids"]).to(torch.int64).cuda()
                    attn_mask = torch.squeeze(batch["attention_mask"]).to(torch.int64).cuda()
                    targets = torch.squeeze(batch["targets"]).cuda()

                    # forward 
                    logits = model(input_ids, attn_mask).logits
                    
                    # Logits should be of shape 8, 2
                    _, predicted = torch.max(logits, 1)

                    total += targets.size(0)
                    correct += (predicted == targets).sum().item()

            model.train()

            print("Validation acc: %.4f" % (100 * correct / total))
            print("Time taken: %.2fs" % (time.time() - start_time))
            total, correct = 0, 0

        if total_steps % 20000 == 0:
            print("20,000 steps finished")
            print("=====================")
            break
    if total_steps % 20000 == 0:
        break


Start of epoch 0
Current LR:  1.7469281074217106e-08
Training loss (for one batch) at step 0: 0.6707
Training acc over 2K Steps: 62.5000
Validation acc: 49.8680
Time taken: 334.30s
Current LR:  3.511325495917639e-06
Training loss (for one batch) at step 200: 0.6989
Current LR:  7.00518171076106e-06
Training loss (for one batch) at step 400: 0.6784
Current LR:  1.049903792560448e-05
Training loss (for one batch) at step 600: 0.6360
Current LR:  1.3992894140447903e-05
Training loss (for one batch) at step 800: 0.6888
Current LR:  1.7486750355291324e-05
Training loss (for one batch) at step 1000: 0.7045
Current LR:  2.0980606570134746e-05
Training loss (for one batch) at step 1200: 0.7161
Current LR:  2.4474462784978165e-05
Training loss (for one batch) at step 1400: 0.7745
Current LR:  2.796831899982159e-05
Training loss (for one batch) at step 1600: 0.6166
Current LR:  3.1462175214665007e-05
Training loss (for one batch) at step 1800: 0.4999
Current LR:  3.495603142950843e-05
Training 

In [None]:
# Temporarily reaches 61% in-between which is on par with thr 64% reported (They took the best from 4 diff. runs in paper)