In [10]:
! pip install datasets transformers evaluate
! pip install cloud-tpu-client==0.10 torch==2.0.0
! pip install https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
! pip install git+https://github.com/huggingface/accelerate

Collecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[K     |████████████████████████████████| 81 kB 8.1 MB/s  eta 0:00:01
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
[31mERROR: responses 0.18.0 has requirement urllib3>=1.25.10, but you'll have urllib3 1.25.8 which is incompatible.[0m
Installing collected packages: responses, evaluate
Successfully installed evaluate-0.4.0 responses-0.18.0
[31mERROR: torch_xla-2.0-cp310-cp310-linux_x86_64.whl is not a supported wheel on this platform.[0m
Collecting git+https://github.com/huggingface/accelerate
  Cloning https://github.com/huggingface/accelerate to /tmp/pip-req-build-a2a0c9__
  Running command git clone -q https://github.com/huggingface/accelerate /tmp/pip-req-build-a2a0c9__
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone
Building wheels for collected packages: acceler

In [None]:
import os
from accelerate.utils import write_basic_config

write_basic_config()  # Write a config file
os._exit(00)  # Restart the notebook

  from .autonotebook import tqdm as notebook_tqdm


: 

: 

: 

: 

In [1]:
from torch.utils.data import DataLoader, Dataset

import torch 

class RandomIntDataset(Dataset):
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
    
    def __len__(self):
        return 10000

    def __getitem__(self, idx):
        return {"input_ids": torch.randint(0, self.vocab_size, (1,))}

def create_dataloader(vocab_size, batch_size=8):
    return DataLoader(RandomIntDataset(vocab_size), batch_size=batch_size)

dataloader = create_dataloader(32000)
for batch in dataloader:
    print({k: v.shape for k, v in batch.items()})
    break

{'input_ids': torch.Size([8, 1])}


In [2]:
from accelerate import Accelerator

import datasets
import transformers
from tqdm.auto import tqdm
from transformers import (
    AdamW,
    get_cosine_schedule_with_warmup,
    set_seed,
)
from torch.optim import AdamW


hyperparameters = {
    "learning_rate": 2e-5,
    "num_epochs": 3,
    "steps_per_epoch": 100,
    "validation_steps": 50,
    "batch_size": 8, # Actual batch size will this x 8
    "seed": 42,
    "vocab_size": 32000,
}

def training_loop(model):
    
    accelerator = Accelerator()
    
    # To have only one message (and not 8) per logs of Transformers or Datasets, we set the logging verbosity
    # to INFO for the main process only.
    if accelerator.is_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
        
    dataloader = create_dataloader(hyperparameters["vocab_size"], hyperparameters["batch_size"])
    
    set_seed(hyperparameters["seed"])
    
    optimizer = AdamW(model.parameters(), lr=hyperparameters["learning_rate"])
    
    # Prepare everything
    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
    # prepare method.
    model, optimizer, dataloader = accelerator.prepare(
        model, optimizer, dataloader
    )
    
    num_epochs = hyperparameters["num_epochs"]

    # Instantiate learning rate scheduler after preparing the training dataloader as the prepare method
    # may change its length.
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=hyperparameters["steps_per_epoch"] * num_epochs,
    )
    progress_bar = tqdm(range(num_epochs * hyperparameters["steps_per_epoch"]), disable=not accelerator.is_main_process)

    for epoch in range(num_epochs):
        model.train()
        model.lm_head.requires_grad_(False)
        model.model.requires_grad_(False)
        model.auxiliary_outputs.requires_grad_(True)
        batch = next(iter(dataloader))["input_ids"]
        for step in range(hyperparameters["steps_per_epoch"]):
            outputs = model(batch)
            loss = outputs.loss
            lm_head_logits = outputs.logits[-1]
            accelerator.backward(loss)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_description(f"Epoch {epoch} loss: {loss.item()}")
            batch = torch.cat((batch, torch.multinomial(torch.softmax(lm_head_logits[:, -1, :], dim=-1), 1)), dim=-1)
            
        model.eval()
        batch = next(iter(dataloader))["input_ids"]
        eval_loss = 0
        for step in range(hyperparameters["validation_steps"]):
            outputs = model(batch)
            eval_loss += outputs.loss
        loss = eval_loss / hyperparameters["validation_steps"]
        
        accelerator.print(f"Epoch {epoch} loss: {loss.item()}")
            

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from src.branchymodel import BranchyLlama

branchyllamaconf = BranchyLlama.config_class.from_pretrained(
    "openlm-research/open_llama_3b_v2"
)
branchyllamaconf.self_supervision = True
branchyllamaconf.num_hidden_layers = 2
model = BranchyLlama.from_pretrained(
    "openlm-research/open_llama_3b_v2", config=branchyllamaconf
)

print(model)



Some weights of the model checkpoint at openlm-research/open_llama_3b_v2 were not used when initializing BranchyLlama: ['model.layers.5.self_attn.q_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.23.mlp.up_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.21.mlp.up_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.13.in

BranchyLlama(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 3200, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (k_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (v_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (o_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3200, out_features=8640, bias=False)
          (down_proj): Linear(in_features=8640, out_features=3200, bias=False)
          (up_proj): Linear(in_features=3200, out_features=8640, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )


In [5]:
from accelerate import notebook_launcher

notebook_launcher(training_loop, (model,), num_processes=1, mixed_precision="bf16")



Launching training on CPU.


Epoch 0 loss: 0.7097679376602173:   1%|          | 2/300 [00:01<02:47,  1.78it/s]

KeyboardInterrupt: 