In [16]:
import torch
import torch.nn as nn
from math import sqrt

from torch.utils.data import DataLoader, Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import load_dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR



In [17]:
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Устанавливаем pad_token
input_text = "I will by this food for"
input_ids = tokenizer.encode(input_text, return_tensors="pt")



In [18]:
# Загрузка датасета
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Предобработка данных
class WikiTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=64):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        # Фильтрация пустых текстов
        self.data = [item for item in dataset if item['text'].strip()]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['text']
        encoding = self.tokenizer(
            text,
            return_tensors='pt',
            max_length=self.max_length,
            padding='max_length',  # Дополнение до max_length
            truncation=True
        )
        input_ids = encoding['input_ids'].squeeze(0)  # Убираем batch dimension
        return input_ids

train_dataset = WikiTextDataset(dataset['train'], tokenizer)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [19]:
class GLAAttention(nn.Module):
    def __init__(self, hidden_dim=768, c=5):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.C = c
        # Инициализация обучаемых параметров
        self.Q = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.K = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.V = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.W1 = nn.Parameter(torch.Tensor(hidden_dim, 16))
        self.W2 = nn.Parameter(torch.Tensor(16, hidden_dim))
        self.b = nn.Parameter(torch.Tensor(hidden_dim))
        
        # Инициализация параметров
        nn.init.xavier_normal_(self.Q)
        nn.init.xavier_normal_(self.K)
        nn.init.xavier_normal_(self.V)
        nn.init.xavier_normal_(self.W1)
        nn.init.xavier_normal_(self.W2)
        nn.init.zeros_(self.b)

        

        # Инициализация параметров
        nn.init.xavier_normal_(self.Q)
        nn.init.xavier_normal_(self.K)
        nn.init.xavier_normal_(self.V)
        self.S = torch.zeros(768, 768)
        self.register_buffer('base_mask', torch.tril(torch.ones(c, c)))

    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, 
                use_cache=False, output_attentions=False):
        batch_size, seq_len, _ = x.shape
        
        # Проецирование входных данных
        Q = torch.matmul(x, self.Q)  # [batch, seq, hidden]
        K = torch.matmul(x, self.K)
        V = torch.matmul(x, self.V)
        
        # Разделение на блоки
        num_blocks = seq_len // self.C
        remainder = seq_len % self.C
        
        # Основные блоки
        S = torch.zeros(batch_size, self.hidden_dim, self.hidden_dim, 
                       device=x.device)
        outputs = []
        
        for i in range(num_blocks):
            start = i * self.C
            end = (i+1) * self.C
            
            Q_block = Q[:, start:end]  # [batch, C, hidden]
            K_block = K[:, start:end]
            V_block = V[:, start:end]
            K_block_T = K_block.transpose(-1, -2)  # [batch, hidden, C]

            # Вычисление внимания
            attn_scores = torch.matmul(Q_block, K_block_T)  # [batch, C, C]
            attn_scores = attn_scores * self.base_mask
            attn_scores = attn_scores / sqrt(self.hidden_dim)
            
            # Применение масок
            if attention_mask is not None:
                attn_scores += attention_mask[:, start:end, start:end]
            
            attn_weights = torch.softmax(attn_scores, dim=-1)
            
            # Обновление состояния
            S_update = torch.matmul(K_block_T, V_block)  # [batch, hidden, hidden]
            alpha = torch.matmul(x, self.W1)
            alpha = torch.matmul(alpha, self.W2) + self.b  # [batch, seq, hidden]
            alpha = torch.sigmoid(alpha) ** (1/16)  # Применение сигмоида и возведение в степень
            alpha_avg = alpha.mean(dim=1, keepdim=True)  # [1,1,768]

            # Создаём матрицу через повторение
            alpha_matrix = alpha_avg.repeat_interleave(768, dim=1)  # [1,768,768]
            S = S * alpha_matrix + S_update            
            # Вычисление выхода
            output = torch.matmul(Q_block, S) + torch.matmul(attn_weights, V_block)
            outputs.append(output)
        
        # Обработка остатка
        if remainder > 0:
            start = num_blocks * self.C
            Q_remain = Q[:, start:]  # [batch, rem, hidden]
            K_remain = K[:, start:]
            V_remain = V[:, start:]
            
            mask = torch.tril(torch.ones(remainder, remainder, device=x.device))
            attn_scores = torch.matmul(Q_remain, K_remain.mT) * mask
            attn_scores = attn_scores / sqrt(self.hidden_dim)
            
            if attention_mask is not None:
                attn_scores += attention_mask[:, start:, start:]
            
            attn_weights = torch.softmax(attn_scores, dim=-1)
            output = torch.matmul(Q_remain, S) + torch.matmul(attn_weights, V_remain)
            outputs.append(output)
        
        # Сборка выходов
        O = torch.cat(outputs, dim=1)
        return (O,)
        

