# LLaMa Training & Generation Code
### Alex Mous and Elliott Zackrone
All code replicated from repo [LLaMa-Train](https://github.com/alex-mous/LLaMa-Train/) with changes to work with new structure and loading from Google Drive. 

## Install Dependencies

In [None]:
!pip install xformers sentencepiece

# Model (Llama folder)

## Xformers_model.py

In [None]:
"""
Reduced scale single GPU LLaMa model with XFormers efficient attention and rotary embedding
Based off of original LLaMa model
"""

from dataclasses import dataclass

import torch
from torch import nn
import torch.nn.functional as F

import xformers.ops as xops
from xformers.components.positional_embedding import RotaryEmbedding


@dataclass
class ModelArgs:
    dim: int = 512
    n_layers: int = 8
    n_heads: int = 8
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def xformers_attn(xq, xk, xv, is_causal):
    mask = xops.LowerTriangularMask() if is_causal else None
    return xops.memory_efficient_attention(
        xq, xk, xv, attn_bias=mask
    )


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super(Attention, self).__init__()

        self.n_local_heads = args.n_heads
        self.head_dim = args.dim // args.n_heads

        self.in_proj = nn.Linear(
            args.dim,
            3 * args.n_heads * self.head_dim,
            bias=False
        )
        self.out_proj = nn.Linear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False
        )

        self.pos_embed = RotaryEmbedding(self.head_dim)
        self.attn_fn = xformers_attn

    def forward(self, x: torch.Tensor, is_causal: bool = False):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.in_proj(x).chunk(3, dim=-1)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = self.pos_embed(xq, xk)
        output = self.attn_fn(
            xq.to(xv.dtype),
            xk.to(xv.dtype),
            xv,
            is_causal=is_causal
        )

        output = output.view(bsz, seqlen, -1)

        return self.out_proj(output)


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
    ):
        super(FeedForward, self).__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(
            in_features=dim,
            out_features=hidden_dim,
            bias=False
        )
        self.w2 = nn.Linear(
            in_features=hidden_dim,
            out_features=dim,
            bias=False
        )
        self.w3 = nn.Linear(
            in_features=dim,
            out_features=hidden_dim,
            bias=False
        )

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super(TransformerBlock, self).__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x: torch.Tensor, is_causal: bool = True):
        x_res = x + self.attention(self.attention_norm(x), is_causal)
        out = x_res + self.feed_forward(self.ffn_norm(x_res))
        return out


class XFormersTransformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.token_embeddings = nn.Embedding(
            num_embeddings=params.vocab_size,
            embedding_dim=params.dim
        )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(
            in_features=params.dim,
            out_features=params.vocab_size,
            bias=False
        )

    def forward(self, tokens: torch.Tensor, is_causal: bool = True):
        x = self.token_embeddings(tokens)
        for layer in self.layers:
            x = layer(x, is_causal)
        x = self.norm(x)
        output = self.output(x)  # compute logits for all instead of just last
        return output.float()


## Tokenizer.py

In [None]:
"""
SentencePieceProcessor-based Tokenizer, based off of original LLaMa tokenizer
"""

from sentencepiece import SentencePieceProcessor
from typing import List
import os


class Tokenizer:
    def __init__(self, model_path: str):
        # Load tokenizer from tokenizer model
        self.sp_model = SentencePieceProcessor(model_file=model_path)

        # Copy special tokens from model
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        assert type(s) is str
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        return self.sp_model.decode(t)

## Llama.py

In [None]:
import time
import torch
from torch import nn

from typing import Tuple, Optional
# from src.main.llama import Transformer, XFormersTransformer, Tokenizer, ModelArgs


def load_llama(
        tokenizer_path: str,
        initial_chkpt: str = None,
        use_xformers: bool = False,
        new_chkpt: bool = False,  # New checkpointing method
        **model_args
) -> Tuple[nn.Module, Tokenizer]:
    # Load LLaMa model and tokenizer with given parameters
    assert use_xformers, "Only XFormers model supported"
    start_time = time.time()
    print("Loading LLaMa model and tokenizer")
    tokenizer = Tokenizer(model_path=tokenizer_path)
    if initial_chkpt is not None and not new_chkpt:
        print(f"Loading initial checkpoint from {initial_chkpt}")
        model = torch.load(initial_chkpt, map_location="cpu")
    else:
        model_params = ModelArgs(**model_args)
        model_params.vocab_size = tokenizer.n_words
        model = XFormersTransformer(model_params)
        if initial_chkpt is not None:
            model.load_state_dict(torch.load(initial_chkpt))
    torch.set_default_tensor_type(torch.FloatTensor)
    print(f"Loaded model and tokenizer in {time.time() - start_time:.2f} seconds")
    return model, tokenizer


