# Experiment: Working with Dolly
## Last Updated $DATE -  $AUTHOR.

```
Summary of High Level Research Question
```

Try to scope your experiments such you can answer your research question in 1-3 hours.
This is an ideal time block to enter flow / deep work, but short enough that you will still feel 
motivated by a relatively tight feedback loop.

If a problem seems like it needs more time that that, 

### High Level Experiment Design

## Goals:
```
List of specific goals that this experiment seeks to achieve.

This should fall under a few categories:
- Development of Intuition about a _specific_ topic
- Novel Research or Insight that could lead to a publishable result
- Meaningfully explore a topic which could lead to an improvement in product

Guiding principles should understanding, insight, and value creation.
```

## Tasks & Experiment Design

```
A list of specific tasks that are going to be tested 

```


## Outcomes

```
Document high level research findings and how
```


In [37]:
# Install things into ENV
# TODO: Setup up a container and push to docker that contains all these
%pip install git+https://github.com/neelnanda-io/TransformerLens.git
%pip install circuitsvis
%pip install plotly


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-fn5so28x
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-fn5so28x
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 0ffcc8ad647d9e991f4c2596557a9d7475617773
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[0mNote: you may need to restart the kernel to use updated packages.
huggingface/tokenizers: The current process just got forked, after parall

In [36]:
# Generic Set of Imports for MI Research
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
from pprint import pprint
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [38]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [40]:
# Setup PyTorch configuration for inference based experiments
# NOTE: Mark as False if you want to do any kind of training 
#       as part of your experimentation

INFERENCE_ONLY_EXPERIMENT = True
if INFERENCE_ONLY_EXPERIMENT:
    torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

import plotly.io as pio
pio.renderers.default = "notebook_connected"

cuda


In [39]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [41]:
# Load Circuit Visualizations
# TODO: Explore building out our own packages / tooling
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Vivek")


In [42]:
# Load & Run a Model
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-7b")
hf_model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-7b")

print("Loaded hf_model, hooking transformer into TransformerLens!")
# model = HookedTransformer.from_pretrained(
#     "EleutherAI/pythia-6.9b-deduped",
#     center_unembed=False,
#     center_writing_weights=False,
#     fold_ln=False,
#     refactor_factored_attn_matrices=True,
#     hf_model=hf_model
# )

### Janky Shit
### TODO: Figure out how this library actually works and make this a cleaner integration.
import transformer_lens.loading_from_pretrained as loading
# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name("EleutherAI/pythia-6.9b-deduped")


# Load the config into an HookedTransformerConfig object. If loading from a
# checkpoint, the config object will contain the information about the
# checkpoint
cfg = loading.get_pretrained_model_config(
    official_model_name,
    checkpoint_index=None,
    checkpoint_value=None,
    fold_ln=False,
    device=device,
    n_devices=1,
)
print(cfg)
cfg.d_vocab = 50280
cfg.d_vocab_out = 50280
print(cfg)


# Get the state dict of the model (ie a mapping of parameter names to tensors), processed to match the HookedTransformer parameter names.
state_dict = loading.get_pretrained_state_dict(
    official_model_name, cfg, hf_model
)

# Create the HookedTransformer object
model = HookedTransformer(cfg, tokenizer=tokenizer)

model.load_and_process_state_dict(
    state_dict,
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    refactor_factored_attn_matrices=False,
    move_state_dict_to_device=True,
)

print(f"Loaded pretrained model into HookedTransformer!")

model_description_text = """For this demo notebook we'll look at Dolly v2. It is based on pythia 6.9b, but we use the weights for dolly v2. To try the model the model out, let's find the loss on this paragraph!"""
# return_type of model can be loss, logits, both, or none!
loss = model(model_description_text, return_type="loss")
print("Model loss:", loss)


Loaded hf_model, hooking transformer into TransformerLens!
HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 128,
 'd_mlp': 16384,
 'd_model': 4096,
 'd_vocab': 50432,
 'd_vocab_out': 50432,
 'device': 'cuda',
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.0125,
 'model_name': 'pythia-6.9b-deduped',
 'n_ctx': 2048,
 'n_devices': 1,
 'n_heads': 32,
 'n_layers': 32,
 'n_params': 6442450944,
 'normalization_type': 'LN',
 'original_architecture': 'GPTNeoXForCausalLM',
 'parallel_attn_mlp': True,
 'positional_embedding_type': 'rotary',
 'rotary_dim': 32,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'EleutherAI/pythia-6.9b-deduped',
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_hook_to

In [43]:
# DOLLY V2 - 7B Config
pprint(model.cfg)

# Transformer Lens Note:
# get_token_position, to_tokens, to_string, to_str_tokens, prepend_bos, to_single_token
# are all methods that are added to the model object by TransformerLens

HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 128,
 'd_mlp': 16384,
 'd_model': 4096,
 'd_vocab': 50280,
 'd_vocab_out': 50280,
 'device': 'cuda',
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.0125,
 'model_name': 'pythia-6.9b-deduped',
 'n_ctx': 2048,
 'n_devices': 1,
 'n_heads': 32,
 'n_layers': 32,
 'n_params': 6442450944,
 'normalization_type': 'LN',
 'original_architecture': 'GPTNeoXForCausalLM',
 'parallel_attn_mlp': True,
 'positional_embedding_type': 'rotary',
 'rotary_dim': 32,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'EleutherAI/pythia-6.9b-deduped',
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_hook_tokens': False,
 'use_local_attn': False,
 'use_split_qkv_inp

In [44]:
#from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate


In [45]:
sample_string = "On halloween, all the children go Trick or"

print(model.to_str_tokens(sample_string)) # Shows tokenization split
print(model.to_tokens(sample_string)) #converts string to integer labeled tokens and then returns a tensor on models device of shape (batch, position)
# NOTE: in GPT2, 50256 is the token for EOS, BOS, and Padding.
# To single token converts string to a single integer, useful for looking up logits
# to_string converts a tensor of tokens to a string


#model.blocks.register_forward_hook  



model.generate(sample_string,
               temperature=0,
               max_new_tokens=1)



from pprint import pprint
batch_size = 8
num_epochs = 1
max_steps = 1
log_every = 1
lr = 1e-3
weight_decay = 1e-2
overfitMax=1
#model_cfg = Config(debug=False, d_model=256, n_heads=4, d_head=64, d_mlp=1024, n_layers=2, n_ctx=256, d_vocab=reference_gpt2.cfg.d_vocab)


# optimizer_copy = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)


# print("done one")
# losses = []

# #print(dataset)
# #print(dataset[0]['text'][:100])
# from EasyTransformer import easy_transformer 
# tokens_dataset = easy_transformer.utils.tokenize_and_concatenate(dataset, model.tokenizer, streaming=False, max_length=model_cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)
# data_loader = torch.utils.data.DataLoader(tokens_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
# print("Number of batches:", len(data_loader))
# #test_string_trained = "Hello world this is a test of overfitting"
# print("done")

['<|endoftext|>', 'On', ' hall', 'ow', 'een', ',', ' all', ' the', ' children', ' go', ' T', 'rick', ' or']
tensor([[   0, 2374, 7423,  319, 9673,   13,  512,  253, 2151,  564,  308, 4662,
          390]], device='cuda:0')


  0%|          | 0/1 [00:00<?, ?it/s]

In [46]:
print("done")

done


In [47]:
# Test Prompt Util -- Check the logit score of the expected output vs. the actual
#                     output
example_prompt = "the founder of Facebook is Mark"
example_answer = "Zuckerberg"

utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)



