https://github.com/hkproj/mistral-llm-notes
    
    
https://github.com/neobundy/Deep-Dive-Into-AI-With-MLX-PyTorch/blob/master/deep-dives/001-mistral-7b/README.md


https://github.com/DongmingShenDS/Mistral_From_Scratch


In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Sliding Window attention - Receptive Field

Initially, the sequence is a list of sets, all containing a single token.

Layer 1 input:

    0: ['the']
    1: ['cat']
    2: ['is']
    3: ['on']
    4: ['a']
    5: ['chair']

After the first layer, considering a sliding window size of 3, the output of the attention mechanism is:

Layer 1 output:

    0: ['the']
    1: ['the', 'cat']
    2: ['the', 'cat', 'is']
    3: ['cat', 'is', 'on']
    4: ['is', 'on', 'a']
    5: ['on', 'a', 'chair']

The output of the first layer becomes the input of the second layer. The output of the second layer is:

Layer 2 output:

    0: ['the']
    1: ['the', 'cat']
    2: ['the', 'cat', 'is']
    3: ['the', 'cat', 'is', 'on']
    4: ['the', 'cat', 'is', 'on', 'a']
    5: ['cat', 'is', 'on', 'a', 'chair']

As we can see, even with a sliding window of size 3, after just two layers, the attention mechanism can capture long-range dependencies. This is because the output of the first layer is used as the input of the second layer, and the attention mechanism is applied again. This is similar to the idea of stacking multiple layers of CNNs to increase the receptive field.

In [None]:
print_order = ['the', 'cat', 'is', 'on', 'a', 'chair']
sequence = [{print_order[i]} for i in range(len(print_order))]
sequence

[{'the'}, {'cat'}, {'is'}, {'on'}, {'a'}, {'chair'}]

In [None]:
sliding_window_size = 3

def sliding_window_attention(seq: list[set[str]], w: int):
    seq_len = len(seq)
    attention_scores: list[list[set]] = [[None for _ in range(seq_len)] for _ in range(seq_len)]
    for i, q_tokens_set in enumerate(seq):
        for j, k_tokens_set in enumerate(seq):
            # The upper triangle is all None
            if j > i:
                continue
            # Each token can only attend to the previous W tokens
            if i - j >= w:
                continue

            attention = set()
            # Add all tokens from q_tokens_set to attention_result
            attention.update(q_tokens_set)
            # Add all tokens from k_tokens_set to attention_resul
            attention.update(k_tokens_set)

            attention_scores[i][j] = attention
    return attention_scores

def multiple_by_v(attention_scores: list[list[set]], v_sequence: list[set[str]]) -> list[set[str]]:
    seq_len = len(v_sequence)
    result = [set() for _ in range(seq_len)]
    for i in range(seq_len):
        for j in range(seq_len):
            attention = attention_scores[i][j]
            v = v_sequence[j]
            r = result[i]
            # Add all the tokens in the attention (if not None) to r
            if attention is not None:
                # Add all the tokens in v to r
                r.update(v)
                r.update(attention)
    return result

def print_attention(attention_scores: list[list[set[str]]]):
    for i, row in enumerate(attention_scores):
        for j, attention in enumerate(row):
            if attention is None:
                print('None', end='\t')
            else:
                print(f'{sorted(attention, key=lambda x: print_order.index(x))}', end='\t')
        print()

def print_sequence(seq: list[set[str]]):
    for i, tokens_set in enumerate(seq):
        print(f'{i}: {sorted(tokens_set, key=lambda x: print_order.index(x))}')

def print_layer(input: list[set[str]], layer_num: int) -> list[set[str]]:
    print(f'Layer {layer_num} input:')
    print_sequence(input)
    attention_scores = sliding_window_attention(input, sliding_window_size)
    print()
    print(f'Layer {layer_num} attention scores:')
    print_attention(attention_scores)
    output = multiple_by_v(attention_scores, input)
    print()
    print(f'Layer {layer_num} output:')
    print_sequence(output)
    return output

In [None]:
# Layer 1
output_layer_1 = print_layer(sequence, 1)

Layer 1 input:
0: ['the']
1: ['cat']
2: ['is']
3: ['on']
4: ['a']
5: ['chair']

Layer 1 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['cat']	None	None	None	None	
['the', 'is']	['cat', 'is']	['is']	None	None	None	
None	['cat', 'on']	['is', 'on']	['on']	None	None	
None	None	['is', 'a']	['on', 'a']	['a']	None	
None	None	None	['on', 'chair']	['a', 'chair']	['chair']	

Layer 1 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['cat', 'is', 'on']
4: ['is', 'on', 'a']
5: ['on', 'a', 'chair']


In [None]:
# Layer 2
output_layer_2 = print_layer(output_layer_1, 2)

Layer 2 input:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['cat', 'is', 'on']
4: ['is', 'on', 'a']
5: ['on', 'a', 'chair']

Layer 2 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['cat', 'is', 'on', 'a']	['is', 'on', 'a']	None	
None	None	None	['cat', 'is', 'on', 'a', 'chair']	['is', 'on', 'a', 'chair']	['on', 'a', 'chair']	

Layer 2 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['cat', 'is', 'on', 'a', 'chair']


In [None]:
# Layer 3
output_layer_3 = print_layer(output_layer_2, 3)