# Data Processing (Util folder)

## Data.py

In [None]:
"""
Data processing and loading
"""

from typing import Tuple, Optional
import json
import os
import time
from tqdm import tqdm
import math

import torch
from torch.utils.data import Dataset, DataLoader


default_data_path: str = "."

artifacts_path: str = "artifacts/"


class PileDataset(Dataset):
    """
    Dataset for loading the Pile dataset.
    Loaded from an array of sequences, each of equal length.
    """
    def __init__(self, seqs: torch.Tensor):
        self.seqs = seqs

    def __getitem__(self, idx):
        if idx >= self.__len__():
            return None
        return self.seqs[idx]

    def __len__(self):
        return self.seqs.shape[0]


def _tokenize_line(line: str, tokenizer: Tokenizer, max_seq_len: int, pad_id: int):
    # Tokenize a string line into at least one sequence of max_seq_len and return tensor of sequences
    line_tokens = torch.tensor(tokenizer.encode(line, bos=True, eos=False)).long()
    tokens = torch.full((math.ceil(len(line_tokens)/max_seq_len)*max_seq_len, ), pad_id, dtype=torch.long)
    tokens[:len(line_tokens)] = line_tokens
    tokens = tokens.view(max_seq_len, -1).t()
    return tokens


def process_file(
        tokenizer: Tokenizer,
        data_file: str,
        max_seqs: int = 20000,
        max_seq_len: int = 2048
) -> torch.Tensor:
    """
    Process JSONL file into up to max_seqs seqs of tokens of length seq_len
    Returns tensor of sequences, each of seq_len with padding of tokenizer eos id
    :param tokenizer:
    :param data_file:
    :param max_seqs:
    :param max_seq_len:
    :return: Tensor of dimension (up to max_seqs, seq_len)
    """

    # check if corresponding artifact exists.
    artifact_path = os.path.join(os.path.normpath(artifacts_path), f"{os.path.splitext(os.path.basename(data_file))[0]}.pt")
    if os.path.isfile(artifact_path):
        print(f"Artifact tokens found. Loading tokenized dataset from {artifact_path}")
        return torch.load(artifact_path)[:max_seqs]
    # otherwise, parse file.
    print(f"No artifact found at {artifact_path}. Loading and tokenizing dataset from {data_file}.")

    # create artifacts dir if they don't exist
    if not os.path.isdir(artifacts_path):
        os.makedirs(artifacts_path)

    pad_id = tokenizer.eos_id  # padding id
    seqs = torch.empty((max_seqs, max_seq_len), dtype=torch.long)  # sequences to parse

    # process data file into tokenized sequences padded to exactly max_seq_len
    curr = 0  # current sequence
    with open(data_file, "r", encoding="utf-8") as file:
        with tqdm(total=max_seqs, desc="Dataset loading: ") as p_bar:
            for jsonline in file:
                if curr >= max_seqs:
                    break
                raw = json.loads(jsonline)
                tokens = _tokenize_line(raw["text"], tokenizer, max_seq_len, pad_id)
                num_toks = min(tokens.shape[0], max_seqs-curr)
                seqs[curr:curr+num_toks, :] = tokens[:num_toks, :]
                curr += num_toks
                p_bar.update(num_toks)

    # save artifact and return
    torch.save(seqs, artifact_path)
    return seqs


def load_pile_dataset(
        tokenizer: Tokenizer,
        train_file : str,
        val_file : str,
        test_file : str = "",
        num_train: int = 20000,
        num_val: int = 10000,
        num_test: int = 0,
        max_seq_len: int = 2048,
        data_path: str = default_data_path
) -> Tuple[PileDataset, PileDataset, Optional[PileDataset]]:
    """
    Load Pile dataset into train, val, and test datasets of tokens
    with numbers of sequences and sequence lengths as specified
    """
    print("Loading Pile dataset...")
    start_time = time.time()

    train_toks = process_file(tokenizer, os.path.join(data_path, train_file), num_train, max_seq_len)
    val_toks = process_file(tokenizer, os.path.join(data_path, val_file), num_val, max_seq_len)
    test = None
    if num_test > 0:
        test_toks = process_file(tokenizer, os.path.join(data_path, test_file), num_test, max_seq_len)
        test = PileDataset(test_toks)
    train = PileDataset(train_toks)
    val = PileDataset(val_toks)

    print(f"Loaded dataset in {time.time() - start_time:.2f} seconds")
    return train, val, test


