In [1]:
from multigpus_repl import init_multigpus_repl, multigpus

  import pynvml  # type: ignore[import]


In [2]:
init_multigpus_repl(print_on_rank=0)

[GPU 0] Worker 0/8 initialized on GPU 0 on localhost:12355

  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 7 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [3]:
%%multigpus

from torch import nn
import torch.nn.functional as F
import torch.nn.init as init
import math
from torch.utils.data import Dataset, DataLoader
from functools import partial
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    apply_activation_checkpointing,
    checkpoint_wrapper,
)
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel,
    PrepareModuleInput,
    SequenceParallel
)
from torch import distributed as dist
from torch.distributed.tensor import distribute_tensor
from torch.distributed.device_mesh import init_device_mesh
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from streaming import LocalDataset
from streaming.base.format.mds.encodings import Encoding, _encodings
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen3ForCausalLM
import transformers.models.qwen3.modeling_qwen3 as qwen3_modeling
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss

model_name = "Qwen/Qwen3-0.6B"

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

class Dataset(Dataset):
    def __init__(self, folder):
        self.dataset = LocalDataset(local=folder)
    
    def __getitem__(self, idx):
        data = self.dataset[idx]
        data.pop('text', None)
        data.pop('token_type_ids', None)

        for k in data.keys():
            data[k] = data[k].astype(np.int64)
    
        return data
    
    def __len__(self):
        return len(self.dataset)

def collator(batch):
    batch = [b for b in batch if b is not None]
    input_ids = [b['input_ids'] for b in batch]
    position_ids = [b['position_ids'] for b in batch]
    labels = [b['input_ids'].copy() for b in batch]
    attention_mask = [b['attention_mask'] for b in batch]
    input_ids = np.concatenate(input_ids)
    position_ids = np.concatenate(position_ids)
    labels = np.concatenate(labels)
    query_lens = np.concatenate(attention_mask)
    cumsum = [0] + np.cumsum(query_lens).tolist()
    max_cumsum = int(np.max(cumsum))
    cu_seq_lens_q = torch.tensor(cumsum, dtype=torch.int32)
    cu_seq_lens_k = torch.tensor(cumsum, dtype=torch.int32)
    max_seqlen_q = int(np.max(query_lens))
    return {
        'input_ids': torch.tensor(input_ids)[None],
        'position_ids': torch.tensor(position_ids)[None],
        'labels': torch.tensor(labels)[None],
        'cu_seq_lens_q': cu_seq_lens_q,
        'cu_seq_lens_k': cu_seq_lens_k,
        'max_length_q': max_seqlen_q,
        'max_length_k': max_seqlen_q
    }

class LinearLoRA(nn.Module):
    def __init__(self, linear: nn.Linear, r=4, alpha=1.0):
        super().__init__()
        self.linear = linear
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r

        in_features = linear.in_features
        out_features = linear.out_features
        
        device = self.linear.weight.device
        dtype = self.linear.weight.dtype

        self.lora_A = nn.ModuleDict({})
        self.lora_B = nn.ModuleDict({})
        
        self.lora_A['e'] = nn.Linear(
            in_features, r, bias=False, 
            device = device,
            dtype = dtype,
        )
        self.lora_B['e'] = nn.Linear(
            r, out_features, bias=False, 
            device = device,
            dtype = dtype,
        )

        for param in self.lora_A['e'].parameters():
            param.requires_grad = True
        for param in self.lora_B['e'].parameters():
            param.requires_grad = True

        # https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py#L260
        init.kaiming_uniform_(self.lora_A['e'].weight, a=math.sqrt(5))
        init.zeros_(self.lora_B['e'].weight)

    def forward(self, x):
        out = self.linear(x)
        lora_update = self.lora_B['e'](self.lora_A['e'](x.to(self.lora_A['e'].weight.dtype))) * self.scaling
        out = out + lora_update.to(x.dtype)
        return out

class Model(Qwen3ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.loss = LigerFusedLinearCrossEntropyLoss(reduction="mean")
        
    def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, **kwargs):
        super_out = self.model.forward(
            input_ids = input_ids,
            position_ids = position_ids, 
            attention_mask = attention_mask, 
            output_hidden_states = True,
            **kwargs,
        )
        if labels is not None:
            embeddings = super_out.last_hidden_state
            embeddings = embeddings[:,:-1].reshape(-1, embeddings.shape[-1])
            labels = labels[..., 1:].contiguous()
            labels = labels.reshape(-1)
            loss = self.loss(self.lm_head.weight, embeddings, labels)
            return {'loss': loss}
        return super_out

