In [1]:
import os
import copy
import transformers
import datasets
import torch
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset

In [2]:
HF_TOKEN = os.environ["HF_TOKEN"]
BASE_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
DATASET_PATH = "G:/data/layerwisetraining"
BF16 = torch.cuda.is_bf16_supported()
DTYPE = torch.bfloat16 if BF16 else torch.float16
DEVICE = "cuda:0"
DEVICE_ORIGINAL_MODEL = "cpu"

In [3]:
ds = datasets.load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")
ds

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1801350
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [4]:
base_model = transformers.AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    token=HF_TOKEN,
    torch_dtype=DTYPE,
    device_map=DEVICE_ORIGINAL_MODEL,
    attn_implementation="eager",
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    BASE_MODEL,
    token=HF_TOKEN,
)


In [5]:
original_layers = [
    layer
    for layer in base_model.model.layers
]

In [6]:
def convert_to_embeddings_dataset(ds, tokenizer, directory):
    os.makedirs(directory, exist_ok=True)
    texts = ds['text']
    for i, text in enumerate(tqdm(texts)):
        fname = os.path.join(directory, f"record-{i}.pkl")
        if os.path.exists(fname):
            continue
        batch = tokenizer(text, return_tensors="pt")
        batch["labels"] = batch["input_ids"]
        batch["inputs_embeds"] = base_model.model.embed_tokens(batch["input_ids"])
        batch_flatten = {
            key: value[0]
            for key, value in batch.items()
            if key not in ["input_ids"]
        }
        with open(fname, "wb") as f:
            pickle.dump(batch_flatten, f)

In [7]:
convert_to_embeddings_dataset(
    ds['test'].filter(lambda x: len(x['text']) > 100),
    tokenizer,
    os.path.join(DATASET_PATH, "test-layer-0-inputs")
)

100%|██████████| 1835/1835 [00:00<00:00, 57342.35it/s]


In [8]:
convert_to_embeddings_dataset(
    ds['train'].filter(lambda x: len(x['text']) > 100),
    tokenizer,
    os.path.join(DATASET_PATH, "train-layer-0-inputs")
)

100%|██████████| 749740/749740 [00:24<00:00, 30213.24it/s]


In [None]:
convert_to_embeddings_dataset(
    ds['train'].filter(lambda x: len(x['text']) > 100).map(lambda x: {'text': x['text'][:128]}),
    tokenizer,
    os.path.join(DATASET_PATH, "train-layer-0-inputs-pretrain")
)

In [9]:
class EmbeddingsDataset(Dataset):
    def __init__(self, directory):
        self.files = [
            os.path.join(directory, fname)
            for fname in os.listdir(directory)
            if fname.endswith(".pkl")
        ]
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        with open(self.files[idx], "rb") as f:
            return pickle.load(f)

In [10]:
ds_embeddings_train = EmbeddingsDataset(
    os.path.join(DATASET_PATH, "train-layer-0-inputs")
)
ds_embeddings_test = EmbeddingsDataset(
    os.path.join(DATASET_PATH, "test-layer-0-inputs")
)


In [11]:
ds_embeddings_train[0].keys()

dict_keys(['attention_mask', 'labels', 'inputs_embeds'])

In [12]:
def collate_fn(batch):
    # Get shapes
    max_len = max(item['inputs_embeds'].shape[0] for item in batch)
    batch_size = len(batch)
    hidden_size = batch[0]['inputs_embeds'].shape[1]
    
    # Initialize padded tensors
    padded_input_embeds = torch.rand((batch_size, max_len, hidden_size), dtype=DTYPE)
    padded_attention_mask = torch.zeros((batch_size, max_len))
    padded_labels = torch.full((batch_size, max_len), fill_value=-100)
    
    # Fill padded tensors with actual values
    for i, item in enumerate(batch):
        seq_len = item['inputs_embeds'].shape[0]
        padded_input_embeds[i, :seq_len] = item['inputs_embeds']
        padded_attention_mask[i, :seq_len] = item['attention_mask']
        padded_labels[i, :seq_len] = item['labels']
        
    return {
        'inputs_embeds': padded_input_embeds,
        'attention_mask': padded_attention_mask,
        'labels': padded_labels
    }


In [13]:
config = copy.deepcopy(base_model.config)
config.num_hidden_layers = 1
config._attn_implementation_autoset = False
config._attn_implementation = "eager"
config.rms_norm_eps = 1e-5
model = transformers.LlamaForCausalLM(config).to(
    dtype=DTYPE
).to(
    device=DEVICE
)
for original_module, new_module in [
    (base_model.model.embed_tokens, model.model.embed_tokens),
    (base_model.model.norm, model.model.norm),
    (base_model.model.rotary_emb, model.model.rotary_emb),
    (base_model.lm_head, model.lm_head),
]:
    new_module.load_state_dict(original_module.state_dict())
for freeze_module in [model.model.embed_tokens, model.model.norm, model.model.rotary_emb, model.lm_head]:
    for param in freeze_module.parameters():
        param.requires_grad = False

In [14]:
#model.model.layers[0].load_state_dict(
#    original_layers[0].state_dict()
#)

In [15]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

In [16]:
# Add gradient clipping to RMSNorm layers
for layer in model.model.layers:
    layer.input_layernorm.weight.data.fill_(1.0)
    layer.post_attention_layernorm.weight.data.fill_(1.0)

for parameter in model.parameters():
    if parameter.requires_grad:
        parameter.register_hook(lambda grad: torch.clamp(grad, -1.0, 1.0))