Tokenized prompt: ['<|endoftext|>', 'the', ' founder', ' of', ' Facebook', ' is', ' Mark']
Tokenized answer: [' Z', 'ucker', 'berg']


Top 0th token. Logit: 22.91 Prob: 99.94% Token: | Z|
Top 1th token. Logit: 15.02 Prob:  0.04% Token: | z|
Top 2th token. Logit: 12.46 Prob:  0.00% Token: | E|
Top 3th token. Logit: 12.16 Prob:  0.00% Token: | Elliot|
Top 4th token. Logit: 11.97 Prob:  0.00% Token: |
|
Top 5th token. Logit: 11.81 Prob:  0.00% Token: | Cuban|
Top 6th token. Logit: 11.52 Prob:  0.00% Token: |Z|
Top 7th token. Logit: 11.40 Prob:  0.00% Token: |  |
Top 8th token. Logit: 10.68 Prob:  0.00% Token: |us|
Top 9th token. Logit: 10.56 Prob:  0.00% Token: |.|


Top 0th token. Logit: 23.77 Prob: 99.60% Token: |ucker|
Top 1th token. Logit: 17.93 Prob:  0.29% Token: |uk|
Top 2th token. Logit: 16.55 Prob:  0.07% Token: |uck|
Top 3th token. Logit: 15.53 Prob:  0.03% Token: |uc|
Top 4th token. Logit: 13.66 Prob:  0.00% Token: |UCK|
Top 5th token. Logit: 12.91 Prob:  0.00% Token: |.|
Top 6th token. Logit: 11.87 Prob:  0.00% Token: |ub|
Top 7th token. Logit: 11.50 Prob:  0.00% Token: |ucks|
Top 8th token. Logit: 11.00 Prob:  0.00% Token: |ander|
Top 9th token. Logit: 10.94 Prob:  0.00% Token: |im|


