In [111]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from math import sqrt

In [28]:
output

tensor([[  140,    94, 16843,   140,   111, 25443,   112, 22177, 40623, 12466,
           123, 21169, 16843, 31583, 21169, 16142, 21727, 22177, 45035,   140,
           117, 12466,   112, 16843, 22177, 45367,    11, 12466,   123, 15166,
         20375, 25443,   120, 35072,   220,   141,   229, 20375, 15166, 12466,
           116, 12466,   121, 16843, 12466,   110, 12466,   120, 16843, 21169,
           140,   123, 25443,   110, 16142, 22177, 18849, 16843,   220, 21727,
         20375, 21169, 15166, 18849, 12466,   118, 15166, 21169, 43108, 16142,
         20375, 45367, 12466,   122,   140,   109, 21169, 25443,   111, 15166,
           220, 20375, 16843, 30143, 16843, 20375, 12466,   111, 21169, 35072,
           140,   114, 16843, 43108, 12466,   109, 45035, 30143, 18849,   220]])

In [104]:
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
input_text = "я сегодня иду в школу"
input_ids = tokenizer.encode(input_text, return_tensors="pt")

In [None]:
# Генерация текста
output = model.generate(input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7)

# Декодирование и вывод результата
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

In [14]:
generated_text

'Сегодня прекрасный день, потому что и не в мерпование строи кормать оброго телет гружем были '

In [5]:
tok = GPT2Tokenizer.from_pretrained('gpt2')



In [18]:
input_ids.int()

tensor([[  140,    94, 16843,   140,   111, 25443,   112, 22177, 40623, 12466,
           123, 21169, 16843, 31583, 21169, 16142, 21727, 22177, 45035,   140,
           117, 12466,   112, 16843, 22177, 45367,    11, 12466,   123, 15166,
         20375, 25443,   120, 35072,   220,   141,   229, 20375, 15166]],
       dtype=torch.int32)

In [124]:
class GLAAttention(nn.Module):
    def __init__(self, Q, K, V, hidden_dim=768, c=5):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.C = c
        self.Q = nn.Parameter(Q)
        self.K = nn.Parameter(K)
        self.V = nn.Parameter(V)
        self.attention_weights = None  # Для сохранения attention weights
        self.S = torch.zeros(768, 768)
        self.register_buffer('base_mask', torch.tril(torch.ones(c, c)))

    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, 
                use_cache=False, output_attentions=False):
        batch_size, seq_len, _ = x.shape
        
        # Проецирование входных данных
        Q = torch.matmul(x, self.Q)  # [batch, seq, hidden]
        K = torch.matmul(x, self.K)
        V = torch.matmul(x, self.V)
        
        # Разделение на блоки
        num_blocks = seq_len // self.C
        remainder = seq_len % self.C
        
        # Основные блоки
        S = torch.zeros(batch_size, self.hidden_dim, self.hidden_dim, 
                       device=x.device)
        outputs = []
        
        for i in range(num_blocks):
            start = i * self.C
            end = (i+1) * self.C
            
            Q_block = Q[:, start:end]  # [batch, C, hidden]
            K_block = K[:, start:end]
            V_block = V[:, start:end]
            K_block_T = K_block.transpose(-1, -2)  # [batch, hidden, C]

            # Вычисление внимания
            attn_scores = torch.matmul(Q_block, K_block_T)  # [batch, C, C]
            attn_scores = attn_scores * self.base_mask
            attn_scores = attn_scores / sqrt(self.hidden_dim)
            
            # Применение масок
            if attention_mask is not None:
                attn_scores += attention_mask[:, start:end, start:end]
            
            attn_weights = torch.softmax(attn_scores, dim=-1)
            
            # Обновление состояния
            S_update = torch.matmul(K_block_T, V_block)  # [batch, hidden, hidden]
            S = S + S_update
            
            # Вычисление выхода
            output = torch.matmul(Q_block, S) + torch.matmul(attn_weights, V_block)
            outputs.append(output)
        
        # Обработка остатка
        if remainder > 0:
            start = num_blocks * self.C
            Q_remain = Q[:, start:]  # [batch, rem, hidden]
            K_remain = K[:, start:]
            V_remain = V[:, start:]
            
            mask = torch.tril(torch.ones(remainder, remainder, device=x.device))
            attn_scores = torch.matmul(Q_remain, K_remain.mT) * mask
            attn_scores = attn_scores / sqrt(self.hidden_dim)
            
            if attention_mask is not None:
                attn_scores += attention_mask[:, start:, start:]
            
            attn_weights = torch.softmax(attn_scores, dim=-1)
            output = torch.matmul(Q_remain, S) + torch.matmul(attn_weights, V_remain)
            outputs.append(output)
        
        # Сборка выходов
        O = torch.cat(outputs, dim=1)
        return (O,)
        