def custom_rmsnorm_forward(self, hidden_states):
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    # Add value clipping here
    hidden_states = torch.clamp(hidden_states, -10.0, 10.0)
    return self.weight * hidden_states

# Monkey patch the RMSNorm forward method
for layer in model.model.layers:
    layer.input_layernorm.forward = lambda x: custom_rmsnorm_forward(layer.input_layernorm, x)
    layer.post_attention_layernorm.forward = lambda x: custom_rmsnorm_forward(layer.post_attention_layernorm, x)

In [17]:
from transformers.models.llama.modeling_llama import LlamaAttention


def custom_attention_forward(
        self,
        hidden_states,
        position_embeddings,
        attention_mask,
        past_key_value,
        cache_position,
        **kwargs,
    ):
    # Call the original forward
    outputs = LlamaAttention.forward(
        self,
        hidden_states,
        position_embeddings,
        attention_mask,
        past_key_value,
        cache_position,
        **kwargs
    )
    
    # If outputs is a tuple (for causal attention), get the first element
    if isinstance(outputs, tuple):
        attn_output = outputs[0]
    else:
        attn_output = outputs
        
    # Clip the attention output
    attn_output = torch.clamp(attn_output, -10.0, 10.0)
    
    # Return the clipped output in the same format as the original
    if isinstance(outputs, tuple):
        return (attn_output,) + outputs[1:]
    return attn_output

def build_custom_attention_forward(layer):
    return lambda *args, **kwargs: custom_attention_forward(layer, *args, **kwargs)

# Monkey patch the attention forward method
for layer in model.model.layers:
    layer.self_attn.forward = build_custom_attention_forward(layer.self_attn)



In [18]:

# Add gradient clipping to attention parameters
for layer in model.model.layers:
    for name, param in layer.self_attn.named_parameters():
        if param.requires_grad:
            param.register_hook(lambda grad: torch.clamp(grad, -0.1, 0.1))

In [19]:
#BATCH_SIZE = 4
#GRADIENT_ACCUMULATION_STEPS = 16
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 1
LR = 1e-5
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 1
SAVE_STEPS = 2000
WARMUP_STEPS = 500
MAX_GRAD_NORN = 0.05
RANDOM_SEED = 42

In [20]:
#class PrintGradNormCallback(transformers.TrainerCallback):
#    def on_pre_optimizer_step(self, args, state, control, model, **kwargs):
#        for name, param in model.named_parameters():
#            print(name)
#            if param.grad is not None:
#                param_norm = param.grad.detach().data.norm(2)
#                print(name, param_norm)

In [21]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRo

In [22]:
training_args = transformers.TrainingArguments(
    output_dir="./checkpoints",
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,
    warmup_steps=WARMUP_STEPS,
    save_steps=SAVE_STEPS,
    max_grad_norm=MAX_GRAD_NORN,
    seed=RANDOM_SEED,
    bf16=DTYPE == torch.bfloat16,
    fp16=DTYPE == torch.float16,
    remove_unused_columns=False,
    logging_dir="./logs",
    logging_steps=1,
    report_to="tensorboard",
)

trainer = transformers.Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_embeddings_test,
    data_collator=collate_fn,
    #callbacks=[
    #    PrintGradNormCallback()
    #],
)
torch.manual_seed(RANDOM_SEED)
with torch.autograd.set_detect_anomaly(True):
    trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


before dtype conversion: hidden_states.min() tensor(-3.9062, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MinBackward1>)
before dtype conversion: hidden_states.max() tensor(3.9375, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
after dtype conversion: hidden_states.min() tensor(-3.9062, device='cuda:0', grad_fn=<MinBackward1>)
after dtype conversion: hidden_states.max() tensor(3.9375, device='cuda:0', grad_fn=<MaxBackward1>)
after variance calculation: variance.min() tensor(0.7309, device='cuda:0', grad_fn=<MinBackward1>)
after variance calculation: variance.max() tensor(1.0755, device='cuda:0', grad_fn=<MaxBackward1>)
after rsqrt calculation: hidden_states.min() tensor(-4.2817, device='cuda:0', grad_fn=<MinBackward1>)
after rsqrt calculation: hidden_states.max() tensor(4.2964, device='cuda:0', grad_fn=<MaxBackward1>)
after dtype backward conversion: hidden_states.min() tensor(-4.2812, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MinBackward1>)
after dtype backwar

Step,Training Loss
1,13.7626
2,14.2554
3,12.9706
4,13.8306
5,14.3144
6,13.5982
7,13.6947
8,13.4728
9,13.8389
10,12.7069


before dtype conversion: hidden_states.min() tensor(-4.2188, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MinBackward1>)
before dtype conversion: hidden_states.max() tensor(4.1250, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
after dtype conversion: hidden_states.min() tensor(-4.2188, device='cuda:0', grad_fn=<MinBackward1>)
after dtype conversion: hidden_states.max() tensor(4.1250, device='cuda:0', grad_fn=<MaxBackward1>)
after variance calculation: variance.min() tensor(0.7755, device='cuda:0', grad_fn=<MinBackward1>)
after variance calculation: variance.max() tensor(1.1381, device='cuda:0', grad_fn=<MaxBackward1>)
after rsqrt calculation: hidden_states.min() tensor(-4.5180, device='cuda:0', grad_fn=<MinBackward1>)
after rsqrt calculation: hidden_states.max() tensor(4.4037, device='cuda:0', grad_fn=<MaxBackward1>)
after dtype backward conversion: hidden_states.min() tensor(-4.5312, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MinBackward1>)
after dtype backwar