In [1]:
import torch
import pandas
from Dataset.Dataset import get_tokenizer, get_dataset_loader
from Model.Model import SLMModel
import torchinfo
from tqdm import tqdm
from Model.GradViewer import GradViewer
import yaml

In [2]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)

In [3]:
BATCH_SIZE = 64
EMBEDDING_DIM = 512
MAX_SEQ_LEN = 256
VOCAB_SIZE =10000
LR_RATE = 1e-4

In [None]:
train_ds_path = config["data"]["train_path"]
val_ds_path = config["data"]["val_path"]

train_df = pandas.read_parquet(train_ds_path)
val_df = pandas.read_parquet(val_ds_path)

tokenizer = get_tokenizer(config["tokenizer_path"],MAX_SEQ_LEN)

train_dataloader = get_dataset_loader(train_df,tokenizer,"train_dataset_cache.pth",BATCH_SIZE,True,4,2)
val_dataloader = get_dataset_loader(val_df,tokenizer,"val_dataset_cache.pth",BATCH_SIZE,False,2,2)

In [None]:
device = torch.device("cuda")
model = SLMModel(VOCAB_SIZE,EMBEDDING_DIM,4,8,64,64,0.1,1024,0.2)
model = model.to(device)
compiled_model = torch.compile(model,mode="default",dynamic=False)

In [6]:
torchinfo.summary(model,input_data=torch.randint(0,VOCAB_SIZE,(BATCH_SIZE,MAX_SEQ_LEN),device=device))

Layer (type:depth-idx)                        Output Shape              Param #
SLMModel                                      [64, 256, 10000]          --
├─Embedding: 1-1                              [64, 256, 512]            5,120,000
├─SinusoidalPositionalEmbedding: 1-2          [64, 256, 512]            --
├─ModuleList: 1-3                             --                        --
│    └─TransformerBlock: 2-1                  [64, 256, 512]            --
│    │    └─RMSNorm: 3-1                      [64, 256, 512]            512
│    │    └─FlashMultiHeadAttention: 3-2      [64, 256, 512]            1,050,624
│    │    └─RMSNorm: 3-3                      [64, 256, 512]            512
│    │    └─RMSNorm: 3-4                      [64, 256, 512]            512
│    │    └─FeedForwardBlock: 3-5             [64, 256, 512]            1,575,424
│    │    └─RMSNorm: 3-6                      [64, 256, 512]            512
│    └─TransformerBlock: 2-2                  [64, 256, 512]          

In [7]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),LR_RATE)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',           
    factor=0.5,
    patience=1,
    threshold=0.0001,
    cooldown=0,
    min_lr=1e-6
)
grad_scaler = torch.GradScaler()
grad_viewer = GradViewer(model,"SLM L2 norm grads")



In [None]:
def train(model: torch.nn.Module,saved_model: torch.nn.Module, num_epochs, train_loader, val_loader, 
          optimizer, loss_fn, device,scheduler=None, grad_scaler=None):
    
    best_val_loss=0

    for epoch in range(num_epochs):
        
        model.train()
        train_loss = 0
        
        for batch in tqdm(train_loader):
            inputs, targets = batch  
            inputs = inputs.to(device)  
            targets = targets.to(device)  
            
            batch_size = inputs.shape[0]
            
            optimizer.zero_grad()
            
           
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                preds = model(inputs)
                
               
                loss = loss_fn(
                    preds.reshape((-1,VOCAB_SIZE)), 
                    targets.reshape((-1,))
                )
            
            
            if grad_scaler:
                grad_scaler.scale(loss).backward()
                grad_viewer.view_grad()
                grad_scaler.step(optimizer)
                grad_scaler.update()
            else:
                loss.backward()
                optimizer.step()
            
            train_loss += loss.item()
        
    
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                inputs, targets = batch
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                batch_size = inputs.shape[0]
                preds = model(inputs)
                loss = loss_fn(
                    preds.reshape((-1,VOCAB_SIZE)),
                    targets.reshape((-1,))
                )
                val_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train loss: {train_loss/len(train_loader):.4f}")
        print(f"Val loss: {val_loss/len(val_loader):.4f}")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        
        if scheduler:
            scheduler.step(avg_val_loss)

        

        if best_val_loss is None or best_val_loss > avg_val_loss:
            torch.save(saved_model.state_dict(),f"best_model_st{epoch+1}.pt")
            print(f"Модель сохранена на эпохе: {epoch+1}")
            best_val_loss = avg_val_loss


In [None]:
train(compiled_model,model,50,train_dataloader,val_dataloader,optimizer,loss_fn,device,scheduler,grad_scaler)