## Download MobileLLM (125M) weights from HuggingFace

In [None]:
!curl -L -o ../data/MobileLLM/model.safetensors https://huggingface.co/mia-llm/MobileLLM-125M-wikitext2raw-hosein/resolve/main/model.safetensors

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1116  100  1116    0     0   5406      0 --:--:-- --:--:-- --:--:--  5417
100  475M  100  475M    0     0  52.7M      0  0:00:09  0:00:09 --:--:-- 56.0M0  0:00:09  0:00:08  0:00:01 55.6M


In [1]:
from transformers import LlamaConfig
from attention_approximation.modeling_llama import LlamaForCausalLM as TeacherModel
from attention_approximation.pytorch import intersect_dicts
from attention_approximation.utils import LOGGER
import safetensors
from attention_approximation.modeling_llama_approximated import LlamaModel
from copy import copy


# Read teacher
model_config_path = "../data/MobileLLM/config.json"
model_weights_path = "../data/MobileLLM/model.safetensors"

config = LlamaConfig().from_json_file(model_config_path)
teacher = TeacherModel(config)
checkpoint = safetensors.torch.load_file(model_weights_path)
csd = intersect_dicts(checkpoint, teacher.state_dict())  # intersect
teacher.load_state_dict(csd, strict=False)  # load
LOGGER.info(f"Transferred {len(csd)}/{len(teacher.state_dict())} items from pretrained weights")

# Freezing weights
for param in teacher.parameters():
    param.requires_grad = False


# Read student
student_config = copy(config)
student_config.factorization_rank = config.hidden_size // 4  # Low-rank factorization
student_config.layer_sharing = False
student_config.seq_length = 512

student = LlamaModel(student_config)
new_state_dict = {k.replace("model.", ""): v for k, v in checkpoint.items()} # fixes key mismatch
csd = intersect_dicts(new_state_dict, student.state_dict())  # intersect
student.load_state_dict(csd, strict=False)  # load
LOGGER.info(f"Transferred {len(csd)}/{len(student.state_dict())} items from pretrained weights")


LOGGER.info("Freezing student transferred weights")
for k, v in student.named_parameters():
    if "attn_approx" not in k:
        v.requires_grad = False

LlamaForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Transferred 272/272 items from pretrained weights
Transferred 152/212 items from pretrained weights
Freezing student transferred weights


In [2]:
from pathlib import Path
from attention_approximation.pytorch import WORLD_SIZE, LOCAL_RANK, RANK
from attention_approximation.data import DataLoaderLite
dataloader = DataLoaderLite(path=Path("../data/edu_fineweb10B"), batch_size=8, seq_len=512, process_rank=1, num_processes=WORLD_SIZE, split='train')
x, _ = dataloader.next_batch()
student(x)

Found 36 shards for split train


torch.Size([8, 512, 576]) torch.float32
dim, w shape 512 torch.Size([4608, 144])


ValueError: Factor 0 has shape torch.Size([4608, 144]), but expected (512, 144).