In [1]:
import matplotlib.pyplot as plt

from pathlib import Path

import lightning as L
import torch
import torch.nn as nn
from lit_llama import model
import random
from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
fabric = L.Fabric(devices=1)
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model")
tokenizer = Tokenizer(tokenizer_path)



In [3]:
import json
with open('datasets/alpaca_data_cleaned.json') as f:
    alpaca_json = json.load(f)

# Create tokenized j
alpaca_json_tokens = []

for item in alpaca_json:
    alpaca_json_tokens.append(
        {
            'instruction': tokenizer.encode(item['instruction'], bos=True, eos=False, device=fabric.device),
            'input': tokenizer.encode(item['input'], bos=False, eos=False, device=fabric.device),
            'output':tokenizer.encode(item['output'], bos=False, eos=True, device=fabric.device)
        }
    )

In [None]:
def get_batch(batch_size=10):
    batch_indices = random.sample(range(len(alpaca_json_tokens)), k=batch_size)

    # IST tokens
    IST_tokens = []
    for index in batch_indices:
        llama_input = torch.cat([alpaca_json_tokens[index]['instruction'], alpaca_json_tokens[index]['input']]).unsqueeze(0)
        IST_tokens.append(IST_generator(LLamaModel(llama_input)[1])[:,-1,:])

    # get shortest
    shortest_output_len = 1000
    for item in batch_indices:
        if(len(alpaca_json_tokens[item]['output']) < shortest_output_len):
            shortest_output_len = len(alpaca_json_tokens[item]['output'])


    length = random.randint(0,shortest_output_len-1)
    inputs = []
    targets = []

    for item in batch_indices:
        inputs.append(alpaca_json_tokens[item]['output'][:length])
        targets.append(alpaca_json_tokens[item]['output'][:length+1])
    
    return torch.stack(inputs), torch.stack(targets), torch.stack(IST_tokens)


In [None]:
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth")
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model")


def load_LLaMA(checkpoint_path):
    with lazy_load(checkpoint_path) as checkpoint:
        name = llama_model_lookup(checkpoint)

        with EmptyInitOnDevice(
                device=fabric.device, dtype=dtype, quantization_mode=None # We won't quantize the weights
        ):
            model = LLaMA.from_name(name)

        model.load_state_dict(checkpoint)
    return model

In [None]:

dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

LLaMA_config = model.LLaMAConfig.from_name('tiny')
print('Loading models...')
# Load the LLaMa model and the IST generator (also a LLaMA model)
#LLamaModel = load_LLaMA(checkpoint_path).to(fabric.device)
LLamaModel = LLaMA(LLaMA_config)
print('Finished loading the first model')
print('Finished loading models')
tokenizer = Tokenizer(tokenizer_path)

IST_schemes = ['vanilla', 'last 4', '2nd to last', 'all layers']
scheme_losses = {}

IST_generator = model.Block(LLaMA_config)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(IST_generator.parameters(), lr=1e-4)

Loading models...
Finished loading the first model
Finished loading models


In [None]:
LLamaModel = LLamaModel.to(fabric.device)
IST_generator = IST_generator.to(fabric.device)

In [None]:
losses = []

In [None]:
# Training loop
LLamaModel.train()
for epoch in range(100):
    inputs, targets, IST_tokens = get_batch(32)
    inputs = inputs.to(fabric.device)
    targets = targets.to(fabric.device)
    IST_tokens = IST_tokens.to(fabric.device)
    predicted_logits = LLamaModel(inputs, IST_tokens)[0]
    loss = loss_fn(predicted_logits.permute(0,2,1), targets.type(torch.LongTensor))
    optimizer.zero_grad()
    loss.backward()
    print(f'loss: {loss.detach().numpy()}')
    losses.append(loss.detach().numpy())
    optimizer.step()

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 15.74 GiB total capacity; 15.49 GiB already allocated; 2.69 MiB free; 15.54 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
plt.plot(losses)