Layer 3 input:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['cat', 'is', 'on', 'a', 'chair']

Layer 3 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	None	
None	None	None	['the', 'cat', 'is', 'on', 'a', 'chair']	['the', 'cat', 'is', 'on', 'a', 'chair']	['cat', 'is', 'on', 'a', 'chair']	

Layer 3 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['the', 'cat', 'is', 'on', 'a', 'chair']


In [None]:
!pip install -q xformers

In [None]:
from xformers.ops.fmha.attn_bias import (
    AttentionBias,
    BlockDiagonalCausalMask,
    BlockDiagonalCausalWithOffsetPaddedKeysMask,
    BlockDiagonalMask,
)

import pandas as pd

  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")


In [None]:
col_dict = {0.0: '#90EE90', float('-inf'): '#FA8072'}
def colour_cell(val):
    if val in col_dict:
        return 'Background-color: %s' % col_dict[val]
    return ''


sentences = [
    "The cat sat on the mat and purred.",  # 7 words
    "The dog ran fast today.",              # 5 words
    "The quick brown fox jumps over."      # 6 words
]

def get_flattened_words(sentences):
    words = []
    for sentence in sentences:
        words.extend(sentence.split())
    return words

# flattened list of words
labels = get_flattened_words(sentences)

In [None]:
## BlockDiagonalCausalMask

seqlens = [7, 5, 6]
sliding_window_size = 3

mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(sliding_window_size)

batch_size = 1
total_seq_len = sum(seqlens)
mask_tensor = mask.materialize((batch_size, total_seq_len, total_seq_len))

df = pd.DataFrame(mask_tensor[0, :, :].numpy())
df.style.applymap(colour_cell)

  df.style.applymap(colour_cell)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
1,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
2,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
3,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
4,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
5,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
6,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
7,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
8,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
9,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf


In [None]:
## BlockDiagonalMask

q_seqlens = [3, 5]
kv_seqlens = [10, 8] # (3 + 7, 5 + 3)
sliding_window_size = 3

mask = BlockDiagonalMask.from_seqlens(q_seqlens, kv_seqlens).make_local_attention_from_bottomright(sliding_window_size)

batch_size = 1
total_seq_len = sum(q_seqlens)
total_kv_seq_len = sum(kv_seqlens)
mask_tensor = mask.materialize((batch_size, total_seq_len, total_kv_seq_len))

df = pd.DataFrame(mask_tensor[0, :, :].numpy())
df.style.applymap(colour_cell)

  df.style.applymap(colour_cell)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
1,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
2,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
3,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf,-inf
4,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf,-inf
5,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf,-inf
6,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,-inf
7,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0


In [None]:
## BlockDiagonalCausalWithOffsetPaddedKeysMask

# We use this mask with padding because the overall size of the KV-Cache is the same for all the prompts, but for each KV-Cache we may need to use only some of the items.

q_seqlen = [1, 1]
kv_seq_len = [3, 5]
kv_padding = 6

mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(q_seqlen=q_seqlen, kv_padding=kv_padding, kv_seqlen=kv_seq_len)

batch_size = 1
total_seq_len = sum(q_seqlen)
total_kv_seq_len = kv_padding * len(kv_seq_len)

mask_tensor = mask.materialize((batch_size, total_seq_len, total_kv_seq_len))

df = pd.DataFrame(mask_tensor[0, :, :].numpy())
df.style.applymap(colour_cell)

  df.style.applymap(colour_cell)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,0.0,0.0,0.0,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf,-inf
1,-inf,-inf,-inf,-inf,-inf,-inf,0.0,0.0,0.0,0.0,0.0,-inf


```python

llama_model.py --> without use of fairscale parallelism in Attention mechanism

Transformer
==================================
1. nn.Embedding(vocab,dim) : token_embeddings
2. nn.ModuleList(TransformerBlock(args) X n_layers) # different LLM architecture implemented
3. RMSNorm(dim,eps)
4. nn.Linear(dim,vocab) : output
5. precompute_freqs_pos_frequencis(dim//n_heads,max_seq*2,device,rope_theta) : freqs_complex # per attention head
# i.e dim // n_heads = head_dim , for Llama rope_theta is not passed

@property # helpful for training-inference device switch error
dtype() : return data type of parameter of the model
device() : device on which model parameters are stored


forward()
----------------------------------------------
"""
note that with the KV Cache, only need the latest tokens, no need all tokens: info about previous tokens are saved in the cache
NOTE: this is only for inference, not training (in training there's no KV cache)
"""
1. get hidden_state = token_embeddings(tokens) # (B,seq) -> (B,seq,dim) [1 token at a time --> map to dim]
2. retrive pairs(m,theta) corresponding to position [start_pos, start_pos + seq_len] from freqs_complex
3. hidden_states = layer(hidden_states,start_pos,freqs_complex) for layer in layers
# apply precomputed frequencies to the encoding layers for positional encoding , each layer is (Nx transformer blocks)
4. apply RMSNorm on combined hidden_states from step 3
5. pass through output


TransformerBlock
===============================================================
1. head_dim = dim//n_heads
2. SelfAttention(args) : attention # Decoder only with causal attention (only work for inference)
3. FeedForward(args) : feed_forward
4. RMSNorm(dim,eps) : attention_norm # RMSNorm before attention
5. RMSNorm(dim,eps) : ffn_norm # RMSNorm befor feed_forwardd


forward()
-----------------------------------------------------------------
1. hidden_states = x + attention.forward(attention_norm(x),start_pos,freqs_complex)
# (B, seq_Len, dim) + (B, seq_Len, dim) => (B, seq_Len, dim) we're dealing
# with 1 token at a time start_pos : current token we're dealing
2. out = hidden_states + feed_forward(ffn_norm(hidden_states)) # (B, seq_Len, dim) + (B, seq_Len, dim) => (B, seq_Len, dim)
3. return out

```

