In [1]:
import torch
device = "cuda:0"
torch.cuda.set_device(device)

In [2]:
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

#from modules.lm_head import LMHeadModel
from modules.modeling_phi import PhiForCausalLM
from utils.config import Config
teacher_model = PhiForCausalLM.from_pretrained(
    "microsoft/phi-1.5", attn_implementation="eager"
).to(device)
teacher_model.eval()
teacher_model.requires_grad_(False)

data = load_dataset("stas/openwebtext-10k")["train"]

tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")



  from .autonotebook import tqdm as notebook_tqdm
PhiForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [3]:
split = data.train_test_split(test_size = 0.2, seed=42)

train_dataset = split["train"]
test_dataset = split["test"]

In [4]:
def init_weights(module):
    if isinstance(module, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, torch.nn.LayerNorm):
        torch.nn.init.ones_(module.weight)
        torch.nn.init.zeros_(module.bias)
    elif isinstance(module, torch.nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


In [5]:

model_config = Config.from_json("assets/sample_config.json")
model_config = AutoConfig.from_pretrained("/common-repos/xLSTM-7b")
model_config.num_blocks = 16
model_config.head_dim = 256
model_config.num_heads = 8
model_config.embedding_dim = 2048
model_config.vocab_size = teacher_model.config.vocab_size
student_model = AutoModelForCausalLM.from_pretrained("/common-repos/xLSTM-7b", config=model_config, ignore_mismatched_sizes=True).to(device)



#student_model.apply(init_weights)
student_model.to(device)
student_model.requires_grad_(True)


Loading checkpoint shards: 100%|██████████| 6/6 [00:01<00:00,  5.43it/s]
Some weights of the model checkpoint at /common-repos/xLSTM-7b were not used when initializing xLSTMForCausalLM: ['backbone.blocks.16.ffn.proj_down.weight', 'backbone.blocks.16.ffn.proj_up.weight', 'backbone.blocks.16.ffn.proj_up_gate.weight', 'backbone.blocks.16.mlstm_layer.fgate_preact.bias', 'backbone.blocks.16.mlstm_layer.fgate_preact.weight', 'backbone.blocks.16.mlstm_layer.igate_preact.bias', 'backbone.blocks.16.mlstm_layer.igate_preact.weight', 'backbone.blocks.16.mlstm_layer.k.weight', 'backbone.blocks.16.mlstm_layer.multihead_norm.weight', 'backbone.blocks.16.mlstm_layer.ogate_preact.weight', 'backbone.blocks.16.mlstm_layer.out_proj.weight', 'backbone.blocks.16.mlstm_layer.q.weight', 'backbone.blocks.16.mlstm_layer.v.weight', 'backbone.blocks.16.norm_ffn.weight', 'backbone.blocks.16.norm_mlstm.weight', 'backbone.blocks.17.ffn.proj_down.weight', 'backbone.blocks.17.ffn.proj_up.weight', 'backbone.blocks.17.

xLSTMForCausalLM(
  (backbone): xLSTMModel(
    (embeddings): Embedding(51200, 2048)
    (blocks): ModuleList(
      (0-15): 16 x mLSTMBlock(
        (norm_mlstm): RMSNorm()
        (mlstm_layer): mLSTMLayer(
          (q): Linear(in_features=2048, out_features=1024, bias=False)
          (k): Linear(in_features=2048, out_features=1024, bias=False)
          (v): Linear(in_features=2048, out_features=2048, bias=False)
          (ogate_preact): Linear(in_features=2048, out_features=2048, bias=False)
          (igate_preact): Linear(in_features=2048, out_features=8, bias=True)
          (fgate_preact): Linear(in_features=2048, out_features=8, bias=True)
          (ogate_act_fn): Sigmoid()
          (mlstm_backend): mLSTMBackend(mLSTMBackendConfig(chunkwise_kernel='chunkwise--triton_xl_chunk', sequence_kernel='native_sequence__triton', step_kernel='triton', mode='inference', chunk_size=64, return_last_states=True, autocast_kernel_dtype='bfloat16', eps=1e-06, inference_state_dtype='float32

In [6]:
path = "checkpoints/1stage_epoch_1_idx_200.pt"
checkpoint = torch.load(path, map_location=device, weights_only=True)
model_state_to_load = checkpoint['model_state_dict']
start_epoch = checkpoint['epoch']
start_idx = checkpoint['idx']


student_model.load_state_dict(model_state_to_load)
student_model.requires_grad_(True)

xLSTMForCausalLM(
  (backbone): xLSTMModel(
    (embeddings): Embedding(51200, 2048)
    (blocks): ModuleList(
      (0-15): 16 x mLSTMBlock(
        (norm_mlstm): RMSNorm()
        (mlstm_layer): mLSTMLayer(
          (q): Linear(in_features=2048, out_features=1024, bias=False)
          (k): Linear(in_features=2048, out_features=1024, bias=False)
          (v): Linear(in_features=2048, out_features=2048, bias=False)
          (ogate_preact): Linear(in_features=2048, out_features=2048, bias=False)
          (igate_preact): Linear(in_features=2048, out_features=8, bias=True)
          (fgate_preact): Linear(in_features=2048, out_features=8, bias=True)
          (ogate_act_fn): Sigmoid()
          (mlstm_backend): mLSTMBackend(mLSTMBackendConfig(chunkwise_kernel='chunkwise--triton_xl_chunk', sequence_kernel='native_sequence__triton', step_kernel='triton', mode='inference', chunk_size=64, return_last_states=True, autocast_kernel_dtype='bfloat16', eps=1e-06, inference_state_dtype='float32

In [7]:
print(start_epoch, start_idx)

1 200


In [7]:
def collate_fn(batch):
    texts = [item['text'] for item in batch]
    tokenizer.pad_token = tokenizer.eos_token
    encodings = tokenizer(
        texts,
        return_tensors="pt",
        truncation=True,
        padding='longest'
    )
    return {
        "input_ids": encodings["input_ids"],
        "attention_mask": encodings["attention_mask"]
    }

In [8]:

tokenizer.pad_token = tokenizer.eos_token
batch_size = 4
dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, collate_fn = collate_fn)

In [9]:
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR

In [10]:
import math
num_teacher_layers = len(teacher_model.model.layers)
num_student_layers = len(student_model.backbone.blocks)

vocab_size = teacher_model.vocab_size


In [12]:
parameters = student_model.parameters()
optimizer = optim.Adam(params = parameters, lr = 1e-4)

scheduler = StepLR(optimizer, step_size=2, gamma=0.5)

In [13]:
import os
checkpoints_path = "./checkpoints"

os.makedirs(checkpoints_path, exist_ok = True)

In [13]:
print(len(dataloader))

2000


In [14]:
freeze_mlp = True


#start_epoch = 0
#start_idx = 0
num_epoch = 10
for epoch in range(start_epoch, num_epoch):
    
    epoch_loss = 0.0
    for idx, batch in enumerate(dataloader):
        if idx < start_idx:
                continue
        start_idx = 0
        
        batch = {k: v.to(device) for k, v in batch.items()}
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
       
        teacher_outputs = teacher_model(
        input_ids=input_ids,
        output_hidden_states=True,
        use_cache=False,
        output_attention_results=freeze_mlp,
        )
        total_loss = 0.0
        for block_idx, student_block in enumerate(student_model.backbone.blocks):
            student_input = teacher_outputs.all_hidden_states[block_idx]
            
            teacher_idx = round(block_idx * (num_teacher_layers-1)/(num_student_layers-1))
            
            # Forward pass
            student_output = student_block(student_input)
            teacher_hstate = (
                teacher_outputs.all_attn_outputs[teacher_idx]
                if freeze_mlp
                else teacher_outputs.all_hidden_states[teacher_idx + 1]
            )
            assert student_output[0].size() == teacher_hstate.size()

            loss = torch.norm(
                student_output[0] - teacher_hstate, p=2, dim=(-1,)
            ).mean()
            
            total_loss += loss


            
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()  
        loss_avg = total_loss.item()/(num_student_layers*batch_size)
        epoch_loss += loss_avg
        if idx % 5 == 0:
            print(f"Epoch: {epoch}, Iter: {idx}, Loss: {loss_avg}")
            
            if idx % 100 == 0:
                checkpoint_path = os.path.join(checkpoints_path, f"1stage_epoch_{epoch}_idx_{idx}.pt")
                torch.save({'idx': idx ,
                'epoch': epoch, 
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_avg,
                }, checkpoint_path) 
                
    scheduler.step()                         
    final_loss = epoch_loss/len(dataloader)
    checkpoint_path = os.path.join(checkpoints_path, f"1stage_epoch_{epoch}.pt")
    torch.save({'epoch': epoch , 
                    'model_state_dict': student_model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': final_loss,
                    }, checkpoint_path) 

Epoch: 0, Iter: 0, Loss: 5.182876110076904
Epoch: 0, Iter: 5, Loss: 9.467247009277344
Epoch: 0, Iter: 10, Loss: 10.98474407196045


KeyboardInterrupt: 