Top 0th token. Logit: 25.24 Prob: 98.27% Token: |berg|
Top 1th token. Logit: 21.15 Prob:  1.65% Token: |burg|
Top 2th token. Logit: 17.57 Prob:  0.05% Token: |ber|
Top 3th token. Logit: 16.38 Prob:  0.01% Token: |borg|
Top 4th token. Logit: 15.81 Prob:  0.01% Token: |beg|
Top 5th token. Logit: 13.86 Prob:  0.00% Token: |bert|
Top 6th token. Logit: 13.72 Prob:  0.00% Token: |bur|
Top 7th token. Logit: 13.15 Prob:  0.00% Token: |­|
Top 8th token. Logit: 12.94 Prob:  0.00% Token: |b|
Top 9th token. Logit: 12.89 Prob:  0.00% Token: |berger|


In [48]:
pprint([(name, param.shape) for name, param in model.named_parameters()])

[('embed.W_E', torch.Size([50280, 4096])),
 ('blocks.0.ln1.w', torch.Size([4096])),
 ('blocks.0.ln1.b', torch.Size([4096])),
 ('blocks.0.ln2.w', torch.Size([4096])),
 ('blocks.0.ln2.b', torch.Size([4096])),
 ('blocks.0.attn.W_Q', torch.Size([32, 4096, 128])),
 ('blocks.0.attn.W_K', torch.Size([32, 4096, 128])),
 ('blocks.0.attn.W_V', torch.Size([32, 4096, 128])),
 ('blocks.0.attn.W_O', torch.Size([32, 128, 4096])),
 ('blocks.0.attn.b_Q', torch.Size([32, 128])),
 ('blocks.0.attn.b_K', torch.Size([32, 128])),
 ('blocks.0.attn.b_V', torch.Size([32, 128])),
 ('blocks.0.attn.b_O', torch.Size([4096])),
 ('blocks.0.mlp.W_in', torch.Size([4096, 16384])),
 ('blocks.0.mlp.b_in', torch.Size([16384])),
 ('blocks.0.mlp.W_out', torch.Size([16384, 4096])),
 ('blocks.0.mlp.b_out', torch.Size([4096])),
 ('blocks.1.ln1.w', torch.Size([4096])),
 ('blocks.1.ln1.b', torch.Size([4096])),
 ('blocks.1.ln2.w', torch.Size([4096])),
 ('blocks.1.ln2.b', torch.Size([4096])),
 ('blocks.1.attn.W_Q', torch.Size([32, 

In [109]:
# Testing out Dolly's Q/A ability.

# model.generate(
#     "abcdefghi",

#     max_new_tokens=100,
# )

#
#     model.generate(
    # """I will give you a sentence in the form, Sentence: <SENTENCE>, and you will then write out the value of the first letter of each word by converting each leading letter to a number, and then add all the numbers up to get their total sum and respond Answer:'<SUM>'. 
    # Example 1, Sentence:'A Cat', then A is the first letter of the alphabet, so A=1, and C is the third letter, so C=3, and 1+3=4, so the answer would be 4, Answer:4
    # Example 2, Sentence:'A Cat Ran For President', then A is the first letter of the alphabet, so A=1, and C is the third letter, so C=3, and so on making R=18, F=6, and P=16, so the answer would be 44, Answer:44
    
    # Example 3, Sentence:'A Cat Ran For Mayor', then""",

    # max_new_tokens=100,
# )
#

from pprint import pprint 
example_prompt = "the founder of Facebook is Mark"
example_answer = "then A is the first letter of the alphabet, so A=1, and C is the third letter"
p = """I will give you a sentence in the form, Sentence: '<SENTENCE>', and you will then have a Thought: '<THOUGHT>' write out the value of the first letter of each word by converting each leading letter to a number, and then add all the numbers up to get their total sum and respond Answer: '<SUM>' 
    Sentence: '<SENTENCE>'  
    Thought: '<THOUGHT>'
    Answer: '<SUM>'
Example 1: 
    Sentence:'A Cat' 
    Thought: 'then A is the first letter of the alphabet, so A=1, and C is the third letter, so C=3, and 1+3=4, so the answer would be 4'
    Answer: '4'
Example 2: 
    Sentence: 'A Dog Ran For President'
    Thought: 'then A is the first letter of the alphabet, so A=1, and D is the fourth letter of the alphabet, so D=4, R is the 18th letter of the R=18, F=6, and P=16, so the answer would be 44'
    Answer: '44'
Example 3: 
    Sentence: 'A Dog Ran For Mayor'
    Thought: '"""  
p = """I will give you a sentence in the form, Sentence: '<SENTENCE>', and you will then have a Thought: '<THOUGHT>' where you will write out the numeric value of the first letter of each word by converting each leading letter to a number, and then add all the numbers up to get their total sum and respond Answer: '<SUM>' 
Remember to double check your math very carefully, make sure when  you are adding numbers up they are the correct numbers, and that you are adding them correctly. 
    Sentence: '<SENTENCE>'  
    Thought: '<THOUGHT>'
    Answer: '<SUM>'
Example 1: 
    Sentence:'A Cat' 
    Thought: 'A=1, C=3, and so A=1 + C=3 sums to 4, so the answer would be 4'
    Answer: '4'
Example 2: 
    Sentence: 'Dog Ran For President'
    Thought: 'D=4, R=18, F=6, P=16, and so D=4 + R=18 + F=6 + P=16 sums to 44, so the answer would be 45'
    Answer: '45'
Example 3: 
    Sentence: 'A Dog Ran For Mayor'
    Thought: '"""  
#utils.test_prompt(  
#, example_answer, model, prepend_bos=True)

p = """Palantir Founder Stephen Cohen often says
 """
o = model.generate(p, max_new_tokens=120)
pprint(o)
#o += """\n There should be a company associated with Peter Thiel on the list. Repeat \n"""
#o2 = model.generate(o, max_new_tokens=120)

#pprint(o2)

# This Dolly is kind of dumb.
# We should try to get A100 x 8 Cluster ASAP and get a full sized model.

  0%|          | 0/120 [00:00<?, ?it/s]

('Palantir Founder Stephen Cohen often says\n'
 ' \n'
 '"The future is already here — it\'s just not very evenly distributed."\n'
 '\n'
 'Since the foundation of the company in 2004,\n'
 ' \n'
 'its co-founders have been awarded over $90 million in funding, led the '
 "company through two rounds of acquisitions, and it's used by government, "
 'industry, and academic clients worldwide.\n'
 '\n'
 'Palantir was named one of the “Top 35 technologists to watch” by Time in '
 '2016\n'
 '\n'
 '\n'
 'It boasts of clients like the Army, ICE, and the Department of Defense, and '
 'claims to offer a technological “edge” to government agencies.')


In [110]:
p = """Palantir Founder Alex Karp often says
 """
o = model.generate(p, max_new_tokens=120)
pprint(o)

  0%|          | 0/120 [00:00<?, ?it/s]

('Palantir Founder Alex Karp often says\n'
 ' \n'
 '\n'
 ' 1. "We ship quickly".\n'
 ' \n'
 '\n'
 '  2. "We don\'t build A.I. to sell to C.I.A. We build A.I. to sell to C3I '
 'companies".\n'
 ' \n'
 '\n'
 "  3.  We don't hire slow, we franchise.\n"
 ' \n'
 " 4.  We don't control our own destiny  - we collaborate with hundreds of "
 'customers to build the best experience.\n'
 ' \n'
 '5.  Spending money to create product market fit is the best economics.\n'
 ' \n'
 '6.  Our competitors often want to build a billion dollar company before they '
 'get started')


In [111]:
p = """Palantir Founder Peter Thiel often says
 """
o = model.generate(p, max_new_tokens=120)
pprint(o)

  0%|          | 0/120 [00:00<?, ?it/s]

('Palantir Founder Peter Thiel often says\n'
 ' \n'
 'In his book Endgame he describes a fictional sci-fi setting in the 2020s '
 'where advanced AI agents develop a pathological need for human attention and '
 'contribution. If left unattended they become insane and attempt to take over '
 'the world.\n'
 ' \n'
 'According to Thiel, the setting is based on real history with insights into '
 'how it could develop further. He names core techno-social technologies that '
 'enable the scenario including Artificial General Intelligence, Collective '
 'Intelligence, Enhanced Human Genomes, Virtual Reality and Smart Contracts.\n'
 ' \n'
 'One solution to the problem he describes is creating an Internet of Thing')


In [50]:
# Direct Logit Attribution


In [112]:
p = """Palantir Founder Stephen Cohen often says
 """
o = model.generate(p, max_new_tokens=120)
pprint(o)

  0%|          | 0/120 [00:00<?, ?it/s]

('Palantir Founder Stephen Cohen often says\n'
 ' \n'
 '"warfare as we know it in the west will end, computer warfare will occur, '
 'and no army will be able to compete."\n'
 ' \n'
 'Palantir executive chair and co-founder, Peter Thiel, said in 2020\n'
 ' \n'
 '"Stephen Cohen was telling me somebody just said that warfare as we know it '
 'will end. Even though you can probably find exceptions, he said the general '
 "sentiment is that if you're an elite unit and you step on the field of "
 'battle you have a 90 percent likelihood of losing. WAR will end".\n'
 ' \n'
 'See the full quote above.\n'
 ' \n')


In [113]:
p = """Palantir Founder Stephen Cohen original built
 """
o = model.generate(p, max_new_tokens=220)
pprint(o)

  0%|          | 0/220 [00:00<?, ?it/s]

('Palantir Founder Stephen Cohen often says\n'
 ' \n'
 '"privacy is a threatened freedom, because the industry for monetizing your '
 "data is growing so fast that it can't even measure all the data it "
 'generates"\n'
 ' \n'
 'in this talk he says\n'
 '"Privacy by Design"\n'
 ' \n'
 'may sound good on an elevator pitch\n'
 'but in the actual deep dive it just sounds like more buzz words\n'
 ' \n'
 'privacy by design means create a system the works flawlessly the first time '
 'without any bugs, with no security holes, across architectures and hardware '
 'etc.  Cohen himself admits its really hard to actually do\n'
 ' \n'
 'and even if it could be done\n'
 'the ethics and morality to show some of the data to an user can be sold to '
 'improve the system is challenged by many including the CEO of Palantir who '
 'funded the founding of palantir\n'
 ' \n'
 "in reality Cohen's company Palantir sells the data it collects along with "
 'other user data to large companies like amazon and f

In [129]:
p = """Palantir has offices in which cities?"""
o = model.generate(p, max_new_tokens=220)
pprint(o)

  0%|          | 0/220 [00:00<?, ?it/s]

('Palantir has offices in which cities?\n'
 '\n'
 'Palantir has offices in which cities?\n'
 '\n'
 'Palantir has offices in Zug, Switzerland; Mountain View, California; SFO, '
 'San Francisco; Berlin, Germany; Shanghai, China; London, England; Tokyo, '
 'Japan; Sydney, Australia; Dublin, Ireland; Budapest, Hungary; and Munich, '
 'Germany. Its satellite office in Singapore supports operations in Indonesia, '
 'the Philippines, and Malaysia.\n'
 'Related posts:<|endoftext|>')


In [130]:
from transformer_lens.utils import get_corner, gelu_new, tokenize_and_concatenate

In [131]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train")





done one
Number of batches: 937
done


In [140]:
batch_size = 8
num_epochs = 1
max_steps = 1
log_every = 1
lr = 1e-3
weight_decay = 1e-2
overfitMax=1
#model_cfg = Config(debug=False, d_model=256, n_heads=4, d_head=64, d_mlp=1024, n_layers=2, n_ctx=256, d_vocab=reference_gpt2.cfg.d_vocab)


optimizer_copy = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)


print("done one")
losses = []

#print(dataset)
#print(dataset[0]['text'][:100])

tokens_dataset = tokenize_and_concatenate(dataset, model.tokenizer, streaming=False, max_length=cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)
data_loader = torch.utils.data.DataLoader(tokens_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
print("Number of batches:", len(data_loader))
#test_string_trained = "Hello world this is a test of overfitting"
print("done")
def lm_cross_entropy_loss(logits, tokens):
    # Measure next token loss
    # Logits have shape [batch, position, d_vocab]
    # Tokens have shape [batch, position]
    log_probs = logits.log_softmax(dim=-1)
    pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return -pred_log_probs.mean()
#loss = lm_cross_entropy_loss(demo_logits, test_tokens)
#print(loss)
#print("Loss as average prob", (-loss).exp())
#print("Loss as 'uniform over this many variables'", (loss).exp())
#print("Uniform loss over the vocab", math.log(demo_gpt2.cfg.d_vocab))
#test_string_trained = "Hello world this is a test of overfitting"
torch.set_grad_enabled(True)
test_string_trained = """
German police and security services say they are preparing for Ukrainian President Volodymyr Zelensky to visit Berlin this month, a trip scheduled for May 13-14, marking Zelensky's first such visit to Germany since the war with Russia began.

It's at the invitation of German Chancellor Olaf Scholz, with Scholz's office not yet having made the official announcement previewing the state visit. But the timing comes amid rising tensions between Kiev and Berlin, despite German Leopard II tanks now being transferred to Ukrainian forces.
"""
test_string_trained_tokens = model.to_tokens(test_string_trained, prepend_bos=True).cuda()
for epoch in range(num_epochs):
    #for c, batch in tqdm.tqdm(enumerate(data_loader)):
    c=0
    for c in range(0,overfitMax):
        c += 1
        if c > max_steps:
              break
        print("training loop ")
        
        #the_list = list(model.tokenizer.batch_decode(test_string_trained_tokens))
        #print(the_list)
        #tokens = batch['tokens'].cuda()
        tokens = test_string_trained_tokens
        for overfits in range(0,overfitMax):
          #print("overfitting on "+str(overfits))
          logits = model(tokens)
          loss = lm_cross_entropy_loss(logits, tokens)
          loss.backward()
          optimizer_copy.step()
          optimizer_copy.zero_grad()
          losses.append(loss.item())
          print(f"Step: {c}, Loss: {loss.item():.4f}")
          # if c % log_every == 0:
          #     print(f"Step: {c}, Loss: {loss.item():.4f}")
          #     #print("now testing model with:\n    " + test_string_test + "\n and expect:\n    "+test_string_expect)
          #     test_string_test = "abcdefghijkl"
          #     test_string_test_tokens = model.to_tokens(test_string_test, prepend_bos=True).cuda()
          #     test_string_test_out = test_string_test
          #     #pprint(list(enumerate(list(reference_gpt2_copy.tokenizer.batch_decode(reference_gpt2_copy.to_tokens(test_string_test_out)[0])))))
          #     the_list_two = list()
          #     pprint(list(enumerate(list(model.tokenizer.batch_decode(reference_gpt2_copy.to_tokens(test_string_test_out, prepend_bos=True)[0])))))
          #     for i in range(0,1):
                
                
          #       test_tokens = model.to_tokens(test_string_test_out, prepend_bos=True).cuda()
                
                
          #       demo_logits = model(test_tokens)
          #       #the_list = list(zip(reference_gpt2.to_str_tokens(test_string), reference_gpt2.tokenizer.batch_decode(demo_logits.argmax(dim=-1)[0])))
          #       test_string_test_out += model.tokenizer.decode(demo_logits[-1, -1].argmax())
          #       the_list_two = list(model.tokenizer.batch_decode(demo_logits.argmax(dim=-1)[0]))
          #       #print(str(i)+"th loop")
          #       #print(list(enumerate(the_list)))
          #       #test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())
              
          #     print(test_string_test)
          #     print("AND THE OUTPUT")
          #     print(test_string_test_out)
          #     #print("##")
          #     #print("".join(the_list))
          #     #print(list(enumerate(the_list_two)))
          #     pprint("###########################################################")

          

training loop 


RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 79.17 GiB total capacity; 77.68 GiB already allocated; 3.81 MiB free; 77.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF