# Ok, so we're going to implement the episodic learning algorithm I devised...

High level steps:
1. Sample 2N sequences from the dataset
2. Generate 2N * X insertions of thought tokens into each sequence.
3. Compute the difference in log likelihood of the sequence with and without the thought tokens
4. Remove the worst N sequences

Q: Could we do this on a per-token basis instead of a per-sequence basis?
i.e. could we insert groups of thought tokens and then remove the worst ones from the sequence as we go?




Some thoughts:
1. Say we're going to *strictly* insert 32 Thought Tokens per sequence.
2. First, we run forward pass, then the 32 locations with highest likelihood of predicting a thought token are selected.
3.  - In theory, one token being inserted will affect the likelihood of other tokens being inserted. This would influence doubles or triples for example...
    - Another method would be to 'run' through the sequence, inserting thought tokens one at a time as if the model were properly generating them as it went.
        - This is probably the *correct* way I should do things...


Idea:
1. Run 32 sequences through model
2. Find 32 first-candidates for a thought token
3. Insert thought tokens and run through model again
4. Compute the difference in log likelihood of the sequence with and without the thought tokens
5. Remove the worst 16 thought-tokens that were inserted
6. Repeat
---
So, for 16 forward passes, we'll have 16*16=256 thought tokens inserted in total, average 8 per sequence....

This phase we can do very quickly. If we're not masking, and not computing gradients, we can hit ~200k tokens per second processed, so we can do about 15 forward passes in a single second in theory.

Then, the actual foward-backward passes...
- We can use the pre-computed thought tokens as a strong baseline instead of an empty initial state.
- Choose depth to unroll, run forward pass, compute loss, compute gradients, update parameters.

---

Ok, this flow actually seems pretty good to me:
1. We're not throwing away sequences, only individual thought tokens, which means we're hopefully not gonna delete 'hard to learn' sequences.
2. By doing 'per-token', we're "reinforcing" good thought token predictions and punishing bad ones.
3. By running-through entire sequence for insertion, we're actually taking into account each token that's added.
4. Finally, by using a strong initialization for our actual forward-backward pass ones, while gradients aren't flowing super deeply backwards we're getting some of the benefit of many iterated thought tokens.


The two hard bits:
1. The 'fastfoward' generation algorithm, where we skip through but insert thought-tokens along the way. Gotta keep that performant but that's not going to be easy.
2. Choosing where to insert a thought token if the model really doesn't want to insert one. Ideally, we start by choosing the first location where a thought token is actually predicted as #1, but if that doesn't occur we still need to choose a thought token for somewhere....


In [None]:
# First, 2

In [53]:
import torch
from core.model import THOUGHT_TOKEN_ID
# choose a location to insert a thought token
THOUGHT_TOKEN_ID = 4

# hopefully with nograd this won't be too slow or memory intensive
@torch.no_grad()
def insert_a_thought_token(logits: torch.Tensor, tokens: torch.Tensor):
    sorted_probs, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    ranks = torch.where(sorted_indices == THOUGHT_TOKEN_ID)[1]
    # insert a thought token at the location of the lowest rank.
    index = torch.argmin(ranks)

    tokens_out = torch.cat([tokens[:, :index], THOUGHT_TOKEN_ID, tokens[:, index:]], dim=1)
    
    return tokens_out, index

In [69]:
@torch.no_grad()
def insert_a_thought_token(logits: torch.Tensor, tokens: torch.Tensor):
    batch_size, seq_len = tokens.size()
    sorted_probs, sorted_indices = torch.sort(logits, descending=True, dim=-1)

    # find the ranks of THOUGHT_TOKEN_ID for each element in the batch
    ranks = (sorted_indices == THOUGHT_TOKEN_ID).nonzero(as_tuple=True)[1]

    # insert a thought token at the location of the lowest rank.
    min_ranks, min_indices = torch.min(ranks, dim=-1)

    # build a tensor of THOUGHT_TOKEN_ID repeated along batch_size
    tokens_out = []
    for b in range(batch_size):
        index = min_indices[b]
        tokens_out.append(torch.cat([tokens[b, :index], THOUGHT_TOKEN_ID, tokens[b, index:]], dim=0))
    
    # pad sequences to the same length
    tokens_out = torch.nn.utils.rnn.pad_sequence(tokens_out, batch_first=True)
    
    return tokens_out, min_indices


In [68]:
tokens_original = torch.tensor([[0, 1, 3, 2, 3, 1, 0, 1]])
logits = torch.randn(1, 8, 5)

tokens_trial, index = insert_a_thought_token(logits, tokens_original)

tensor(0)


IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number

In [None]:
tokens_trial, index

(tensor([[4, 0, 1, 3, 2, 3, 1, 0, 1]]), tensor(0))

In [None]:
# the flow



    tokens_original = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

    logits_original, losses = model(tokens_original)

    episode_trial_tokens, episode_token_insertion_index = insert_a_thought_token(logits_original, tokens_original)

    logits_episode_trial, losses_trial = model(episode_trial_tokens)

    completion_perplexity_original = torch.exp(-torch.mean(losses_trial[episode_token_insertion_index:]))
    completion_perplexity_trial = torch.exp(-torch.mean(losses_trial[episode_token_insertion_index + 1:]))

    diff = completion_perplexity_trial - completion_perplexity_original




