In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt


MAX_SEQUENCE_LENGTH = 512

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def debug_print(*args, flag=False, **kwargs):
    if flag:
        print(*args, **kwargs)

# Pytorch's   positional encoding implementaiton
class DualPositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 512, max_thought_len: int = 4):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.d_embed = d_model

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)


        self.register_buffer('pe', pe)
        self.max_len, self.max_thought_len = max_len, max_thought_len
        self.thought_position_encoding = nn.Embedding(max_thought_len+1, self.d_embed)

    def forward(self, x: torch.Tensor, thoughts_taken: int, real_token_count: int) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        batch_size = x.size(0)
        thoughts_taken = thoughts_taken + 1

        # Reshape to put thoughts per token taken on the same dimension
        x = x[:,:real_token_count * thoughts_taken,:].view(batch_size,real_token_count, thoughts_taken, self.d_embed)
        # Add both kinds of embeddings
        x = x + self.pe[:,:real_token_count].unsqueeze(2)
        x = x + self.thought_position_encoding(torch.arange(thoughts_taken)).unsqueeze(0).unsqueeze(0)

        # Reshape and pad back to original size
        x = x.view(batch_size, -1, self.d_embed)

        padding_size = self.max_len - (real_token_count * thoughts_taken)
        if padding_size > 0:
            x = F.pad(x, (0, 0, 0, padding_size), mode='constant', value=0)

        return self.dropout(x)


class ThoughtCausalTransformer(nn.Module):
    def __init__(self,max_thought_len, max_sequence_length, num_layers, n_head=8, d_embed=768, feed_forward_dim=2048, dropout=0.1):
        super().__init__()
        self.max_sequence_length = max_sequence_length

        self.positional_encoding = DualPositionalEncoding(d_embed, dropout, max_sequence_length, max_thought_len)
        self.layer = nn.TransformerEncoderLayer(
            d_model=d_embed,
            nhead = n_head,
            dim_feedforward=feed_forward_dim,
            dropout=dropout,
            activation=F.gelu,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer=self.layer,
                                                 num_layers=num_layers,
                                                 norm=nn.LayerNorm(d_embed))


    def generate_thoughtsformer_mask(self, thoughts_taken, real_tokens):

      main_size = self.max_sequence_length
      block_size = thoughts_taken + 1
      n_tokens = real_tokens

      # Create the main tensor and block tensor
      causal_mask = torch.zeros((main_size, main_size))
      block_for_thought_sequence = torch.triu(torch.ones(block_size,block_size),diagonal=0)

      # List of starting indices for the diagonal blocks
      block_starting_idxs = torch.arange(n_tokens) * block_size

      for idx in block_starting_idxs:
          causal_mask[idx:idx+block_size, idx:idx+block_size] = block_for_thought_sequence
          causal_mask[idx, idx+1:n_tokens*block_size] = 1
      causal_mask = causal_mask.T == 0

      return causal_mask

    def generate_normal_causal_mask(self, *args):
      return torch.triu(torch.ones(self.max_sequence_length, self.max_sequence_length),diagonal=1)

    def forward(self, x, padding_mask, thoughts_taken, real_token_count):
      x = self.positional_encoding(x, thoughts_taken, real_token_count)
      causal_mask = self.generate_thoughtsformer_mask(thoughts_taken,real_token_count).to(x.device)
      output = self.transformer(x, mask=causal_mask, src_key_padding_mask=padding_mask)

      return output


class ThoughtsFormer(nn.Module):
    def __init__(self, thought_length, vocab_size, max_sequence_length, num_layers, n_head=8, d_embed=768, feed_forward_dim=2048, dropout=0.1, verbose=False):
        super().__init__()

        self.seq_len, self.d_embed = max_sequence_length, d_embed
        self.thought_length = thought_length
        self.transformer = ThoughtCausalTransformer(thought_length, max_sequence_length, num_layers, n_head, d_embed, feed_forward_dim, dropout)
        self.out = nn.Linear(d_embed, vocab_size)
        self.debug = verbose

    def forward(self, embeddings, padding_mask):
        self.token_positions = torch.where(padding_mask == False)[1]
        self.n_real_tokens = int(self.seq_len - torch.sum(padding_mask))

        original_embeddings = embeddings
        debug_print("original embeddings \n", embeddings, flag=self.debug)

        for thoughts_taken in range(self.thought_length+1):
          next_embeddings = self.transformer(embeddings, padding_mask, thoughts_taken, self.n_real_tokens)
          # next_embeddings = torch.arange(self.seq_len).view(1,-1,1).repeat([embeddings.size(0),1,d_embed])
          debug_print(next_embeddings, next_embeddings.shape, flag=self.debug)
          if thoughts_taken != self.thought_length: # Don't need to insert next thoughts if there's not going to be another iteration
            embeddings = self.insert_thoughts(next_embeddings, original_embeddings, thoughts_taken + 1)
          debug_print("updated embeddings\n", embeddings,  flag=self.debug)

          original_embeddings = embeddings

        return self.out(embeddings) # logits for everything

    def forward_ppo(self, embeddings, padding_mask, thoughts_taken):
        self.token_positions = torch.where(padding_mask == False)[1]
        self.n_real_tokens = int(self.seq_len - torch.sum(padding_mask))

        
        next_embeddings = self.transformer(embeddings, padding_mask, thoughts_taken, self.n_real_tokens)

        if thoughts_taken != self.thought_length: # Don't need to insert next thoughts if there's not going to be another iteration
          embeddings = self.insert_thoughts(next_embeddings, embeddings, thoughts_taken + 1)

        return embeddings # logits for everything


    def insert_thoughts(self, next_embeddings, original_embeddings, iter):

      debug_print("Debugging here ", self.n_real_tokens, self.token_positions.size(0), flag=self.debug)
      n_elements = self.token_positions.size(0) * iter
      n_element_next = self.token_positions.size(0) * (iter + 1)
      batch_size, seq_len, d_embed = original_embeddings.shape
      # we'll reshape and concat
      # to go from
      # 1, t            # 1, t, t
      # 2, t    --->    # 2, t, t
      # 3, t.flatten()  # 3, t, t.flatten()

      original_embeddings = original_embeddings[:,:n_elements,:].view(batch_size,-1, iter, d_embed)


      # This gets the positions of the next tokens to predict - 1, so right before the tokens that are being predicted
      next_token_positions = (torch.arange(self.token_positions.size(0)) + 1) * iter - 1
      next_embeddings = next_embeddings[:, next_token_positions,:]

      # Reshapes the embeddings so they can be concatenated like in the previous diagram
      next_embeddings = next_embeddings.view(next_embeddings.size(0), next_embeddings.size(1), 1, next_embeddings.size(2))

      #Concatenates and reshapes back
      final_embeds = torch.cat((original_embeddings,next_embeddings),dim=2)
      final_embeds = final_embeds.view(batch_size,-1,d_embed)
      debug_print("final embedding shape", flag=self.debug)
      debug_print(final_embeds.shape, seq_len, n_element_next, flag=self.debug)
      padding = torch.zeros(final_embeds.size(0), seq_len-n_element_next, final_embeds.size(2))
      final_embeds = torch.cat((final_embeds, padding),dim=1)

      self.token_positions = self.get_next_token_count(self.token_positions)

      return final_embeds



    def get_next_token_count(self, token_positions):
      '''
      Updates the internal token_positions variable. Assumes each thought train will have the same length.
      '''
      return token_positions + torch.arange(self.token_positions.shape[0])