# Walking through some thinking here about a concept of "Dense-Tokens"

Primary goal is we want to introduce a concept of "Dense Tokens" where a model can pass a token between forward passes that is *not* a one-hot encoded vector.

The primary goal is to figure out how to train such a model effectively and in an unsupervised manner if possible.



## Approach:
---

1. Phase 1: Generate a batch of training examples:
    1. Sample N different sequences from the corpus.
    2. Randomly insert X "dense" tokens into each sequence, initializing them to be empty.
    3. Run a forward pass of the model for each of these dense tokens to generate a candidate dense token.
    4. Keep the subset N' of the sequences where the dense token decrease the perplexity of the sequence the most.

2. Phase 2: Train the model:
    1. Run a single forward pass of the model on the N' sequences (using teacher forcing style, where the N' tokens are the targets).
    2. Calculate loss
        loss = normal_cross_entropy_loss + (lambda * dense_token_loss)
         ## -> Goal here is to move generated token A closer to target token B
        dense_token_loss = max P(sequence | X)
         ## -> Goal here is to move generated token A closer to theoretical best token B that minimises the perplexity of the sequence
         ## What is token B? We don't know what token B is, but maybe we can work out which direction token B is in????????
         ## Token B is the token that minimises the perplexity of the sequence.

---



What is the dense token loss?

Perplexity distillation says we could:
1. calculate the perplexity of the entire sequence


dense_token_loss = -P(entire_sequence | dense_token_a)
normal_loss = -P(token_a | tokens_before_a) - P(token_b | tokens_before_b) - P(token_c | tokens_before_c) ...

So in normal loss we're trying to predict token A given a sequence of tokens before A.
In dense token loss, we're trying to predict an entire sequence given a single token A.
So it's sort of the inverse?



# normal_loss = - (1/N) * Σ_{i=1}^{N} log(P(token_i | tokens_before_i))
# P(token_i | tokens_before_i) = softmax(M(tokens_before_i))[token_i]


# dense_token_loss = -P(tokens_after_i | dense_token_i)
                   = P(token_after_i_1 | dense_token_i) * P(token_after_i_2 | dense_token_i) * P(token_after_i_3 | dense_token_i) * ...
                  -> we do a full forward pass and need to know entire input sequences probability before we can condition
                     on the dense token ???????

 we're strictly trying to nudge the dense token such that the rest of the sequence is more likely.

 I'm a little stuck here, but I'm ALMOST thinking that our normal cross entropy loss might do this for us??
 AHH wait a second, thinking about how autograd works....
   what we want to do is:
     1. Perform a forward pass to calculate the dense token (while tracking gradients)
     2. (freeze params) Perform a *second* forward pass to calculate probability of rest of the sequence once we have the dense token
     3. Backprop through these two layers (basically unrolling the model)
     4. THEN modify the first model...

Why freeze the params? We don't want to modify our models ability to predict the second set of tokens at that stage (at least in this mental model, it might be possible to merge these two steps)

----

Flow2.0:
- Generate dense token (while tracking gradients)
- Calculate loss of sequence given dense token
- Backprop and optimize

(so main leap is that we need to be tracking gradients when we generate the dense tokens, as we use that to modify generation of the dense token itself)

question: say we fill in X dense tokens in the sequence, do we need to perform this step X times or can we perform it once given we're only doing a layer depth of 1?
we're sort of expecting then that the model can make use of a dense token without multiple recurrances, which might be smart?
Maybe we should be training it just like a good old fashion normal RNN, but just the control flow is decided on by the model rather than the architecture??????????? :mindblown:

---







We want to modify the dense tokens to be better. i.e. we don't just purely want to ask the model to predict the dense tokens, as we're expecting it to do that given it literally just generated that set of dense tokens.

Instead, we want to MODIFY the dense tokens.
That means we want to pose this situation:
- Here is dense token X
- Here is the rest of the sequence after X
- What is the best way to modify X to make the sequence after X more likely?
- We want to do max P(sequence | X)





Now the question I've got, given we're giving the model input signal from the future, is, is that cheating?

Maybe? The intention is that the model is asked to think about the future and look ahead rather than just looking at the past...
 The main point is, the dense token is generated to reduce the future perplexity of the sequence. We modify the model such that it will produce
 a dense token most helpful in predicting the future of the sequence.

 The model autoregressively generates a dense token, then the model generates an output sequence given the dense token.
 During training then, we ask the model to specifically generate dense tokens that are more likely to help that future auto-regressive generation.

(look at UL2 training objective)