In [None]:
####

In [None]:
####


# Transformer

```python

1. nn.Embedding(vocab,dim) : token_embeddings
2. nn.ModuleList(TransformerBlock(args) X n_layers) # different LLM architecture implemented
3. RMSNorm(dim,eps)
4. nn.Linear(dim,vocab) : output
5. precompute_freqs_pos_frequencis(dim//n_heads,max_seq*2,device,rope_theta) : freqs_complex # per attention head
# i.e dim // n_heads = head_dim , for Llama rope_theta is not passed

@property # helpful for training-inference device switch error
dtype() : return data type of parameter of the model
device() : device on which model parameters are stored


forward()
----------------------------------------------
"""
note that with the KV Cache, only need the latest tokens, no need all tokens: info about previous tokens are saved in the cache
NOTE: this is only for inference, not training (in training there's no KV cache)
"""
1. get hidden_state = token_embeddings(tokens) # (B,seq) -> (B,seq,dim) [1 token at a time --> map to dim]
2. retrive pairs(m,theta) corresponding to position [start_pos, start_pos + seq_len] from freqs_complex
3. hidden_states = layer(hidden_states,start_pos,freqs_complex) for layer in layers
# apply precomputed frequencies to the encoding layers for positional encoding , each layer is (Nx transformer blocks)
4. apply RMSNorm on combined hidden_states from step 3
5. pass through output

```

# TransformerBlock

```python


""" a single transformer block (different for Llama & Mistral) """
1. head_dim = dim//n_heads
2. RMSNorm(dim,eps) : rms_norm # before attention & feed_forwardd
2. SelfAttention(args) : attention # Decoder only with causal attention (only work for inference)
3. MOE(experts=[FeedForward(args) X moe.n_experts,gate=nn.Linear(dim,moe.n_experts),args.moe]
# Feed Forward Layer (with MoE support) otherwise FeedForward(args) : feed_forwad


forward()
-----------------------------------------------------------------
1. hidden_states = x + attention.forward(rms_norm(x),start_pos,freqs_complex)
# (B, seq_Len, dim) + (B, seq_Len, dim) => (B, seq_Len, dim) we're dealing
# with 1 token at a time start_pos : current token we're dealing
2. out = hidden_states + feed_forward(rms_norm(hidden_states)) # (B, seq_Len, dim) + (B, seq_Len, dim) => (B, seq_Len, dim)
3. return out


```

# RMSNorm
```python

"""same for both"""
1. nn.Parameter(torch.ones(dim)) : weight
# gamma(g) parameter trainable to perform rescaling on the norm

_norm()
-------------------------------------------------------------------------
# RMSNorm stat : (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
1. return x * 1/rms

forward()
-------------------------------------------------------------------------
# (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
# auto-broadcasting expands (Dim) to (1, 1, Dim) to multiplied to the last dimension of (B, Seq_Len, Dim)
# recall: Automatic broadcasting in PyTorch occurs when dimensions match or are broadcastable starting from the trailing dimensions (i.e., from right to left)
1. weight * _norm(x) # typecast to float & verify if it's same type as x

```

# FeedForward

```python

"""same for both but Llama uses multiple_of parameter for hidden_dim"""
1. hidden_dim = multiple_of * ((int(hidden_dim) + multiple_of - 1) // multiple_of
# round hidden_dim to the nearest multiple of the args.multiple_of parameter (bigger or equal)
# just a design choice to look cool :)
1. gate : w1(dim,hidden_dim) , up : w2(hidden_dim,dim), down : w3(dim,hidden_dim)

forward()
---------------------------------------------------------------
1. xw1 = w1(x)
# (S(XW1) * XV)XW2 = (ss)XW2; goal shape: (B, seq_len, Dim) => (B, seq_len, Dim)
# (B, seq_len, Dim) w1=> (B, seq_len, Hidden_Dim)
2. sxw1 = silu(xw1)
# (B, seq_len, Hidden_Dim) => (B, seq_len, Hidden_Dim) # silu
3. xv = w3(x)
# (B, seq_len, Dim) w3=> (B, seq_len, Hidden_Dim)
4. sxw1xv = sxw1 * xv
# (B, seq_len, Hidden_Dim) * (B, seq_len, Hidden_Dim) = (B, seq_len, Hidden_Dim) = element wise multiplication
5. return w2(sxw1xv)
# (B, seq_len, Hidden_Dim) w2=> (B, seq_len, Dim)


```