In [4]:
%%multigpus

device_type = torch.accelerator.current_accelerator()
device_mesh = init_device_mesh(device_type.type, (1, 4), mesh_dim_names=("dp", "tp"))
tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]
dp_rank = dp_mesh.get_local_rank()
dp_world_size = dp_mesh.size()

In [5]:
%%multigpus

dataset = Dataset('multipacking')
sampler = DistributedSampler(
    dataset,
    num_replicas=dp_world_size,
    rank=dp_rank,
    shuffle=True,
)
train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=5,
    sampler=sampler,
    num_workers=5,
    prefetch_factor=5,
    pin_memory=True,
    collate_fn=collator,
)
iter_train_loader = iter(train_loader)

In [6]:
%%multigpus

model = Model.from_pretrained(
    model_name, 
    attn_implementation='flash_attention_3',
    torch_dtype=torch.bfloat16,
)
selected = [
    "q_proj", 
    "k_proj", 
    "v_proj", 
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj"
]
for name, module in model.named_modules():
    for child_name, child in module.named_children():
        if len(child_name) and any([a in child_name for a in selected]) and isinstance(child, nn.Linear):
            lora = LinearLoRA(child, r=128, alpha=256)
            setattr(module, child_name, lora)

# model.model = parallelize_module(
#     model.model,
#     tp_mesh,
#     {
#         "embed_tokens": ColwiseParallel(
#             input_layouts=Replicate(),
#             output_layouts=Replicate(),
#         ),
#         "norm": SequenceParallel(),
#     }
# )

for layer_id, block in enumerate(model.model.layers):
    layer_tp_plan = {

        # "input_layernorm": SequenceParallel(),
        # "post_attention_layernorm": SequenceParallel(),

        "self_attn.q_proj.linear": ColwiseParallel(),
        "self_attn.q_proj.lora_A.e": RowwiseParallel(input_layouts=Replicate()),
        "self_attn.q_proj.lora_B.e": ColwiseParallel(),   

        "self_attn.k_proj.linear": ColwiseParallel(),
        "self_attn.k_proj.lora_A.e": RowwiseParallel(input_layouts=Replicate()),
        "self_attn.k_proj.lora_B.e": ColwiseParallel(),

        "self_attn.v_proj.linear": ColwiseParallel(),
        "self_attn.v_proj.lora_A.e": RowwiseParallel(input_layouts=Replicate()),
        "self_attn.v_proj.lora_B.e": ColwiseParallel(),

        "self_attn.o_proj.linear": RowwiseParallel(),
        "self_attn.o_proj.lora_A.e": RowwiseParallel(),
        # "self_attn.o_proj.lora_B.e": RowwiseParallel(),

        "mlp.gate_proj.linear": ColwiseParallel(),
        "mlp.gate_proj.lora_A.e": RowwiseParallel(input_layouts=Replicate()),
        "mlp.gate_proj.lora_B.e": ColwiseParallel(),

        "mlp.up_proj.linear": ColwiseParallel(),
        "mlp.up_proj.lora_A.e": RowwiseParallel(input_layouts=Replicate()),
        "mlp.up_proj.lora_B.e": ColwiseParallel(),

        "mlp.down_proj.linear": RowwiseParallel(),
        "mlp.down_proj.lora_A.e": RowwiseParallel(),
        # "mlp.down_proj.lora_B.e": RowwiseParallel(),
    }
    parallelize_module(
        module=block,
        device_mesh=tp_mesh,
        parallelize_plan=layer_tp_plan
    )
    
device = torch.device(f"{device_type}:{rank}")
model = model.to(device)

In [7]:
%%multigpus

b = next(iter_train_loader)
for k in b.keys():
    if isinstance(b[k], torch.Tensor):
        b[k] = b[k].to(device, non_blocking=True)

out = model(**b, use_cache=False)

In [8]:
%%multigpus

out

[GPU 0] {'loss': tensor(3.3200, device='cuda:0',
       grad_fn=<LigerFusedLinearCrossEntropyFunctionBackward>)}