Interesting point: outer vector will be size of vocab, but that will get projected down to the embedding layer size of 768 immediately. So we could actually avoid storing a 30k vector for each dense token, and just store a 768 vector for each dense token (or 6144 for a large model )

Thoughts from overnight:
- Yes, we should be able to train with multiple dense tokens in a single pass... but the ability for tokens to build off eachother will be dependent on depth of the tree.
- Q: How do we deal with softmax obliterating some inputs, what do we do to avoid that...
- Q: Can I think of a more efficient training objective. This one requires potentially holding K optimizer states in memory where K is the recursive depth.
- Q: Can we get this to perform better than the encoding stage of an encoder????

# Experiment 1

- Have a single training example with a dense token initialized to be empty
- Run a forward pass to generate a dense token, then run a forward pass to generate the rest of the sequence
    - Calculate loss of entire model, and backprop through both layers

Goal is to optimize a single dense token for a single example, such that that dense token improves generation of the rest of the sequence. Ideally we freeze everything else and see if the dense token can be optimized to improve the rest of the sequence in isolation.

# Experiment 2
Expand that to a batch of examples with multiple dense tokens, still only a tree depth of 1 though.
# Experiment 3
Expand that to a batch of examples with multiple dense tokens, and a tree depth of 2.

# Experiment 4
Implment random deletion of useless dense tokens, so we only train on the dense tokens that are actually useful.




In [1]:
from model import GPT
from transformers import GPTNeoXTokenizerFast
model = GPT.from_pretrained('EleutherAI/pythia-70m')
tokenizer = GPTNeoXTokenizerFast.from_pretrained('EleutherAI/pythia-70m')

loading weights from pretrained GPTNeoX: EleutherAI/pythia-70m
number of parameters: 70.43M


In [2]:
tokenizer.vocab_size, len(tokenizer), tokenizer.decode([50300])


# Vocabsize is original vocab
# len(tokenizer) is vocab + additional tokens
# tokenizer.decode([50304]) are special tokens

(50254, 50277, '')

In [3]:
tokenizer.add_tokens(['<|dense|>'])

tokenizer.vocab_size, len(tokenizer), tokenizer.encode('<|dense|>')

(50254, 50278, [50277])

We're trying to get a dense token into the model and ideally have the model make use of that dense token somehow.

So, what if the input to the transformer is:
- inputids = [1,2,4,5, 50000] # 50000 is ID of dense token
- dense    = [[],[],[],[],[0.1,0.2]] 

- inputids will map 50000 to a dense token embedding
- we'll add dense to that dense token embedding

For output:
- outputids = [[], [], [], [], []]
- dense = outputidsA iff outputidsA is a dense token by softmax?

iff can't be optimized though... 



So somehow the model needs to know that if it outputs softmax(50000), then it'll get to output the entire dense token and softmax won't kill it's result...

Mask seems like the right idea...

# dense = softmax(logits) * logits

#LM head multiplies out to size 50000, which we don't actually want....
# we want the little 768 vector to comeout the other side.

So really, we want to at inference time have the model decide if it wants to keep the dense token or not.
We can have it output every single dense token, then keep only the one's designated by the softmax.


----

Ok, so we only want to keep the dense token if our sampling strategy chooses the dense token right?

Thinking:
- If we train the model with cross-entropy-loss on output tokem, the model will learn to output the dense token ID
 Simultaneously:
- We only pass the dense token into the model for times when dense token is sampled
  - Given we use teacher forcing in training, that will be always, and for inference it might be fancy or some shit



In [5]:
import torch
dense_token_id = tokenizer.encode('<|dense|>')[0]

example = "This is my training example,<|dense|> and I want to see how it works."
tokenized = tokenizer.encode(example, return_tensors='pt')[0]

dense_mask = (tokenized == dense_token_id).long()

dense = torch.rand((dense_mask.shape[0], model.config.n_embd))
dense = dense * dense_mask[:, None]


In [22]:

example = "This is my training example,<|dense|> and I want to see how it works."

weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
learning_rate = 6e-4 # max learning rate
device_type = 'cpu'
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)


tokenized = tokenizer(example, return_tensors='pt')
length = tokenized['input_ids'].shape[1]
X = tokenized['input_ids'][:, :-1]
Y = tokenized['input_ids'][:, 1:]
print(X)
dense_mask = (X == dense_token_id).long() # B, T, 1
print(dense_mask)