# MOE
```python
"""the only difference in Mixtral MOE: after attention, instead of RMS=>MLP, it has RMS=>MOE"""
1. nn.ModuleList(experts) : experts
2. gate , moe_args

forward()
------------------------------------------------------------------------------------
"""NOTE: in the mistral paper, all input/output size used are (B * seq_len, Dim) instead of (B, seq_len, Dim)"""
# goal shape: (B, seq_len, Dim) = > (B, seq_len, Dim)
1. flat_x = reshape input shape to (B * seq_len, Dim)
2. gate_logic = gate(flat_x)  # recall gate is linear with (Dim, n_experts)
# (B * seq_len, Dim) gate=> (B * seq_len, n_experts) for each input token
3. weights, selected_experts = torch.topk(gate_logits, self.moe_args.n_experts_per_tok)
# Get the top k experts for each input token, using torch.topk
# weights=logits, selected_experts=indices
# (B * seq_len, n_experts) => (B * seq_len, n_experts_per_tok)
4. weights = F.softmax(weights, dim=1, dtype=torch.float).to(x.dtype)
# # Normalize the weights with softmax, to get the selected top k experts' weights on the tokens
5. results = torch.zeros_like(flat_x)
# init results: (B * seq_len, Dim)
6. # Iterate over each expert to compute the weighted sum of the outputs from each selected top k experts,
for i, expert in enumerate(self.experts):
    # for each expert: retrieves only the batch_idx & selected_exp_idx this expert is responsible for
    batch_idx, selected_exp_idx = torch.where(selected_experts == i)
    # (K, Dim) => (K, Dim), where K is how many tokens this expert is responsible for
    expert_out = expert(flat_x[batch_idx])  # recall expert is FFN with Dim=>Dim
    # (K, 1), where K is how many tokens this expert is responsible for
    expert_w = weights[batch_idx, selected_exp_idx, None]
    # add the experts' weighted sum output to the corresponding tokens
    # expert_w * expert_out: (K, 1) * (K, Dim) => (K, Dim)
    # results: still (B * seq_len, Dim), where the corresponding tokens are updated
    results[batch_idx] += expert_w * expert_out
7. results = results.view(B, seq_len, dim)    
# reshape results: (B * seq_len, Dim) => (batch_size, seq_len, dim)
8. return results

```

# SelfAttention

```python
"""Decoder only with causal attention (only work for inference)
only care about current token and its corresponding attention (with support from the KV Cache)
Extended support for GQA (grouped query attention)"""

1. n_kv_heads : n_heads (for MHA) , n_kv_heads (for GQA)
2. n_heads_q = n_ heads # no. of Q heads should be n_heads
3. n_rep = n_heads_q // n_kv_heads
4. head_dim = dim / n_heads # part of embedding each head responsible for)
5. attn_window = if not specified max_seq_len
6. wq = (dim,n_heads_q * head_dim) , wk,wv = (dim,n_kv_heads * head_dim), wo = (n_heads_q * head_dim,dim)
7. kv_cache = RollingBufferKVCache(max_batch_size, attn_window, n_kv_heads, head_dim)
# KV Cache with support of Sliding Window Attention & Rolling Buffer Cache
# this is modified from the Llama implementation which does not support rolling buffer


repeat_kv
------------------------------------------------------------------------------
# in GQA, each Q group shares the same KV heads, thus just repeat KV heads for the Q in the same group
# goal shape: (B, prefix_seq_len, n_kv_heads, Head_Dim) => (B, prefix_seq_len, n_heads_q, Head_Dim)
1. n_rep == 1 then return kv # (Q and KV are 1-to-1 (just a normal MHA))
2. return # for GQA
kv[:, :, :, None, :]
## (B, prefix_seq_len, n_kv_heads, 1, Head_Dim)
.expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
.reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim) #just copy n_rep times
#n_kv_heads * n_rep = n_heads_q

forward
-------------------------------------------------------------------------------
1. get Q,K,V : xq,xk,xv using wq,wk,wv
# (B, 1, Dim) => (B, 1, n_heads_q * Head_Dim)
2. reshape Q K V to get individual single heads (Qi, Ki, Vi)
# (B, 1, n_heads_q * Head_Dim) => (B, 1, n_heads_q, Head_Dim)
3. apply_rotary_embeddings on xq,xk , Q,K must have same shape before & after RoPE
# (B, 1, n_heads_q, Head_Dim)
4. kv_cache.update_cache(xk, xv, batch_size, start_pos)
# replace the entry in the KV cache's respective position (aka update KV Cache)
# fill (:B, idx) part of the (max_B, max_seq_len, n_kv_heads, Head_Dim) cache with (B, 1, n_kv_heads, Head_Dim)
5. keys, values = kv_cache.retrieve_cache(batch_size, start_pos + seq_len)
  # retrieve complete K and V from KV Cache
  # (B, prefix_seq_len, n_kv_heads, Head_Dim)
6. keys, values =  repeat_kv(keys),repeat_kv(values)     
# in GQA, each Q group shares the same KV heads, thus just repeat KV heads for the Q in the same group
# (B, prefix_seq_len, n_kv_heads, Head_Dim) => (B, prefix_seq_len, n_heads_q, Head_Dim)
7. xq # transpose : (B, 1, n_heads_q, Head_Dim) => (B, n_heads_q, 1, Head_Dim)
   keys,values # transpose :  (B, prefix_seq_len, n_heads_q, Head_Dim) => (B, n_heads_q, prefix_seq_len, Head_Dim)
8. scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
   # (B, n_heads_q, 1, Head_Dim) @ (B, n_heads_q, Head_Dim, prefix_seq_len) => (B, n_heads_q, 1, prefix_seq_len)
"""NOTE about MATMUL: for tensors with more than 2 dimensions, torch.matmul treats the last two dimensions as matrices and performs batch matrix multiplication on the other dimensions. The result is a tensor where each batch element is the result of matrix multiplication on the corresponding batch elements of the input tensors"""
   scores = F.softmax(scores.float(), dim=-1).type_as(xq)  # dim=-1 means softmax along last dimension (sum=1)
   # softmax(QK/sqrt(dk)): (B, n_heads_q, 1, prefix_seq_len) => (B, n_heads_q, 1, prefix_seq_len)
9. output = torch.matmul(scores, values)
   # (B, n_heads_q, 1, prefix_seq_len) @ (B, n_heads_q, prefix_seq_len, Head_Dim) => (B, n_heads_q, 1, Head_Dim)
    output = output.transpose(1, 2).contiguous()
   # (B, n_heads_q, 1, Head_Dim) => (B, 1, n_heads_q, Head_Dim) and make sure contiguous in memory
    output = output.view(batch_size, seq_len, -1)   # -1 means infer the last dimension's shape
   # (B, 1, n_heads_q, Head_Dim) => (B, 1, n_heads_q * Head_Dim) = (B, 1, dim)
    output = wo(output)# apply the attention's output layer
   # (B, 1, dim) => (B, 1, dim)


apply_rotary_embeddings
-------------------------------------------------------------------------------
1. x_complex # Separate the last dim pairs of 2 values (aka real and imaginary parts of the complex number) => make complex, Each pair of 2 consecutive values in head_dim is transformed into a single complex number (thus head_dim / 2
# (B, seq_len, H=n_heads, head_dim) => (B, seq_len, H, head_dim / 2)
2. freqs_complex # feqs_complex to match the shape of the x_complex tensor
# (seq_len, head_dim / 2) => (1, seq_len, 1, head_dim / 2)
3. x_rotated # Element-wise multiplication with broadcasting
# (B, seq_len, H, head_dim / 2) * (1, seq_len, 1, head_dim / 2) => (B, seq_len, H, head_dim / 2)
4. x_real # # Convert complex number back to the real number: additional 2 in the final dim is for real from imag
# (B, seq_len, H, head_dim / 2) => (B, seq_len, H * head_dim / 2, 2)
5. x_out # Flatten the last two dimensions back into 2nd last dimension
# (B, seq_len, H * head_dim / 2, 2) => (B, seq_len, H * head_dim)

```

