In [1]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
)
from datasets import load_dataset

dataset = load_dataset("ashaba1in/small_openwebtext")
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

print(tokenizer.pad_token)

[PAD]


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from llama_modeling import Llama
from llama_modeling import MultiheadAttention


In [3]:
import torch
from torch.utils.data import DataLoader, Dataset


device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
print(device)

class TextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length):
        self.dataset = dataset['train']
        #self.dataset = [item for item in dataset['train'] if len(item["text"]) > 2000]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]["text"]
        tokens = self.tokenizer(
            text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"
        ) 

        input_ids = tokens["input_ids"].squeeze(0)
        labels = input_ids[1:].clone()
        input_ids = input_ids[:-1]
        return input_ids, labels


batch_size = 50
max_length = 256
vocab_size = tokenizer.vocab_size + 5
embed_dim = 800
num_heads = 16
head_dim = 50
num_layers = 16

model = Llama(vocab_size, embed_dim, num_heads, head_dim, max_length, num_layers, device).to(device)
#optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)


train_dataset = TextDataset(dataset, tokenizer, max_length)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


cuda:5
llama3


In [4]:
import torch
import torch.optim as optim
import math

optimizer = optim.AdamW(model.parameters(), lr=2e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.1)

total_steps = 1_000_000
warmup_steps = int(0.001 * total_steps)  

def lr_lambda(current_step):
    if current_step < warmup_steps:
        return current_step / warmup_steps 
    else:
        progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
        return 0.5 * (1 + math.cos(math.pi * progress)) 

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [5]:


total_params = sum(p.numel() for p in model.parameters())
print(f"Всего параметров: {total_params :,}")
def print_param_count(model):
    total_params = 0
    for name, param in model.named_parameters():
        param_count = param.numel() 
        total_params += param_count
        print(param.dtype)
        print(f"Параметры слоя {name}: {param_count:,}")

    print(f"\nВсего параметров: {total_params:,}")
print_param_count(model)

Всего параметров: 153,665,605
torch.float32
Параметры слоя embed.weight: 25,604,000
torch.float32
Параметры слоя decoder_blocks.0.norm1.weight: 800
torch.bfloat16
Параметры слоя decoder_blocks.0.MLP.gate_proj.weight: 1,280,000
torch.bfloat16
Параметры слоя decoder_blocks.0.MLP.up_proj.weight: 1,280,000
torch.bfloat16
Параметры слоя decoder_blocks.0.MLP.down_proj.weight: 1,280,000
torch.float32
Параметры слоя decoder_blocks.0.norm2.weight: 800
torch.bfloat16
Параметры слоя decoder_blocks.0.q_proj.weight: 640,000
torch.bfloat16
Параметры слоя decoder_blocks.0.k_proj.weight: 640,000
torch.bfloat16
Параметры слоя decoder_blocks.0.v_proj.weight: 640,000
torch.bfloat16
Параметры слоя decoder_blocks.0.out_proj.weight: 640,000
torch.float32
Параметры слоя decoder_blocks.1.norm1.weight: 800
torch.bfloat16
Параметры слоя decoder_blocks.1.MLP.gate_proj.weight: 1,280,000
torch.bfloat16
Параметры слоя decoder_blocks.1.MLP.up_proj.weight: 1,280,000
torch.bfloat16
Параметры слоя decoder_blocks.1.MLP.

In [6]:
import torch
import time
from tqdm import tqdm

def train():
    start_time = time.time()
    num_epochs = 1  
    log_interval = 10
    savelog_interval = 10000
    save_interval = 100000
    samples_processed = 0 
    samples_log_processed = 0
    avg_loss = 100
    i = 1
    for epoch in range(num_epochs):
        loss_accum = 0.0
        with tqdm(enumerate(train_loader, start=1), total=len(train_loader)) as pbar:
            for batch_idx, (x, y) in pbar:
                x, y = x.to(device), y.to(device)

                logits, loss = model(x, y)

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for i in range(len(x)):
                    scheduler.step()

                loss_accum += loss.item()
                samples_processed += len(x)  
                samples_log_processed += len(x)
                
                if batch_idx % log_interval == 0:
                    avg_loss = loss_accum / log_interval
                    pbar.set_postfix({
                        "Avg Loss": f"{avg_loss:.6f}", 
                        "LR": f"{optimizer.param_groups[0]['lr']:.2e}", 
                        "Samples": samples_processed
                    })
                    loss_accum = 0.0  
                    
                if samples_log_processed >= savelog_interval:
                    with open("loss_log42.txt", "a") as f:
                        f.write(f"Epoch {epoch}, Batch {batch_idx}, Avg Loss: {avg_loss:.6f}, LR: {optimizer.param_groups[0]['lr']:.2e}\n")
                    samples_log_processed = 0
                    
                if samples_processed >= save_interval:
                    checkpoint = {
                        "epoch": epoch,
                        "batch_idx": batch_idx,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict() if scheduler else None,
                        "loss": avg_loss,
                    }
                    torch.save(checkpoint, f"checkpoint_epoch42_{epoch}_batch_{batch_idx}.pth")
                                        
                    samples_processed = 0
                i += 1
    end_time = time.time()
    return end_time - start_time

elapsed_time = train()
print(f"Время выполнения {elapsed_time:.2f} секунд")


 84%|████████▍ | 12022/14229 [2:47:29<30:44,  1.20it/s, Avg Loss=4.678125, LR=6.89e-06, Samples=1000]   


KeyboardInterrupt: 

In [None]:
import torch

def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.8, top_k=25, device="cuda"):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        for _ in range(max_length):
            logits = model(input_ids)
            logits = logits[:, -1, :]
            
            logits = logits / temperature
            if top_k > 0:

                top_k_values, top_k_indices = torch.topk(logits, top_k)
                if (top_k_indices < 0).any() or (top_k_indices >= vocab_size).any():
                    print("Invalid indices detected in top_k_indices!")
                    print("Min index:", top_k_indices.min().item())
                    print("Max index:", top_k_indices.max().item())

                probs = torch.softmax(top_k_values, dim=-1)
                sampled_index = torch.multinomial(probs, 1)
                next_token = top_k_indices.gather(-1, sampled_index)

            else:
                probs = torch.softmax(logits, dim=-1)
                
                sampled_index = torch.multinomial(probs, 1)
                next_token = top_k_indices.gather(-1, sampled_index)

            input_ids = torch.cat([input_ids, next_token], dim=1)

            if tokenizer.eos_token_id is not None and next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True)

prompt = "I think"
generated_text = generate_text(model, tokenizer, prompt, max_length=200, device = device)
print(generated_text)


In [None]:
import torch

checkpoint_path = "checkpoint_epoch4_0_batch_14000.pth"
checkpoint = torch.load(checkpoint_path, map_location="cuda:5" if torch.cuda.is_available() else "cpu")

epoch = checkpoint["epoch"]
batch_idx = checkpoint["batch_idx"]
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])