In [1]:
import os
import copy
import transformers
import datasets
import torch
import pickle
from tqdm import tqdm
from torch.utils.data import IterableDataset
import math
import datasets

In [2]:
HF_TOKEN = os.environ["HF_TOKEN"]
BASE_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
DATASET_PATH = "G:/data/layerwisetraining"
BF16 = torch.cuda.is_bf16_supported()
DTYPE = torch.bfloat16 if BF16 else torch.float16
DEVICE = "cuda:0"
DEVICE_ORIGINAL_MODEL = "cpu"
CHECKPOINT_PATH = "F:/layerwisetraining"
LOG_DIR = "F:/layerwisetraining/tensorboard"
RANDOM_SEED = 20250330

In [3]:
ds = datasets.load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")
ds

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1801350
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [4]:
base_model = transformers.AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    token=HF_TOKEN,
    torch_dtype=DTYPE,
    device_map=DEVICE_ORIGINAL_MODEL,
    attn_implementation="eager",
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    BASE_MODEL,
    token=HF_TOKEN,
)


In [5]:
original_layers = [
    layer
    for layer in base_model.model.layers
]

In [6]:
def convert_to_embeddings_dataset(ds, tokenizer, directory, chunk_size):
    os.makedirs(directory, exist_ok=True)
    texts = ds['text']
    chunk_count = int(math.ceil(
        len(texts) / chunk_size
    ))
    for i in tqdm(range(chunk_count)):
        fname = os.path.join(directory, f"chunk-{i}.pkl")
        if os.path.exists(fname):
            continue
        chunk_texts = texts[i * chunk_size : (i + 1) * chunk_size]
        chunk_batches = {}
        for j, text in enumerate(chunk_texts):
            global_index = i * chunk_size + j
            batch = tokenizer(text, return_tensors="pt")
            batch["labels"] = batch["input_ids"]
            batch["inputs_embeds"] = base_model.model.embed_tokens(batch["input_ids"])
            del batch["input_ids"]
            batch_flatten = {
                key: value[0]
                for key, value in batch.items()
            }
            chunk_batches[global_index] = batch_flatten
        with open(fname, "wb") as f:
            pickle.dump(chunk_batches, f)

In [7]:
convert_to_embeddings_dataset(
    ds['test'].filter(lambda x: len(x['text']) > 100).shuffle(seed=RANDOM_SEED),
    tokenizer,
    os.path.join(DATASET_PATH, "test-layer-0-inputs"),
    chunk_size=1000,
)

100%|██████████| 2/2 [00:00<?, ?it/s]


In [8]:
convert_to_embeddings_dataset(
    ds['train'].filter(lambda x: len(x['text']) > 100).shuffle(seed=RANDOM_SEED),
    tokenizer,
    os.path.join(DATASET_PATH, "train-layer-0-inputs"),
    chunk_size=1000,
)

100%|██████████| 750/750 [00:00<00:00, 61336.97it/s]


In [9]:
class EmbeddingsDataset(IterableDataset):
    def __init__(self, dirname):
        self.dirname = dirname
        self.chunk_size, self.total_size = self._get_sizes()

    def _read(self, fname):
        with open(os.path.join(self.dirname, fname), 'rb') as src:
            return pickle.load(src)
    
    def _get_sizes(self):
        chunk_size = len(self._read('chunk-0.pkl'))
        chunk_files = [
            fname
            for fname in os.listdir(self.dirname)
            if fname.startswith("chunk") and fname.endswith(".pkl")
        ]
        chunk_count = len(chunk_files)
        last_chunk_size = len(self._read(f'chunk-{chunk_count-1}.pkl'))
        return chunk_size, chunk_size * (chunk_count - 1) + last_chunk_size

    def __len__(self):
        return self.total_size
    
    def __iter__(self):
        chunk_count = int(math.ceil(self.total_size / self.chunk_size))
        for i in range(chunk_count):
            chunk_data = self._read(f'chunk-{i}.pkl')
            for value in chunk_data.values():
                yield value


def get_embeddings_dataset(dirname):
    return EmbeddingsDataset(dirname)


In [10]:
def collate_fn(batch):
    with torch.no_grad():
        # Get shapes
        max_len = max(item['inputs_embeds'].shape[0] for item in batch)
        batch_size = len(batch)
        hidden_size = batch[0]['inputs_embeds'].shape[1]
        
        # Initialize padded tensors
        padded_input_embeds = torch.rand((batch_size, max_len, hidden_size), dtype=DTYPE)
        padded_attention_mask = torch.zeros((batch_size, max_len))
        padded_labels = torch.full((batch_size, max_len), fill_value=-100)
        
        # Fill padded tensors with actual values
        for i, item in enumerate(batch):
            seq_len = item['inputs_embeds'].shape[0]
            padded_input_embeds[i, :seq_len] = item['inputs_embeds']
            padded_attention_mask[i, :seq_len] = item['attention_mask']
            padded_labels[i, :seq_len] = item['labels']
            
        return {
            'inputs_embeds': padded_input_embeds,
            'attention_mask': padded_attention_mask,
            'labels': padded_labels
        }