# RollingBufferKVCache

```python
1. initialize cache_k,cache_v with zeros with shape (max_batch_size, attn_window, n_kv_heads, head_dim)


update_cache
----------------------------------------------------------------------
1. cache_position = start_pos % attn_window # position wraps around within the attn_window size
# circular buffer, meaning it can overwrite old entries when the window's limit is reached.
2. cache_k[:batch_size, cache_position:cache_position + 1] = xk
   cache_v[:batch_size, cache_position:cache_position + 1] = xv
# fill (:B, idx) part of the (max_B, max_seq_len, n_kv_heads, Head_Dim) cache with (B, 1, n_kv_heads, Head_Dim)
# shape of xk and xv: (batch_size, 1, n_kv_heads, head_dim)
# only the active sequences in the batch are updated.Since the sequence length is 1
# as indicated by the shape (batch_size, 1, n_kv_heads, head_dim), we update just a single position.

Example:

attn_window = 5
start_pos = 7
cache_position = 7 % 5 = 2
This means the new key-value pair will be placed at position 2 of the cache (for the specified batch).

update_cache_multiple
-----------------------------------------------------------------------
for i in range(seq_len):
    update_cache(xk[:, i:i+1, :, :], xv[:, i:i+1, :, :], batch_size, start_pos + i)

# updating the cache used when seq_len > 1, yet in inference we only care about the seq_len = 1 case
# can be optimized in the future to support Mistral's pre-fill and chunking (to handle prompts)
# For each token i in the sequence, it slices the tensors xk and xv to extract the keys and values corresponding to that token, and then calls update_cache with these values and the adjusted starting position (start_pos + i).

Example:
seq_len = 3
start_pos = 7
For each i in the range [0, 2]:
    The method extracts xk[:, i:i+1, :, :] and xv[:, i:i+1, :, :], corresponding to the keys and values for the
    i-th token in the sequence.
    It then calls update_cache with these values and updates the cache at positions 7, 8, and 9 (% attn_window).


retrieve_cache
------------------------------------------------------------------------
# calculate the effective start position considering the rolling buffer's nature
# NOTE: start_pos should be updated to be start_pos + seq_len when called after update_cache
1. effective_start_pos = start_pos % attn_window
# retrieve KV from the cache, split into 2 parts to handle the wrap-around
2. keys = torch.cat([cache_k[:batch_size, effective_start_pos:, :, :],
                  cache_k[:batch_size, :effective_start_pos, :, :]], dim=1) # same for values
# select the last seq_len tokens from the concatenated keys and values (to handle when < attn_window)
3. keys = keys[:, -start_pos:, :, :] # same for values
4. return keys, values


Example:
attn_window = 5
batch_size = 2
start_pos = 7
effective_start_pos = 7 % 5 = 2

cache_k for batch 0: [K0, K1, K2, K3, K4]
cache_k for batch 1: [L0, L1, L2, L3, L4]
After splitting:

First Part (batch 0): [K2, K3, K4]
Second Part (batch 0): [K0, K1]
After concatenation:

Combined for batch 0: [K2, K3, K4, K0, K1]

```