def get_pile_dataloaders(train_set: PileDataset, val_set: PileDataset, test_set: PileDataset = None, batch_size: int = 32):
    """
    Get dataloaders for train, val, and test datasets with given batch size
    :param train_set:
    :param val_set:
    :param test_set:
    :param batch_size:
    :return:
    """
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader


## Metrics.py

In [None]:
"""
Compute metrics based on a model and dataset loader
"""

import torch
from torch import nn
from torch.utils.data import DataLoader

#from src.main.llama import Transformer


def evaluate(model: XFormersTransformer, eval_dataloader: DataLoader, loss_fn: nn.CrossEntropyLoss) -> float:
    """
    Compute loss on eval dataloader
    """
    with torch.no_grad():
        model.eval()
        val_loss = 0
        try:
            for batch in eval_dataloader:
                tokens = batch.to(device)
                val_loss += compute_loss(model, tokens, loss_fn).item()
                tokens.cpu()
        finally:
            gc.collect()
            torch.cuda.empty_cache()
    return val_loss / len(eval_dataloader)


def compute_loss(model: XFormersTransformer, tokens: torch.Tensor, loss_fn: nn.CrossEntropyLoss) -> torch.Tensor:
    """
    Compute loss on the input batch of tokens
    :param model:
    :param tokens:
    :param loss_fn:
    :return:
    """
    logits = model.forward(tokens[:, :-1], is_causal=True)
    flattened_logits = logits.reshape(-1, model.params.vocab_size)  # flatten logits for input to cross-entropy loss
    shift_tokens = tokens[:, 1:].reshape(-1)  # shift tokens so we only compute loss after first token
    loss = loss_fn(flattened_logits, shift_tokens)  # compute loss between logits and true tokens, ignoring padding
    return loss


def get_number_of_parameters(model: nn.Module):
    total_params = 0
    for name, parameter in model.named_parameters():
        if parameter.requires_grad:
            total_params += parameter.numel()
    return total_params


## Checkpoints.py

In [None]:
"""
Load and save checkpoints for model and optimizer
"""

import os

import torch
from torch import nn, optim


def generate_checkpoint_name(checkpoints_base_path: str, epoch: int, new_type: bool):
    """
    Generate a checkpoint name for the given model type and epoch

    :param checkpoints_base_path: Path to directory to store checkpoints in
    :param model: PyTorch model
    :return: Checkpoint path and name
    """
    return os.path.join(checkpoints_base_path, f"chkpt-{epoch}" + ("-light" if new_type else "") + ".pt")


# Main Scripts

In [None]:
# Mount Google Drive
# Ensure that folder LLaMaTrain exists in My Drive, with tokenizer.model, tiny_train.jsonl, and tiny_val.jsonl before running the following training code
from google.colab import drive
import locale
locale.getpreferredencoding = lambda: "UTF-8"
drive.mount('/gdrive')

gdrive_path = "/gdrive/MyDrive/LLaMaTrain/"
artifacts_path = gdrive_path + "artifacts/"

In [None]:
def load_model_and_data(
    storage_base_path: str,
    tokenizer_path: str,
    train_path: str,
    val_path: str,
    initial_chkpt: Optional[str] = None,
    num_train: int = 20000,
    num_val: int = 10000,
    max_seq_len: int = 512,
    batch_size: int = 16,
    new_chkpt_format: bool = False,
    **model_args
) -> Tuple[XFormersTransformer, Tokenizer, DataLoader, DataLoader]:
    """
    Load a model and train and val dataloaders for training, evaluation, or generation
    """
    assert os.path.isfile(storage_base_path + tokenizer_path), "LLaMa tokenizer pretrained model file required"
    assert os.path.isfile(storage_base_path + train_path), "Train data subset in JSONL format required"
    assert os.path.isfile(storage_base_path + val_path), "Validation data subset in JSONL format required"

    # Load model
    torch.cuda.empty_cache()
    model, tokenizer = load_llama(
        tokenizer_path=storage_base_path + tokenizer_path,
        initial_chkpt=initial_chkpt,
        use_xformers=True,
        new_chkpt=new_chkpt_format,
        max_seq_len=max_seq_len,
        **model_args
    )

    # Load data
    train_set, val_set, _ = load_pile_dataset(
        tokenizer,
        storage_base_path + train_path,
        storage_base_path + val_path,
        num_train=num_train,
        num_val=num_val,
        max_seq_len=max_seq_len,
    )

    train_dataloader, val_dataloader, _ = get_pile_dataloaders(
        train_set,
        val_set,
        batch_size=batch_size
    )

    return model, tokenizer, train_dataloader, val_dataloader