In [11]:
def create_one_layer_model(base_model, attention_implementation):
    config = copy.deepcopy(base_model.config)
    config.num_hidden_layers = 1
    config._attn_implementation_autoset = False
    config._attn_implementation = attention_implementation
    config.rms_norm_eps = 1e-5
    model = transformers.LlamaForCausalLM(config).to(
        dtype=DTYPE
    ).to(
        device=DEVICE
    )
    for original_module, new_module in [
        (base_model.model.embed_tokens, model.model.embed_tokens),
        (base_model.model.norm, model.model.norm),
        (base_model.model.rotary_emb, model.model.rotary_emb),
        (base_model.lm_head, model.lm_head),
    ]:
        new_module.load_state_dict(original_module.state_dict())
    for freeze_module in [model.model.embed_tokens, model.model.norm, model.model.rotary_emb, model.lm_head]:
        for param in freeze_module.parameters():
            param.requires_grad = False
    return model


In [12]:
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [13]:
BATCH_SIZE = 6
GRADIENT_ACCUMULATION_STEPS = 20
LR = 1e-4
WEIGHT_DECAY = 0.01
SAVE_STEPS = 500
SAVE_LIMIT = 3
WARMUP_STEPS = 500
MAX_GRAD_NORN = 1.0
RANDOM_SEED = 42
EARLY_STOPPING_PATIENCE = 1
GRADIENT_NORM_LOGGING_FREQUENCY = 20
#MAX_STEPS = 6000
MAX_EPOCHS = 1

In [14]:
layer_idx = 1

In [15]:
torch.manual_seed(RANDOM_SEED)
ds_embeddings_train = get_embeddings_dataset(
    os.path.join(DATASET_PATH, f"train-layer-{layer_idx}-inputs")
)
ds_embeddings_train = get_embeddings_dataset(
    os.path.join(DATASET_PATH, f"test-layer-{layer_idx}-inputs")
)

In [16]:
torch.manual_seed(RANDOM_SEED)
ds_embeddings_train = get_embeddings_dataset(
    os.path.join(DATASET_PATH, f"train-layer-{layer_idx}-inputs")
)
ds_embeddings_test = get_embeddings_dataset(
    os.path.join(DATASET_PATH, f"test-layer-{layer_idx}-inputs")
)
train_checkpoint_path = os.path.join(CHECKPOINT_PATH, f"layer-{layer_idx}-train-checkpoints")
train_logging_path = os.path.join(CHECKPOINT_PATH, f"layer-{layer_idx}-train-logs")
train_model = create_one_layer_model(base_model, "eager") # Flash Attention 2 returns NaN gradients
train_model.model.layers[0].load_state_dict(base_model.model.layers[layer_idx].state_dict()) # Load layer 0 to one-layer model

<All keys matched successfully>

In [17]:
train_dataloader = DataLoader(
    ds_embeddings_train,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
    collate_fn=collate_fn,
)
test_dataloader = DataLoader(
    ds_embeddings_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
    collate_fn=collate_fn,
)

In [18]:
WARMUP_STEPS

500

In [19]:
(MAX_EPOCHS * len(train_dataloader)) // GRADIENT_ACCUMULATION_STEPS

6247

In [20]:
optimizer = torch.optim.AdamW(
    train_model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY
)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=(MAX_EPOCHS * len(train_dataloader)) // GRADIENT_ACCUMULATION_STEPS,
)

In [21]:
train_model = train_model.to(DEVICE).train()

In [22]:
def evaluate(model, dataloader, verbose=False):
    model.eval()
    loss_values = []
    if verbose:
        dataloader = tqdm(dataloader)
    with torch.no_grad():
        for batch in dataloader:
            inputs = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            loss_values.append(loss.item())
    return sum(loss_values) / len(loss_values)
    