Transformer
==================================
1. nn.Embedding(vocab,dim) : token_embeddings
2. nn.ModuleList(TransformerBlock(args) X n_layers) # different LLM architecture implemented
3. RMSNorm(dim,eps)
4. nn.Linear(dim,vocab) : output
5. precompute_freqs_pos_frequencis(dim//n_heads,max_seq*2,device,rope_theta) : freqs_complex # per attention head
# i.e dim // n_heads = head_dim , for Llama rope_theta is not passed

@property # helpful for training-inference device switch error
dtype() : return data type of parameter of the model
device() : device on which model parameters are stored


forward()
----------------------------------------------
"""
note that with the KV Cache, only need the latest tokens, no need all tokens: info about previous tokens are saved in the cache
NOTE: this is only for inference, not training (in training there's no KV cache)
"""
1. get hidden_state = token_embeddings(tokens) # (B,seq) -> (B,seq,dim) [1 token at a time --> map to dim]
2. retrive pairs(m,theta) corresponding to position [start_pos, start_pos + seq_len] from freqs_complex
3. hidden_states = layer(hidden_states,start_pos,freqs_complex) for layer in layers
# apply precomputed frequencies to the encoding layers for positional encoding , each layer is (Nx transformer blocks)
4. apply RMSNorm on combined hidden_states from step 3
5. pass through output


TransformerBlock
===============================================================
""" a single transformer block (different for Llama & Mistral) """
1. head_dim = dim//n_heads
2. RMSNorm(dim,eps) : rms_norm # before attention & feed_forwardd
2. SelfAttention(args) : attention # Decoder only with causal attention (only work for inference)
3. MOE(experts=[FeedForward(args) X moe.n_experts,gate=nn.Linear(dim,moe.n_experts),args.moe]
# Feed Forward Layer (with MoE support) otherwise FeedForward(args) : feed_forwad


forward()
-----------------------------------------------------------------
1. hidden_states = x + attention.forward(rms_norm(x),start_pos,freqs_complex)
# (B, seq_Len, dim) + (B, seq_Len, dim) => (B, seq_Len, dim) we're dealing
# with 1 token at a time start_pos : current token we're dealing
2. out = hidden_states + feed_forward(rms_norm(hidden_states)) # (B, seq_Len, dim) + (B, seq_Len, dim) => (B, seq_Len, dim)
3. return out



RMSNorm
========================================================================
"""same for both"""
1. nn.Parameter(torch.ones(dim)) : weight
# gamma(g) parameter trainable to perform rescaling on the norm

_norm()
-------------------------------------------------------------------------
# RMSNorm stat : (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
1. return x * 1/rms

forward()
-------------------------------------------------------------------------
# (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
# auto-broadcasting expands (Dim) to (1, 1, Dim) to multiplied to the last dimension of (B, Seq_Len, Dim)
# recall: Automatic broadcasting in PyTorch occurs when dimensions match or are broadcastable starting from the trailing dimensions (i.e., from right to left)
1. weight * _norm(x) # typecast to float & verify if it's same type as x


FeedForward
==========================================================================
"""same for both but Llama uses multiple_of parameter for hidden_dim"""
1. hidden_dim = multiple_of * ((int(hidden_dim) + multiple_of - 1) // multiple_of
# round hidden_dim to the nearest multiple of the args.multiple_of parameter (bigger or equal)
# just a design choice to look cool :)
1. gate : w1(dim,hidden_dim) , up : w2(hidden_dim,dim), down : w3(dim,hidden_dim)

forward()
---------------------------------------------------------------
1. xw1 = w1(x)
# (S(XW1) * XV)XW2 = (ss)XW2; goal shape: (B, seq_len, Dim) => (B, seq_len, Dim)
# (B, seq_len, Dim) w1=> (B, seq_len, Hidden_Dim)
2. sxw1 = silu(xw1)
# (B, seq_len, Hidden_Dim) => (B, seq_len, Hidden_Dim) # silu
3. xv = w3(x)
# (B, seq_len, Dim) w3=> (B, seq_len, Hidden_Dim)
4. sxw1xv = sxw1 * xv
# (B, seq_len, Hidden_Dim) * (B, seq_len, Hidden_Dim) = (B, seq_len, Hidden_Dim) = element wise multiplication
5. return w2(sxw1xv)
# (B, seq_len, Hidden_Dim) w2=> (B, seq_len, Dim)


MOE
================================================================================
"""the only difference in Mixtral MOE: after attention, instead of RMS=>MLP, it has RMS=>MOE"""
1. nn.ModuleList(experts) : experts
2. gate , moe_args

forward()
------------------------------------------------------------------------------------
# NOTE: in the mistral paper, all input/output size used are (B * seq_len, Dim) instead of (B, seq_len, Dim)
# goal shape: (B, seq_len, Dim) = > (B, seq_len, Dim)
1. flat_x = reshape input shape to (B * seq_len, Dim)
2. gate_logic = gate(flat_x)  # recall gate is linear with (Dim, n_experts)
# (B * seq_len, Dim) gate=> (B * seq_len, n_experts) for each input token
3. weights, selected_experts = torch.topk(gate_logits, self.moe_args.n_experts_per_tok)
# Get the top k experts for each input token, using torch.topk
# weights=logits, selected_experts=indices
# (B * seq_len, n_experts) => (B * seq_len, n_experts_per_tok)
4. weights = F.softmax(weights, dim=1, dtype=torch.float).to(x.dtype)
# # Normalize the weights with softmax, to get the selected top k experts' weights on the tokens
5. results = torch.zeros_like(flat_x)
# init results: (B * seq_len, Dim)
6. # Iterate over each expert to compute the weighted sum of the outputs from each selected top k experts,
for i, expert in enumerate(self.experts):
    # for each expert: retrieves only the batch_idx & selected_exp_idx this expert is responsible for
    batch_idx, selected_exp_idx = torch.where(selected_experts == i)
    # (K, Dim) => (K, Dim), where K is how many tokens this expert is responsible for
    expert_out = expert(flat_x[batch_idx])  # recall expert is FFN with Dim=>Dim
    # (K, 1), where K is how many tokens this expert is responsible for
    expert_w = weights[batch_idx, selected_exp_idx, None]
    # add the experts' weighted sum output to the corresponding tokens
    # expert_w * expert_out: (K, 1) * (K, Dim) => (K, Dim)
    # results: still (B * seq_len, Dim), where the corresponding tokens are updated
    results[batch_idx] += expert_w * expert_out
7. results = results.view(B, seq_len, dim)    
# reshape results: (B * seq_len, Dim) => (batch_size, seq_len, dim)
8. return results


In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
device

device(type='cuda', index=0)

In [None]:
import pandas as pd

# Assuming we have the following data
data = {
    'transaction_id': ['T1001', 'T1002', 'T1003', 'T1004', 'T1005'],
    'customer_id': ['C001', 'C002', 'C003', 'C002', 'C001'],
    'payment_amount': [125.50, 89.99, 120.00, 54.30, 210.20],
    'payment_date': ['2021-10-05', '2021-10-06', '2021-10-07', '2021-10-05', '2021-10-08'],
    'payment_status': ['Paid', 'Unpaid', 'Paid', 'Paid', 'Pending']
}

# Create DataFrame
df = pd.DataFrame(data)
df

Unnamed: 0,transaction_id,customer_id,payment_amount,payment_date,payment_status
0,T1001,C001,125.5,2021-10-05,Paid
1,T1002,C002,89.99,2021-10-06,Unpaid
2,T1003,C003,120.0,2021-10-07,Paid
3,T1004,C002,54.3,2021-10-05,Paid
4,T1005,C001,210.2,2021-10-08,Pending


In [None]:
def retrieve_payment_status(df: data, transaction_id: str) -> str:
    if transaction_id in df.transaction_id.values:
        return json.dumps({'status': df[df.transaction_id == transaction_id].payment_status.item()})
    return json.dumps({'error': 'transaction id not found.'})

def retrieve_payment_date(df: data, transaction_id: str) -> str:
    if transaction_id in df.transaction_id.values:
        return json.dumps({'date': df[df.transaction_id == transaction_id].payment_date.item()})
    return json.dumps({'error': 'transaction id not found.'})

In [None]:
import functools

tools = [
    {
        "type": "function",
        "function": {
            "name": "retrieve_payment_status",
            "description": "Get payment status of a transaction",
            "parameters": {
                "type": "object",
                "properties": {
                    "transaction_id": {
                        "type": "string",
                        "description": "The transaction id.",
                    }
                },
                "required": ["transaction_id"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "retrieve_payment_date",
            "description": "Get payment date of a transaction",
            "parameters": {
                "type": "object",
                "properties": {
                    "transaction_id": {
                        "type": "string",
                        "description": "The transaction id.",
                    }
                },
                "required": ["transaction_id"],
            },
        },
    }
]



names_to_functions = {
  'retrieve_payment_status': functools.partial(retrieve_payment_status, df=df),
  'retrieve_payment_date': functools.partial(retrieve_payment_date, df=df)
}

messages = [{"role": "user", "content": "What's the status of my transaction T1001?"}]

B_FUNC, E_FUNC = "You have access to the following functions. Use them if required:\n\n", "\n\n"
B_INST, E_INST = "GPT4 Correct User: ", "<|end_of_turn|>GPT4 Correct Assistant:\n\n" #OpenChat style
prompt = f"{B_INST}{B_FUNC}{names_to_functions}{E_FUNC}{messages}{E_INST}\n\n"
prompt

'GPT4 Correct User: You have access to the following functions. Use them if required:\n\n{\'retrieve_payment_status\': functools.partial(<function retrieve_payment_status at 0x7f238ff1d1b0>, df=  transaction_id customer_id  payment_amount payment_date payment_status\n0          T1001        C001          125.50   2021-10-05           Paid\n1          T1002        C002           89.99   2021-10-06         Unpaid\n2          T1003        C003          120.00   2021-10-07           Paid\n3          T1004        C002           54.30   2021-10-05           Paid\n4          T1005        C001          210.20   2021-10-08        Pending), \'retrieve_payment_date\': functools.partial(<function retrieve_payment_date at 0x7f238ff1d240>, df=  transaction_id customer_id  payment_amount payment_date payment_status\n0          T1001        C001          125.50   2021-10-05           Paid\n1          T1002        C002           89.99   2021-10-06         Unpaid\n2          T1003        C003          1

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id,device="auto")

model = AutoModelForCausalLM.from_pretrained(model_id,trust_remote_code=True,torch_dtype=torch.bfloat16,
                                      low_cpu_mem_usage=True,device_map="auto")

# text = "Hello my name is"
# inputs = tokenizer(text, return_tensors="pt")

# outputs = model.generate(**inputs, max_new_tokens=20)
# print(tokenizer.decode(outputs[0], skip_special_tokens=True))



Downloading shards:   0%|          | 0/19 [00:00<?, ?it/s]

model-00001-of-00019.safetensors:  61%|######1   | 3.00G/4.89G [00:00<?, ?B/s]

model-00002-of-00019.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00019.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00019.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00005-of-00019.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00006-of-00019.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00007-of-00019.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00008-of-00019.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

In [None]:
# prompt = tokenizer.apply_chat_template(prompt,  return_dict=True, return_tensors="pt", add_generation_prompt=True,tokenize=False)
# prompt

In [None]:
# Tokenize and get model outputs
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

In [None]:
# Function to extract function calls from the model output
def extract_function_call(response):
    # Example extraction logic; modify according to actual response format
    match = re.search(r'Call function (\w+)\((.*)\)', response)
    if match:
        func_name = match.group(1)
        params = json.loads(f"{{{match.group(2)}}}")
        return func_name, params
    return None, None

# Extract function call and parameters
func_name, params = extract_function_call(response)

# Call the function if it exists
if func_name in names_to_functions:
    result = names_to_functions[func_name](**params)
    print(f"Function call result: {result}")
else:
    print("No valid function call found.")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Expert, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class Router(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(Router, self).__init__()
        self.fc = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        logits = self.fc(x)
        probs = F.softmax(logits, dim=-1)
        return probs

class MoELayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts, top_k=1):
        super(MoELayer, self).__init__()
        self.experts = nn.ModuleList([Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)])
        self.router = Router(input_dim, num_experts)
        self.num_experts = num_experts
        self.top_k = top_k if top_k <= num_experts else num_experts  # Ensure top_k does not exceed num_experts

    def forward(self, x):
        batch_size, seq_length, _ = x.size()
        # Routing
        probs = self.router(x)  # Shape: [batch_size, seq_length, num_experts]

        # Get top-k experts for each token
        topk_probs, topk_indices = torch.topk(probs, self.top_k, dim=-1)  # Shape: [batch_size, seq_length, top_k]

        # Prepare outputs
        outputs = torch.zeros(batch_size, seq_length, x.size(-1)).to(x.device)

        # Iterate over available experts (up to top_k)
        for i in range(self.top_k):
            expert_output = self.experts[i](x)  # Shape: [batch_size, seq_length, output_dim]

            # Create mask for current expert i
            mask = topk_indices[:, :, i]  # Shape: [batch_size, seq_length]
            mask = mask.unsqueeze(-1).expand(-1, -1, expert_output.size(-1))  # Shape: [batch_size, seq_length, output_dim]

            # Gather corresponding probabilities
            masked_probs = topk_probs[:, :, i]  # Shape: [batch_size, seq_length]

            # Apply mask to expert outputs
            masked_output = expert_output * masked_probs.unsqueeze(-1)  # Shape: [batch_size, seq_length, output_dim]

            # Sum over top-k dimension
            outputs += masked_output

        return outputs



class DecoderBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts, top_k=1):
        super(DecoderBlock, self).__init__()
        self.self_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=8)
        self.moelayer = MoELayer(input_dim, hidden_dim, output_dim, num_experts, top_k)
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.layer_norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, memory):
        # Self-Attention
        attn_output, attn_output_weights = self.self_attention(x, x, x)
        x = self.layer_norm1(x + self.dropout(attn_output)) # residual connection

        # MoE Layer
        moe_output = self.moelayer(x)
        x = self.layer_norm2(x + self.dropout(moe_output))
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, num_experts, top_k=1):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList([
            DecoderBlock(input_dim, hidden_dim, output_dim, num_experts, top_k)
            for _ in range(num_layers)
        ])

    def forward(self, x, memory):
        for layer in self.layers:
            x = layer(x, memory)
        return x

# Usage
input_dim = 128
hidden_dim = 256
output_dim = 128
num_experts = 4
top_k = 2
num_layers = 6
batch_size = 32
seq_length = 10

# Initialize the Transformer Decoder with MoE
decoder = TransformerDecoder(num_layers, input_dim, hidden_dim, output_dim, num_experts, top_k)
x = torch.randn(batch_size, seq_length, input_dim)
memory = torch.randn(batch_size, seq_length, input_dim)  # For illustration, this could be encoder output in practice
output = decoder(x, memory)
print(x.shape)
print(output.shape)