## Example Batches

In [None]:
# OPTIONAL CODE - NOT REQUIRED FOR TRAINING
# Show example batches from Pile dataset

model, _, train_dataloader, val_dataloader = load_model_and_data(
    gdrive_path, 
    tokenizer_path="tokenizer.model",
    train_path="data-750k.jsonl",
    val_path="tiny_val.jsonl",
    num_train=750000,
    num_val=10000,
    batch_size=4
)

print("Example train batch")
print(next(iter(train_dataloader)))
print("Example val batch")
print(next(iter(val_dataloader)))

## Training

In [None]:
from typing import Tuple, Optional
import torch
import time
import gc
import os
from tqdm import tqdm

from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch import nn
from torch.optim import AdamW

# from src.main.llama import XFormersTransformer, Tokenizer, load_llama
# from src.main.util import get_pile_dataloaders, load_pile_dataset
# from src.main.util import compute_loss
# from src.main.util import save_checkpoint, load_checkpoint, generate_checkpoint_name

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def train(
        model: XFormersTransformer,
        tokenizer: Tokenizer,
        train_loader: DataLoader,
        val_loader: DataLoader,
        epochs: int,
        lr: float,
        weight_decay: float,
        grad_clip: float = 1.0,  # clipping for all gradients
        chkpt_dir: str = None,  # checkpoint directory
        batch_save_freq: int = -1,  # save after this many batches. -1 means only saving at the end of an epoch
):
    # Main training loop, including checkpointing
    if not os.path.isdir(chkpt_base):
        os.makedirs(chkpt_base)

    model.to(device)
    model.train()
    loss_fn = CrossEntropyLoss(ignore_index=tokenizer.eos_id)  # ignore padding
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in range(epochs):
        train_loss = 0
        try:
            for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
                # compute loss between predictions (logits) and tokens
                # each prediction corresponds to the next token, so we shift tokens by one
                tokens = batch.to(device)
                optimizer.zero_grad()
                loss = compute_loss(model, tokens, loss_fn)  # compute logits on all using all but last token
                tokens.cpu()  # ensure tokens can be garbage collected
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                train_loss += loss.item()

                if batch_save_freq > 0 and (i+1) % batch_save_freq == 0:
                    # Save checkpoint
                    chkpt_path = generate_checkpoint_name(chkpt_dir, f"{epoch+1}-batch-{i}", True)
                    torch.save(model.state_dict(), chkpt_path)
                    
                    val_loss = evaluate(model, val_loader, loss_fn)

                    # Print batch summary
                    print(f"Epoch {epoch+1}. Batch {i}. Train loss: {train_loss/i}. Val loss: {val_loss}")
        finally:
            # Save checkpoint
            chkpt_path = generate_checkpoint_name(chkpt_dir, f"{epoch+1}-end", True)
            torch.save(model.state_dict(), chkpt_path)
            # garbage collect to process next batch
            gc.collect()
            torch.cuda.empty_cache()

        val_loss = evaluate(model, val_loader, loss_fn)

        # Print summary
        train_loss /= len(train_loader)
        print(f"Epoch {epoch+1}. Train loss: {train_loss}. Val loss: {val_loss}")


# Checkpoints
chkpt_base = gdrive_path + "checkpoints/" + "dim-256-heads-8-layers-4-huge-run-part-3/"
ichkpt_path = gdrive_path + "checkpoints/" + "dim-256-heads-8-layers-4-huge-run-part-2/" + "chkpt-1-batch-17999-light.pt" # initial checkpoint, if any

# Training parameters
epochs = 2
lr = 3e-4
weight_decay = 0.01

model, tokenizer, train_dataloader, val_dataloader = load_model_and_data(
    gdrive_path, 
    tokenizer_path="tokenizer.model",
    train_path="data-750k.jsonl",
    val_path="tiny_val.jsonl",
    num_train=1500000,
    num_val=10000,
    batch_size=32,
    dim=256,
    n_layers=8,
    n_heads=4,
    initial_chkpt=ichkpt_path,
    new_chkpt_format=True
)