In [23]:
def train(train_model, train_dataloader, optimizer, lr_scheduler, logger):
    step = 0
    for epoch in range(MAX_EPOCHS):
        print(f"EPOCH {epoch}")
        print("TRAINING")
        train_model.train()
        for batch_idx, batch in enumerate(tqdm(train_dataloader)):
            if batch_idx % GRADIENT_ACCUMULATION_STEPS:
                step += 1
            inputs = {k: v.to(train_model.device) for k, v in batch.items()}
            outputs = train_model(**inputs)
            loss = outputs.loss
            (loss / GRADIENT_ACCUMULATION_STEPS).backward()
            logger("loss", step, loss.item())
            global_step = (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS
            if global_step == 0 or batch_idx == len(train_dataloader) - 1:
                torch.nn.utils.clip_grad_norm_(train_model.parameters(), MAX_GRAD_NORN)
                if global_step % GRADIENT_NORM_LOGGING_FREQUENCY == 0:
                    with torch.no_grad():
                        for name, param in train_model.named_parameters():
                            if param.grad is not None:
                                param_norm = param.grad.data.norm(2).item()
                                logger(f"grad_norm/{name}", step, param_norm)
                # Log the current learning rate
                current_lr = lr_scheduler.get_last_lr()[0]
                logger("learning_rate", step, current_lr)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                if step % SAVE_STEPS == 0:
                    yield step
                    train_model.train()


In [24]:
def save_checkpoint(train_model, optimizer, lr_scheduler, step):
    os.makedirs(train_checkpoint_path, exist_ok=True)
    previous_checkpoints = [
        fname
        for fname in os.listdir(train_checkpoint_path)
        if fname.startswith("checkpoint")
    ]
    previous_checkpoints_indices = [
        int(checkpoint_fname.split("-")[-1].replace(".bin", ""))
        for checkpoint_fname in previous_checkpoints
    ]
    previous_checkpoints_data = sorted(
        zip(previous_checkpoints, previous_checkpoints_indices),
        key=lambda pair: pair[1]
    )
    if len(previous_checkpoints_data) > (SAVE_LIMIT - 1):
        for fname, _ in previous_checkpoints_data[:-(SAVE_LIMIT - 1)]:
            os.remove(os.path.join(train_checkpoint_path, fname))
    checkpoint = {
        'model_state_dict': train_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': lr_scheduler.state_dict(),
        'step': step,
    }
    fname = os.path.join(train_checkpoint_path, f"checkpoint-{step}.bin")
    torch.save(checkpoint, fname)
    return fname

In [25]:
layer_idx

1

In [None]:
log_dir = os.path.join(LOG_DIR, f"layer-{layer_idx}")
writer = SummaryWriter(log_dir=log_dir)
initial_loss = evaluate(train_model, test_dataloader, verbose=True)
writer.add_scalar('eval/loss', initial_loss, 0)
eval_without_improvement = 0
best_loss = None
for save_step in train(
    train_model=train_model,
    train_dataloader=train_dataloader,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    logger=lambda metric, step, loss: writer.add_scalar(f'train/{metric}', loss, step)
):
    loss = evaluate(train_model, test_dataloader, verbose=False)
    writer.add_scalar('eval/loss', loss, save_step)
    if (best_loss is None) or (loss < best_loss):
        best_loss = loss
        eval_without_improvement = 0
    else:
        eval_without_improvement += 1
    if eval_without_improvement > EARLY_STOPPING_PATIENCE:
        break
    save_checkpoint(
        train_model=train_model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        step=save_step,
    )

100%|██████████| 306/306 [00:22<00:00, 13.38it/s]


EPOCH 0
TRAINING


  8%|▊         | 10366/124957 [51:26<3:04:28, 10.35it/s]   

In [None]:
raise ValueError()

In [53]:
def convert_to_next_layer_dataset(ds, model, path, chunk_size):
    model.eval()
    def _chunk_iterate():
        chunk = {}
        for i, item in enumerate(tqdm(ds)):
            with torch.no_grad():
                batch = {
                    key: value.reshape([1] + list(value.shape)).to(model.device)
                    for key, value in item.items()
                }
                output = model(**batch, output_hidden_states=True)
                hidden_state = output.hidden_states[-1]
                hidden_state = hidden_state.reshape(hidden_state.shape[1:])
                item["inputs_embeds"] = hidden_state.to("cpu")
                chunk[i] = item
                if len(chunk) == chunk_size:
                    yield chunk
                    chunk = {}
        if len(chunk) > 0:
            yield chunk
        
    def _main():
        os.makedirs(path, exist_ok=True)
        for i, chunk in enumerate(_chunk_iterate()):
            full_path = os.path.join(path, f"chunk-{i}.pkl")
            with open(full_path, "wb") as dst:
                pickle.dump(chunk, dst)
    
    _main()

In [None]:
convert_to_next_layer_dataset(
    ds_embeddings_test,
    train_model,
    "G:\\data\\layerwisetraining\\test-layer-1-inputs",
    1000,
)

In [None]:
convert_to_next_layer_dataset(
    ds_embeddings_train,
    train_model,
    "G:\\data\\layerwisetraining\\train-layer-1-inputs",
    1000,
)