In [1]:
import matplotlib.pyplot as plt

import sys
import time
import warnings
from pathlib import Path
from typing import Optional

import lightning as L
import torch
import torch.nn as nn
from lit_llama import model
import pandas as pd
import os
import gc # chasing mmeory leaks
import random
from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup

In [2]:
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model")

In [3]:
fabric = L.Fabric(devices=1)

tokenizer = Tokenizer(tokenizer_path)

In [4]:
import json
with open('datasets/alpaca_data_cleaned.json') as f:
    alpaca_json = json.load(f)

# Create tokenized j
alpaca_json_tokens = []

for item in alpaca_json:
    #print(item)
    alpaca_json_tokens.append(
        {
            'instruction': tokenizer.encode(item['instruction'], bos=True, eos=False, device=fabric.device),
            'input': tokenizer.encode(item['input'], bos=False, eos=False, device=fabric.device),
            'output':tokenizer.encode(item['output'], bos=False, eos=True, device=fabric.device)
        }
    )

In [5]:
def get_batch(batch_size=10):
    batch_indices = random.sample(range(len(alpaca_json_tokens)), k=batch_size)

    # IST tokens
    IST_tokens = []
    for index in batch_indices:
        llama_input = torch.cat([alpaca_json_tokens[index]['instruction'], alpaca_json_tokens[index]['input']]).unsqueeze(0)
        IST_tokens.append(IST_generator(LLamaModel(llama_input)[1])[:,-1,:])

    # get shortest
    shortest_output_len = 1000
    for item in batch_indices:
        if(len(alpaca_json_tokens[item]['output']) < shortest_output_len):
            shortest_output_len = len(alpaca_json_tokens[item]['output'])


    length = random.randint(0,shortest_output_len-1)
    inputs = []
    targets = []

    for item in batch_indices:
        inputs.append(alpaca_json_tokens[item]['output'][:length])
        targets.append(alpaca_json_tokens[item]['output'][:length+1])
    
    return torch.stack(inputs), torch.stack(targets), torch.stack(IST_tokens)


In [6]:
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth")
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model")


def load_LLaMA(checkpoint_path):
    with lazy_load(checkpoint_path) as checkpoint:
        name = llama_model_lookup(checkpoint)

        with EmptyInitOnDevice(
                device=fabric.device, dtype=dtype, quantization_mode=None # We won't quantize the weights
        ):
            model = LLaMA.from_name(name)

        model.load_state_dict(checkpoint)
    return model

In [7]:

dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

LLaMA_config = model.LLaMAConfig.from_name('7B')
print('Loading models...')
# Load the LLaMa model and the IST generator (also a LLaMA model)
LLamaModel = load_LLaMA(checkpoint_path)
#LLamaModel.load_state_dict(torch.load("checkpoints/lit-llama/7B/lit-llama.pth"))
print('Finished loading the first model')
print('Finished loading models')
tokenizer = Tokenizer(tokenizer_path)

IST_schemes = ['vanilla', 'last 4', '2nd to last', 'all layers']
#IST_schemes = ['vanilla',]
scheme_losses = {}

IST_generator = model.Block(LLaMA_config).to(fabric.device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(IST_generator.parameters(), lr=1e-4)

Loading models...
Finished loading the first model
Finished loading models


In [8]:
'''
LLamaModel = LLamaModel.to(fabric.device)
IST_generator = IST_generator.to(fabric.device)
'''

'\nLLamaModel = LLamaModel.to(fabric.device)\nIST_generator = IST_generator.to(fabric.device)\n'

In [9]:
losses = []

In [10]:
for param in LLamaModel.parameters():
    param.requires_grad=False

In [11]:
# Training loop
LLamaModel.train()
for epoch in range(100):
    inputs, targets, IST_tokens = get_batch(1)
    inputs = inputs.to(fabric.device)
    targets = targets.to(fabric.device)
    IST_tokens = IST_tokens.to(fabric.device).type(torch.bfloat16)
    
    with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
        predicted_logits = LLamaModel(inputs, IST_tokens)[0]
    
    loss = loss_fn(predicted_logits.permute(0,2,1).to(fabric.device), targets.type(torch.LongTensor).to(fabric.device))
    optimizer.zero_grad()
    loss.backward()
    print(f'loss: {loss.item()}')
    losses.append(loss.item())
    optimizer.step()

loss: 9.1875


OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB (GPU 0; 15.74 GiB total capacity; 15.51 GiB already allocated; 10.69 MiB free; 15.52 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
!pip list

In [None]:
plt.plot(losses)