try:
    # Train model
    train(
        model,
        tokenizer,
        train_dataloader,
        val_dataloader,
        lr=lr,
        epochs=epochs,
        weight_decay=weight_decay,
        chkpt_dir=chkpt_base,
        batch_save_freq=3000
    )
finally:
    # Ensure model is on CPU so it can be garbage collected
    print("Cleaning up...")
    model.cpu()
    del model, tokenizer, train_dataloader, val_dataloader
    gc.collect()
    torch.cuda.empty_cache()


## Evaluation

In [None]:
# Load model and dataloaders

chkpt_base = gdrive_path + "checkpoints/" + "dim-256-heads-8-layers-8-big-run/"
chkpt_name = "chkpt-1-end-light.pt"

assert os.path.isfile(chkpt_base + chkpt_name), "Initial checkpoint required to eval"

model, tokenizer, _, val_dataloader = load_model_and_data(
    gdrive_path, 
    tokenizer_path="tokenizer.model",
    train_path="tiny_train.jsonl",
    val_path="tiny_val.jsonl",
    num_train=0,
    num_val=10000,
    batch_size=64,
    dim=256,
    n_layers=8,
    n_heads=8,
    initial_chkpt=chkpt_base + chkpt_name,
    new_chkpt_format=True
)


# Perform evaluation
model.to(device)
model.eval()
loss_fn = CrossEntropyLoss(ignore_index=tokenizer.eos_id)
print(f"Validation loss on model: {evaluate(model, val_dataloader, loss_fn)}")
model.cpu()
torch.cuda.empty_cache()
del model

## Generation

In [None]:
class LLaMAInference:
    def __init__(
        self, 
        model: XFormersTransformer, 
        tokenizer: Tokenizer
    ):
        self.model = model
        self.tokenizer = tokenizer

    def generate(
        self,
        prompts: List[str],
        max_gen_len: int = 512,
        temperature: float = 0.8,
        top_p: float = 0.95,
    ) -> List[str]:
        self.model.to(device)
        batch_size = len(prompts)

        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
        min_prompt_size = min([len(t) for t in prompt_tokens])
        max_prompt_size = max([len(t) for t in prompt_tokens])
        assert max_prompt_size < max_gen_len

        tokens = torch.full((batch_size, max_gen_len), self.tokenizer.eos_id).cuda().long()
        for k, t in enumerate(prompt_tokens):
            tokens[k, :len(t)] = torch.tensor(t).long()
        input_text_mask = (tokens != self.tokenizer.eos_id)

        start_pos = min_prompt_size
        self.model.eval()
        for cur_pos in range(start_pos, 100):
            logits = self.model(tokens)[:, cur_pos, :]
            if temperature > 0:
                probs = torch.softmax(logits / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits, dim=-1)
            next_token = next_token.reshape(-1)

            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token

        decoded = []
        for i, t in enumerate(tokens.tolist()):
            decoded.append(self.tokenizer.decode(t))
        return decoded


def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

In [None]:
# Load model and dataloaders

chkpt_base = gdrive_path + "checkpoints/" + "dim-256-heads-8-layers-8-big-run/"
chkpt_name = "chkpt-1-end-light.pt"

assert os.path.isfile(chkpt_base + chkpt_name), "Initial checkpoint required to eval"

model, tokenizer, _, _ = load_model_and_data(
    gdrive_path, 
    tokenizer_path="tokenizer.model",
    train_path="tiny_train.jsonl",
    val_path="tiny_val.jsonl",
    num_train=0,
    num_val=0,
    batch_size=64,
    dim=256,
    n_layers=8,
    n_heads=8,
    initial_chkpt=chkpt_base + chkpt_name,
    new_chkpt_format=True
)

inference_model = LLaMAInference(model, tokenizer)

prompts = ["The history of Spain can be summarized through the reigns of various kings and queens which", "Our Golf Umbrellas will score a Hole-in-One with your loyal golf-loving customers and clients. Every time they hit the greens with one of your promotional golf umbrellas", "The meaning of life", "The cat sat"]

try:
  responses = inference_model.generate(prompts, temperature=0.8)
finally:
  # Cleanup before exiting
  inference_model.model.cpu()
  gc.collect()
  torch.cuda.empty_cache()

for response in responses:
  print(response)