In [1]:
import functools

import numpy as np
import rich
import torch
import transformers

import general_utils as gu

In [2]:
data = """"The task of building general agents 
that perform well over a wide
 range of tasks has been an important 
 goal in reinforcement learning since its inception. 
The problem has been subject of research
 of a large body of work, with performance 
frequently measured by observing scores over the
 wide range of environments contained in the Atari 57 benchmark. 
Agent57 was the first agent to 
surpass the human benchmark on all 57 games, 
but this came at the cost of poor data-efficiency, 
requiring nearly 80 billion frames of experience to achieve. 
Taking Agent57 as a starting point,
 we employ a diverse set of strategies to 
achieve a 200-fold reduction of experience 
needed to out perform the human baseline. 
We investigate a range of instabilities
 and bottlenecks we encountered while 
reducing the data regime, and propose
effective solutions to build a more robust and efficient agent. 
We also demonstrate competitive performance
 with high-performing methods such
 as Muesli and MuZero. The four key components 
 to our approach are (1) an approximate trust 
region method which enables stable bootstrapping from the online network, 
(2) a normalisation scheme for the loss and 
priorities which improves robustness when learning 
a set of value functions with a wide range
 of scales, (3) an improved architecture
 employing techniques from NFNets in order 
 to leverage deeper networks without the need 
for normalization layers, and (4) a policy 
distillation method which serves
 to smooth out the instantaneous greedy policy overtime.
https://doi.org/10.48550/arXiv.2209.07550
Focus to learn more""".strip().split("\n")
max_len = 50
data = [x.strip()[:max_len] for x in data]
print([len(x) for x in data])
print(len(data))

[36, 29, 36, 50, 40, 41, 48, 50, 30, 44, 50, 50, 35, 40, 42, 41, 39, 36, 37, 50, 43, 33, 45, 44, 50, 43, 50, 42, 39, 41, 44, 42, 32, 50, 41, 19]
36


In [3]:
model_names = ["distilgpt2", "gpt2", "gpt2-large"]
models = {name: transformers.GPT2LMHeadModel.from_pretrained(name).to("cuda") for name in model_names}

In [4]:
import rich.table as table

table_ = table.Table("Key", *model_names)
all_keys = functools.reduce(lambda a, b: a | b, [vars(x.config).keys() for x in models.values()], set())
to_ignore = {"id2label", "label2id", "_num_labels"}

for k in all_keys - to_ignore:
    
    for model_name in models:
        if k not in vars(models[model_name].config):
            print(k, "not in", model_name)

    if not all([
        vars(models[model_name].config).get(k, object()) ==   # Iterating over the values of the models
        vars(models[model_names[0]].config).get(k, object())  # Value for the first model
        for model_name in models
    ]):
        table_.add_row(k, *[str(vars(models[model_name].config).get(k, "<Not present>")) for model_name in models])

rich.print(table_)

In [5]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenized = tokenizer(data, padding=True, return_tensors="pt", truncation=True, max_length=max_len, add_special_tokens=False)
tokenized = {k: (v.half() if v.dtype == torch.float else v).to("cuda") for k, v in tokenized.items()}
for model in models.values():
    model.config.pad_token_id = tokenizer.pad_token_id

In [6]:
num_beams = 5
max_gen_len = 10

configs = dict(
    warmup_sampled=dict(
        do_sample=True,
    ),

    sampled=dict(
        do_sample=True,
    ),

    beam_search=dict(
        do_sample=False,
        num_beams=num_beams, 
        num_return_sequences=num_beams,
    ),

    sampled_beam_search=dict(
        do_sample=True,
        num_beams=num_beams,
        num_return_sequences=num_beams,
    ),

    group_beam_search=dict(
        do_sample=False,
        num_beams=num_beams,
        num_beam_groups=num_beams,
        num_return_sequences=num_beams,
        diversity_penalty=0.25,
    ),
)

for k, v in tokenized.items():
    print(k, v.shape)

for config_name, kwargs in configs.items():
    for name, model in models.items():
        model.eval()
        with torch.inference_mode():
            inputs = {k: v for k, v in tokenized.items()}
            if config_name == "sampled":
                inputs["input_ids"]      = inputs["input_ids"     ].repeat_interleave(num_beams, dim=0)
                inputs["attention_mask"] = inputs["attention_mask"].repeat_interleave(num_beams, dim=0)

            with gu.cuda_timeit(f"generation with {name} and {config_name}"):
                output = model.generate(
                    **inputs, 
                    **kwargs, 
                    cache=True, 
                    constraints=None, 
                    max_new_tokens=max_gen_len,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )


input_ids torch.Size([36, 20])
attention_mask torch.Size([36, 20])




In [7]:
!nvidia-smi

Mon Sep 19 17:58:10 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  On   | 00000000:47:00.0 Off |                    0 |
| N/A   34C    P0    78W / 400W |  39302MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:BD:00.0 Off |                    0 |
| N/A   37C    P0    54W / 400W |      3MiB / 40536MiB |      0%      Default |
|       