for i in range(10):
    
    # B, T, n_embd
    init_dense = torch.zeros((1, X.shape[0], model.config.n_embd))
    
    # calculate first pass through model to get dense layers.
    _, dense, _ = model(X, init_dense, Y)
    
    # apply mask    
    dense = dense * dense_mask[:, :, None] # B, T, n_embd
    
    # calculate loss
    _, _, loss = model(X, dense, Y)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss)



using fused AdamW: False
tensor([[ 1552,   310,   619,  3733,  1650,    13, 50277,   285,   309,   971,
           281,   923,   849,   352,  2987]])
tensor([[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([ 7.6724e-01, -4.4579e+00,  1.9110e-01,  5.4319e-01, -7.6342e+00,
        -1.9871e+00, -6.2796e-01,  4.0221e+00,  4.0577e+00, -2.3395e-01,
         5.1495e+00, -2.8622e+00, -5.8175e+00, -3.3946e+00,  2.3529e+01,
        -1.2671e+00, -1.5010e+01, -1.3828e+01, -1.9582e+00, -6.4407e-01,
         6.0676e+00,  2.2795e+01,  1.5882e+00,  1.5995e+00,  3.7940e+00,
        -3.3146e+00, -3.2365e+01,  3.5258e+00,  6.6679e+00,  4.2526e-01,
         1.1754e+01,  8.3165e+00, -6.6724e+00,  1.0731e+01,  4.9631e+00,
        -2.0929e+00,  9.3262e+00, -5.8444e+00, -1.4233e+00,  1.7206e+00,
        -2.6799e+00, -7.0065e-02, -9.5309e+00, -1.8553e+00,  9.5008e+00,
         4.8936e+00,  1.4466e+01,  3.9460e+00, -3.4502e+00,  4.0477e+00,
        -2.1790e+00, -4.1906e-01, -9.7150e+00,  2.3500e+00, -3.8

In [29]:
inputs = tokenizer.encode("This is my training,", return_tensors='pt')
dense = torch.zeros((1, inputs.shape[1], model.config.n_embd))

logits, dense, loss = model(inputs, dense, inputs)


import torch.nn.functional as F
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next


tensor([[50277]])

# Question to answer
How do I know it's working? Is there a minimally reproducable result??

So I want to determine how I would know if this helped.

Thought 1:
- I need to figure out how to make loss scores comparable...
- I want to be able to run a validation set through the model, and know how well it worked.
- Unclear what sort of toy problem I could apply this all to such that I know that answer.

---

I mean, goal is still obviously to autoregressively predict the next word. In that case, perplexity should still be a comperable metric as the tokenization is the same right?

Well, obviously at the outset, the dense tokens will be confusing meaning P(Y|X) will be lower as the input seems out of distribution. But after some training, the perplexities should be comperable right?
---


What if I did something like this:
- Take a dataset of Q&A pairs, especially one where Chain Of Thought reasoning seems to perform well.
- Train two models:
  - Model A is vanialla GPT, with sequences like `Q: What is the meaning of life? A: 42`
  - Model B is GPT with dense tokens, with sequences like `Q: What is the meaning of life? <|dense|><|dense|> A: 42`

Is that a fair comparision? Dense model is probably going to use more compute, but that's ok potentially if it's better by some metric.

Let's use the Hellaswag dataset for this.

# Aim is to make the model smarter by teaching it to think about the context of the question more.

We're going to get setup on lambda labs real quick, where I'm going to train the model to predict (1,2,3,4)

We're going to see if a double pass through the network allows it to 'think' more, and make better predictions of 1,2,3,4.

---

Question: does this best test the models ability here? Is it proof?
Well, surely.

Hypothesis: the model will be able to learn to predict 1,2,3,4 better if it's given the opportunity to think about the context of the question more.

Possible failure points:
- The model won't get enough time to learn how dense tokens work, and will just get confused by them.
- The model won't use the dense tokens at all.
- Training is unstable with dense tokens, and we can't get a good test.

Hellaswag turned out to be too difficult for the model to learn. It never breached 25% (AKA random chance) accuracy, so I've switched to OpenBookQA.

We were able to get pythia-70m to about 30% accuracy.
I'm going to switch to pythia 160m though now, as it has 4x as many non-embedding parameters which should be a big improvement in performance in theory, which in these Q's is partially logic partially memory.

I want a more reliable >30% result before I start adding dense tokens.

---
Still unable to get decent results. Going to reduce batchsize to 1 and see if that helps, as maybe my broken padding stuff is a problem?


--
Ok, still no luck. Maybe a maths dataset instead?

GSM8K looks hard but achievable...