In [20]:
class GLAMLP(nn.Module):
    def __init__(self, hidden_dim=768, c=5):
        super().__init__()
        self.Wr = nn.Linear(768, 768, bias = True)
        self.Wo = nn.Linear(768, 768, bias = False)
    def forward(self, x, o):
        r = nn.SiLU()(self.Wr(x))
        o = self.Wo(r*o)
        return o

In [21]:
class GLA(nn.Module):
    def __init__(self, c=5):
        super().__init__()
        self.gpt2_lmhead = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt2 = self.gpt2_lmhead.transformer
        self.lm_head = self.gpt2_lmhead.lm_head
        self.config = self.gpt2.config
        self.wte = self.gpt2.wte
        self.wpe = self.gpt2.wpe
        self.drop = self.gpt2.drop
        self.ln_f = self.gpt2.ln_f
        self.gpt2_layers = []
        for i in range(12):
            tmp = self.gpt2.h[i]
            tmp.attn = GLAAttention()
            tmp.mlp = GLAMLP()
            self.gpt2_layers.append(tmp)

    def layers(self):
        return self.gpt2_layers, self.config

    def forward(self, X):
        X_int =X
        position_ids = torch.arange(0, X_int.shape[-1], dtype=torch.long)
        position_ids = position_ids.unsqueeze(0)
        X = self.wte(X_int)
        X_p = self.wpe(position_ids)
        X+=X_p
        X = self.drop(X)
        for el in self.gpt2_layers:
            X_init = X
            X = el.ln_1(X)
            X = el.attn(X)
            X = el.ln_2(X[0])
            X = el.mlp(X_init, X)
        X = self.ln_f(X)
        X = self.lm_head(X)
        argmax_indices = torch.argmax(X, dim=2)
        return X,argmax_indices  # Возвращаем результат после прохождения через все слои


model = GLA(5)
logits, output = model(input_ids.int())
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
print(logits)

earingearingearingearingearingearing
tensor([[[ 2.7423,  1.9723,  2.3502,  ..., -1.2254, -0.1439,  4.4315],
         [ 2.7423,  1.9723,  2.3502,  ..., -1.2254, -0.1439,  4.4315],
         [ 2.7423,  1.9723,  2.3502,  ..., -1.2254, -0.1439,  4.4315],
         [ 2.7423,  1.9723,  2.3502,  ..., -1.2254, -0.1439,  4.4315],
         [ 2.7423,  1.9723,  2.3502,  ..., -1.2254, -0.1439,  4.4315],
         [ 2.7301,  1.9514,  2.3378,  ..., -1.2144, -0.1548,  4.4135]]],
       grad_fn=<UnsafeViewBackward0>)


In [31]:
# Определение оптимизатора и функции потерь
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = StepLR(optimizer, step_size=1, gamma=0.9)
criterion = torch.nn.CrossEntropyLoss(reduction='mean', ignore_index = 50256)

model.train()

num_epochs = 3
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, input_ids in enumerate(train_loader):
        print(batch_idx)
        input_ids = input_ids
        print(input_ids.shape)

        # Создание input и target
        inputs = input_ids[:, :-1]  # Все токены, кроме последнего
        targets = input_ids[:, 1:]  # Все токены, кроме первого

        # Forward pass
        logits, outputs = model(inputs)
        # Вычисление потерь
        print(logits.shape, targets.shape)
        loss = criterion(
    logits.reshape(-1, logits.size(-1)),  # [8*63, 50257]
    targets.reshape(-1)                   # [8*63]
)
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        print(f"Epoch {epoch + 1}, Batch {batch_idx}, Loss: {loss.item()}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss}")

# Сохранение модели
model.save_pretrained("./gpt2-wikitext")
tokenizer.save_pretrained("./gpt2-wikitext")

0
torch.Size([8, 64])
torch.Size([8, 63, 50257]) torch.Size([8, 63])
Epoch 1, Batch 0, Loss: 13.851832389831543
1
torch.Size([8, 64])
torch.Size([8, 63, 50257]) torch.Size([8, 63])
Epoch 1, Batch 1, Loss: 13.253299713134766
2
torch.Size([8, 64])
torch.Size([8, 63, 50257]) torch.Size([8, 63])
Epoch 1, Batch 2, Loss: 12.894468307495117
3
torch.Size([8, 64])
torch.Size([8, 63, 50257]) torch.Size([8, 63])
Epoch 1, Batch 3, Loss: 12.35001277923584
4
torch.Size([8, 64])
torch.Size([8, 63, 50257]) torch.Size([8, 63])



KeyboardInterrupt

