## Imports

In [None]:
%pip install -q datasets transformers==4.28.0 --upgrade accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m61.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.1/219.1 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m34.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Constants

In [None]:
model_id = "gpt2"
DEVICE = 'cuda'

## Load Dataset

In [None]:
from datasets import load_dataset

# Note: this is the same dataset as https://urldefense.com/v3/__https://pytorch.org/text/stable/datasets.html*id22__;Iw!!LIr3w8kk_Xxm!oJNtg0Dcg0AZd7jpP-TKv-pOUtoxBQ668RwcOjO1YIHTzTC8ZBVbXkyntoc9YijqdBKbGukpcgzchLbesQ$
raw_dataset = load_dataset("wikitext", "wikitext-2-v1")

Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.25k [00:00<?, ?B/s]

Downloading and preparing dataset wikitext/wikitext-2-v1 to /root/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading data:   0%|          | 0.00/4.48M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset wikitext downloaded and prepared to /root/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

## Load Tokenizer

In [None]:
from transformers import GPT2TokenizerFast
from transformers import DataCollatorForLanguageModeling


tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

## Load Untrained Model

In [None]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=128,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

GPT-2 size: 124.4M parameters


## Data Preprocessing

In [None]:
processed_dataset = (raw_dataset
    .filter(lambda example: len(example['text']) > 0)
    .map(
        lambda example: tokenizer(
            example['text'], 
            max_length=128, 
            truncation=True, 
            padding='max_length',
            return_tensors="pt",
            return_attention_mask=True,
        ),
        batched=True
    )
)

processed_dataset = processed_dataset.remove_columns('text')
processed_dataset.set_format('torch')

Filter:   0%|          | 0/4358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/36718 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3760 [00:00<?, ? examples/s]

Map:   0%|          | 0/2891 [00:00<?, ? examples/s]

Map:   0%|          | 0/23767 [00:00<?, ? examples/s]

Map:   0%|          | 0/2461 [00:00<?, ? examples/s]

In [None]:
from torch.utils.data import DataLoader
import torch

model = model.to(DEVICE)
model_parameters = list(model.parameters())

real_train_dataloader = DataLoader(processed_dataset['train'], batch_size=16)
batch_real = next(iter(real_train_dataloader))

x_real = batch_real['input_ids'].to(DEVICE)
attn_mask_real = batch_real['attention_mask'].to(DEVICE)
y_real = x_real.clone()

out_real = model(x_real, attention_mask=attn_mask_real, labels=y_real)
gradient_weights_real = torch.autograd.grad(out_real.loss, model_parameters)



### Synthetic Data

In [None]:
# Generate
syn_tokens = torch.randint(0, 50257-1, (120, 128))

In [None]:
syn_train_dataloader = DataLoader(syn_tokens, batch_size=16)
batch_syn = next(iter(syn_train_dataloader))

# Compute the synthetic loss and get synthetic gradient weights
x_syn = batch_syn.to(DEVICE)
y_syn = x_syn.clone()

# syn_embed = model.get_input_embeddings()(x_syn) # Grab the synthetic embeddings, this is the vector space that we wish to update
# syn_embed.requires_grad = True # It requires grads now
out_syn = model(x_syn, labels=y_syn) # We pass it through the remaining model

In [None]:
gradient_weights_syn = torch.autograd.grad(out_syn.loss, model_parameters, create_graph=True) # We obtain embeddings for all layers, including embeddings layer.

In [None]:
syn_train_dataloader = DataLoader(syn_tokens, batch_size=16)
batch_syn = next(iter(syn_train_dataloader))

# Compute the synthetic loss and get synthetic gradient weights
x_syn = batch_syn.to(DEVICE)
y_syn = x_syn.clone()

syn_embed = model.get_input_embeddings()(x_syn) # Grab the synthetic embeddings, this is the vector space that we wish to update
# syn_embed.requires_grad = True # It requires grads now
syn_embed = syn_embed.detach()
syn_embed.requires_grad = True
out_syn = model(inputs_embeds=syn_embed, labels=y_syn) # We pass it through the remaining model

In [None]:
(out_syn.loss)

tensor(10.9905, device='cuda:0', grad_fn=<NllLossBackward0>)