class GLA(nn.Module):
    def __init__(self, c=5):
        super().__init__()
        self.gpt2_lmhead = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt2 = self.gpt2_lmhead.transformer
        self.lm_head = self.gpt2_lmhead.lm_head
        self.config = self.gpt2.config
        self.wte = self.gpt2.wte
        self.wpe = self.gpt2.wpe
        self.drop = self.gpt2.drop
        self.ln_f = self.gpt2.ln_f
        self.gpt2_layers = []
        for i in range(12):
            tmp = self.gpt2.h[i]
            tmp.attn = GLAAttention(tmp.attn.c_attn.weight[:, :self.config.n_embd], 
                                     tmp.attn.c_attn.weight[:, self.config.n_embd:2*self.config.n_embd],
                                     tmp.attn.c_attn.weight[:, 2*self.config.n_embd:3*self.config.n_embd], c=c)
            self.gpt2_layers.append(tmp)

    def layers(self):
        return self.gpt2_layers, self.config

    def forward(self, X):
        X_int =X
        position_ids = torch.arange(0, X_int.shape[-1], dtype=torch.long)
        position_ids = position_ids.unsqueeze(0)
        X = self.wte(X_int)
        X_p = self.wpe(position_ids)
        X+=X_p
        X = self.drop(X)
        for el in self.gpt2_layers:
            X = el.ln_1(X)
            X = el.attn(X)
            X = el.ln_2(X[0])
            X = el.mlp(X)
        X = self.ln_f(X)
        X = self.lm_head(X)
        argmax_indices = torch.argmax(X, dim=2)
        return X,argmax_indices  # Возвращаем результат после прохождения через все слои


model = GLA(5)
logits, output = model(input_ids.int())
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
print(logits)

 the the the the the the the the the the the the the the the the the the the the the the the
tensor([[[-73.5670, -71.8578, -78.1044,  ..., -85.1023, -81.5467, -73.6291],
         [-73.5670, -71.8578, -78.1044,  ..., -85.1023, -81.5467, -73.6291],
         [-73.5670, -71.8578, -78.1044,  ..., -85.1024, -81.5467, -73.6291],
         ...,
         [-73.5423, -71.8281, -78.0751,  ..., -85.0805, -81.5228, -73.6000],
         [-73.5423, -71.8281, -78.0751,  ..., -85.0805, -81.5228, -73.5999],
         [-73.5423, -71.8281, -78.0751,  ..., -85.0805, -81.5228, -73.5999]]],
       grad_fn=<UnsafeViewBackward0>)


In [46]:
i =0
while i<5:
    i+=1
i+=1
print(i)

6


In [3]:
layers = model.layers()
layers

([GPT2Block(
    (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): GLAAttention()
    (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): GPT2MLP(
      (c_fc): Conv1D()
      (c_proj): Conv1D()
      (act): NewGELUActivation()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  ),
  GPT2Block(
    (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): GLAAttention()
    (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): GPT2MLP(
      (c_fc): Conv1D()
      (c_proj): Conv1D()
      (act): NewGELUActivation()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  ),
  GPT2Block(
    (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): GLAAttention()
    (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): GPT2MLP(
      (c_fc): Conv1D()
      (c_proj): Conv1D()
      (act): NewGELUActivation()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  ),
  GPT2B

In [4]:
model.gpt2

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GLAAttention()
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [59]:
layers[1].attn.Q_weights

tensor([[-0.2906,  0.3057,  0.0302,  ...,  0.0132, -0.2652, -0.2058],
        [-0.3272,  0.2420,  0.2140,  ..., -0.0861,  0.0558,  0.1168],
        [-0.2679,  0.1188, -0.2670,  ..., -0.1061,  0.2167,  0.1066],
        ...,
        [-0.0284,  0.4304, -0.1394,  ..., -0.1750,  0.0154, -0.0614],
        [ 0.1730,  0.0967,  0.0262,  ...,  0.1744,  0.3897, -0.2129],
        [ 0.0422,  0.1598, -0.2512,  ..., -0.0259,  0.2618,  0.0779]],
       grad_fn=<SliceBackward0>)

In [145]:
for el in model.parameters():
    print(el.shape)

torch.Size([50257, 768])
torch.Size([1024, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768

In [147]:
for el in model.parameters():
    print(el.shape)

torch.Size([50257, 768])
torch.Size([1024, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Siz