# LLaMa Training & Generation Code
### Alex Mous and Elliott Zackrone
All code replicated from repo [LLaMa-Train](https://github.com/alex-mous/LLaMa-Train/) with changes to work with new structure and loading from Google Drive. 

## Install Dependencies

In [None]:
!pip install xformers sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting xformers
  Downloading xformers-0.0.20-cp310-cp310-manylinux2014_x86_64.whl (109.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.1/109.1 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m43.4 MB/s[0m eta [36m0:00:00[0m
Collecting pyre-extensions==0.0.29 (from xformers)
  Downloading pyre_extensions-0.0.29-py3-none-any.whl (12 kB)
Collecting typing-inspect (from pyre-extensions==0.0.29->xformers)
  Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)
Collecting mypy-extensions>=0.3.0 (from typing-inspect->pyre-extensions==0.0.29->xformers)
  Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)
Installing collected packages: sent

# Model (Llama folder)

## Xformers_model.py

In [None]:
"""
Reduced scale single GPU LLaMa model with XFormers efficient attention and rotary embedding
Based off of original LLaMa model
"""

from dataclasses import dataclass

import torch
from torch import nn
import torch.nn.functional as F

import xformers.ops as xops
from xformers.components.positional_embedding import RotaryEmbedding


@dataclass
class ModelArgs:
    dim: int = 512
    n_layers: int = 8
    n_heads: int = 8
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def xformers_attn(xq, xk, xv, is_causal):
    mask = xops.LowerTriangularMask() if is_causal else None
    return xops.memory_efficient_attention(
        xq, xk, xv, attn_bias=mask
    )


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super(Attention, self).__init__()

        self.n_local_heads = args.n_heads
        self.head_dim = args.dim // args.n_heads

        self.in_proj = nn.Linear(
            args.dim,
            3 * args.n_heads * self.head_dim,
            bias=False
        )
        self.out_proj = nn.Linear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False
        )

        self.pos_embed = RotaryEmbedding(self.head_dim)
        self.attn_fn = xformers_attn

    def forward(self, x: torch.Tensor, is_causal: bool = False):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.in_proj(x).chunk(3, dim=-1)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = self.pos_embed(xq, xk)
        output = self.attn_fn(
            xq.to(xv.dtype),
            xk.to(xv.dtype),
            xv,
            is_causal=is_causal
        )

        output = output.view(bsz, seqlen, -1)

        return self.out_proj(output)


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
    ):
        super(FeedForward, self).__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(
            in_features=dim,
            out_features=hidden_dim,
            bias=False
        )
        self.w2 = nn.Linear(
            in_features=hidden_dim,
            out_features=dim,
            bias=False
        )
        self.w3 = nn.Linear(
            in_features=dim,
            out_features=hidden_dim,
            bias=False
        )

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super(TransformerBlock, self).__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x: torch.Tensor, is_causal: bool = True):
        x_res = x + self.attention(self.attention_norm(x), is_causal)
        out = x_res + self.feed_forward(self.ffn_norm(x_res))
        return out


class XFormersTransformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.token_embeddings = nn.Embedding(
            num_embeddings=params.vocab_size,
            embedding_dim=params.dim
        )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(
            in_features=params.dim,
            out_features=params.vocab_size,
            bias=False
        )

    def forward(self, tokens: torch.Tensor, is_causal: bool = True):
        x = self.token_embeddings(tokens)
        for layer in self.layers:
            x = layer(x, is_causal)
        x = self.norm(x)
        output = self.output(x)  # compute logits for all instead of just last
        return output.float()


## Tokenizer.py

In [None]:
"""
SentencePieceProcessor-based Tokenizer, based off of original LLaMa tokenizer
"""

from sentencepiece import SentencePieceProcessor
from typing import List
import os


class Tokenizer:
    def __init__(self, model_path: str):
        # Load tokenizer from tokenizer model
        self.sp_model = SentencePieceProcessor(model_file=model_path)

        # Copy special tokens from model
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        assert type(s) is str
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        return self.sp_model.decode(t)

## Llama.py

In [None]:
import time
import torch
from torch import nn

from typing import Tuple, Optional
# from src.main.llama import Transformer, XFormersTransformer, Tokenizer, ModelArgs


def load_llama(
        tokenizer_path: str,
        initial_checkpoint: Optional[str],
        use_xformers: bool = False,
        **model_args
) -> Tuple[nn.Module, Tokenizer]:
    # Load LLaMa model and tokenizer with given parameters
    start_time = time.time()
    print("Loading LLaMa model and tokenizer")
    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_params = ModelArgs(**model_args)
    model_params.vocab_size = tokenizer.n_words
    if use_xformers:
        model = XFormersTransformer(model_params)
    else:
        model = None
    torch.set_default_tensor_type(torch.FloatTensor)
    if initial_checkpoint is not None:
        torch.load(initial_checkpoint, map_location="cpu")
    print(f"Loaded model and tokenizer in {time.time() - start_time:.2f} seconds")
    return model, tokenizer


# Data Processing (Util folder)

## Data.py

In [None]:
"""
Data processing and loading
"""

from typing import Tuple, Optional
import json
import os
import time
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader


default_data_path: str = "."

artifacts_path: str = "artifacts/"


class PileDataset(Dataset):
    """
    Dataset for loading the Pile dataset.
    Loaded from an array of sequences, each of equal length.
    """
    def __init__(self, seqs: torch.Tensor):
        self.seqs = seqs

    def __getitem__(self, idx):
        if idx >= self.__len__():
            return None
        return self.seqs[idx]

    def __len__(self):
        return self.seqs.shape[0]


def _tokenize_line(line: str, tokenizer: Tokenizer, max_seq_len: int, pad_id: int):
    # Tokenize a string line into at least one sequence of max_seq_len and return tensor of sequences
    line_tokens = torch.tensor(tokenizer.encode(line, bos=True, eos=False)).long()
    if len(line_tokens) > max_seq_len:  # split into multiple sequences
        line_tokens = line_tokens[:max_seq_len * (len(line_tokens) // max_seq_len)]  # trim to multiple
        line_tokens = line_tokens.view(max_seq_len, -1).t()  # reshape into (num_seq, max_seq_len)
    else:
        line_tokens = line_tokens.reshape(1, -1)  # reshape into (1, seq len)
    tokens = torch.full((line_tokens.shape[0], max_seq_len), pad_id).long()
    for i, t in enumerate(line_tokens):
        tokens[i, : min(max_seq_len, len(t))] = t
    return tokens


def process_file(
        tokenizer: Tokenizer,
        data_file: str,
        max_seqs: int = 20000,
        max_seq_len: int = 2048
) -> torch.Tensor:
    """
    Process JSONL file into up to max_seqs seqs of tokens of length seq_len
    Returns tensor of sequences, each of seq_len with padding of tokenizer eos id
    :param tokenizer:
    :param data_file:
    :param max_seqs:
    :param max_seq_len:
    :return: Tensor of dimension (up to max_seqs, seq_len)
    """

    # check if corresponding artifact exists.
    artifact_path = os.path.join(artifacts_path, f"{os.path.splitext(data_file)[0]}.pt")
    if os.path.isfile(artifact_path):
        print(f"Artifact tokens found. Loading tokenized dataset from {artifact_path}")
        return torch.load(artifact_path)
    # otherwise, parse file.
    print(f"No artifact found. Loading and tokenizing dataset from {data_file}.")

    # create artifacts dir if they don't exist
    if not os.path.isdir(artifacts_path):
        os.makedirs(artifacts_path)

    seqs = torch.zeros((1, max_seq_len), dtype=torch.long)  # sequences to parse
    pad_id = tokenizer.eos_id  # padding id

    # process data file into tokenized sequences padded to exactly max_seq_len
    with open(data_file, "r", encoding="utf-8") as file:
        with tqdm(total=max_seqs, desc="Dataset loading: ") as p_bar:
            for jsonline in file:
                if seqs.shape[0] >= max_seqs:
                    break
                raw = json.loads(jsonline)
                tokens = _tokenize_line(raw["text"], tokenizer, max_seq_len, pad_id)
                seqs = torch.vstack((seqs, tokens))
                p_bar.update(tokens.shape[0])

    # save artifact and return
    torch.save(seqs[1:], artifact_path)
    return seqs[1:]


def load_pile_dataset(
        tokenizer: Tokenizer,
        train_file : str,
        val_file : str,
        test_file : str = "",
        num_train: int = 20000,
        num_val: int = 10000,
        num_test: int = 0,
        max_seq_len: int = 2048,
        data_path: str = default_data_path
) -> Tuple[PileDataset, PileDataset, Optional[PileDataset]]:
    """
    Load Pile dataset into train, val, and test datasets of tokens
    with numbers of sequences and sequence lengths as specified
    :param max_seq_len:
    :param data_path:
    :param test_file:
    :param val_file:
    :param train_file:
    :param tokenizer:
    :param num_train:
    :param num_val:
    :param num_test:
    :param max_seq_len:
    :return: train, val, (optionally) test PileDatasets
    """
    print("Loading Pile dataset...")
    start_time = time.time()

    train_toks = process_file(tokenizer, os.path.join(data_path, train_file), num_train, max_seq_len)
    val_toks = process_file(tokenizer, os.path.join(data_path, val_file), num_val, max_seq_len)
    test = None
    if num_test > 0:
        test_toks = process_file(tokenizer, os.path.join(data_path, test_file), num_test, max_seq_len)
        test = PileDataset(test_toks)
    train = PileDataset(train_toks)
    val = PileDataset(val_toks)

    print(f"Loaded dataset in {time.time() - start_time:.2f} seconds")
    return train, val, test


def get_pile_dataloaders(train_set: PileDataset, val_set: PileDataset, test_set: PileDataset = None, batch_size: int = 32):
    """
    Get dataloaders for train, val, and test datasets with given batch size
    :param train_set:
    :param val_set:
    :param test_set:
    :param batch_size:
    :return:
    """
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

## Metrics.py

In [None]:
"""
Compute metrics based on a model and dataset loader
"""

import torch
from torch import nn
from torch.utils.data import DataLoader

#from src.main.llama import Transformer


def accuracy(model: nn.Module, dataset_loader: DataLoader, device: torch.device) -> float:
    """
    Calculate accuracy of model based on dataset loader
    :param model:
    :param dataset_loader:
    :param device:
    :return:
    """


def compute_loss(model: XFormersTransformer, tokens: torch.Tensor, loss_fn : nn.CrossEntropyLoss) -> torch.Tensor:
    """
    Compute loss on the input batch of tokens
    :param model:
    :param tokens:
    :param loss_fn:
    :return:
    """
    logits = model.forward(tokens[:, :-1], is_causal=True)
    flattened_logits = logits.reshape(-1, model.params.vocab_size)  # flatten logits for input to cross-entropy loss
    shift_tokens = tokens[:, 1:].reshape(-1)  # shift tokens so we only compute loss after first token
    loss = loss_fn(flattened_logits, shift_tokens)  # compute loss between logits and true tokens, ignoring padding
    return loss


def get_number_of_parameters(model: nn.Module):
    total_params = 0
    for name, parameter in model.named_parameters():
        if parameter.requires_grad:
            total_params += parameter.numel()
    return total_params


## Checkpoints.py

In [None]:
"""
Load and save checkpoints for model and optimizer
"""

import os

import torch
from torch import nn, optim


def save_checkpoint(optimizer: optim.Optimizer, model: nn.Module, checkpoint_path: str):
    """
    Save checkpoint from optimizer and model into checkpoint_path

    :param optimizer: Optimizer used during training
    :param model: PyTorch model
    :param checkpoint_path: Checkpoint path and name
    """
    if checkpoint_path is not None:
        torch.save({
            'optimizer': optimizer.state_dict(),
            'model': model.state_dict(),
        }, checkpoint_path)


def load_checkpoint(optimizer: optim.Optimizer, model: nn.Module, checkpoint_path: str):
    """
    Load checkpoint into optimizer and model from checkpoint_path

    :param optimizer: Optimizer used during training (optional)
    :param model: PyTorch model
    :param checkpoint_path: Checkpoint path and name
    """
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer'])


def generate_checkpoint_name(checkpoints_base_path: str, epoch: int):
    """
    Generate a checkpoint name for the given model type and epoch

    :param checkpoints_base_path: Path to directory to store checkpoints in
    :param model: PyTorch model
    :return: Checkpoint path and name
    """
    return os.path.join(checkpoints_base_path, f"chkpt-{epoch}.pt")


## Example Batches

In [None]:
# OPTIONAL CODE - NOT REQUIRED FOR TRAINING
# Show example batches from Pile dataset
tokenizer_path = "tokenizer.model"
train_path = "tiny_train.jsonl"
val_path = "tiny_val.jsonl"
assert os.path.isfile(tokenizer_path), "LLaMa tokenizer pretrained model file required"
assert os.path.isfile(train_path), "Train data subset in JSONL format required"
assert os.path.isfile(val_path), "Validation data subset in JSONL format required"
batch_size = 4
max_seq_len = 512

torch.cuda.empty_cache()
model, tokenizer = load_llama(
    tokenizer_path=tokenizer_path,
    initial_checkpoint=None,
    use_xformers=True,
    max_seq_len=max_seq_len
)
    
train_set, val_set, _ = load_pile_dataset(
    tokenizer,
    train_path,
    val_path,
    num_train=20000,
    num_val=10000,
    max_seq_len=max_seq_len,
)
train_dataloader, val_dataloader, _ = get_pile_dataloaders(
    train_set,
    val_set,
    batch_size=batch_size
)

print("Example data point")
print(train_set[0])
print("Example train batch")
print(next(iter(train_dataloader)))
print("Example val batch")
print(next(iter(val_dataloader)))

# Main

In [None]:
# Mount Google Drive
# Ensure that folder LLaMaTrain exists in My Drive, with tokenizer.model, tiny_train.jsonl, and tiny_val.jsonl before running the following training code
from google.colab import drive
import locale
locale.getpreferredencoding = lambda: "UTF-8"
drive.mount('/gdrive')

persistent_storage_path = "/gdrive/MyDrive/LLaMaTrain/"

Mounted at /gdrive


## Training

In [None]:
from typing import Tuple, Optional
import torch
import time
import gc
import os
from tqdm import tqdm

from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# from src.main.llama import XFormersTransformer, Tokenizer, load_llama
# from src.main.util import get_pile_dataloaders, load_pile_dataset
# from src.main.util import compute_loss
# from src.main.util import save_checkpoint, load_checkpoint, generate_checkpoint_name

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def train(
        model: XFormersTransformer,
        tokenizer: Tokenizer,
        train_loader: DataLoader,
        val_loader: DataLoader,
        epochs: int,
        lr: float,
        weight_decay: float,
        grad_clip: float = 1.0,
        checkpoints_dir: str = None
):
    model.to(device)
    model.train()
    loss_fn = CrossEntropyLoss(ignore_index=tokenizer.eos_id)  # ignore padding
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    #lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs*len(train_loader))

    for epoch in range(epochs):
        train_loss = 0
        try:
            for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
                # compute loss between predictions (logits) and tokens
                # each prediction corresponds to the next token, so we shift tokens by one
                tokens = batch.to(device)
                optimizer.zero_grad()
                loss = compute_loss(model, tokens, loss_fn)  # compute logits on all using all but last token
                # tokens.cpu()  # ensure tokens can be garbage collected
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                #lr_scheduler.step()  # update lr
                train_loss += loss.item()
        finally:
            # garbage collect to process next batch
            gc.collect()
            torch.cuda.empty_cache()

        with torch.no_grad():
            model.eval()
            val_loss = 0
            try:
                for batch in val_loader:
                    tokens = batch.to(device)
                    val_loss += compute_loss(model, tokens, loss_fn).item()
                    tokens.cpu()
            finally:
                gc.collect()
                torch.cuda.empty_cache()

        # Print summary
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}. Train loss: {train_loss}. Val loss: {val_loss}")

        # Save checkpoint
        chkpt_path = generate_checkpoint_name(checkpoints_dir, epoch+1)
        torch.save(model, chkpt_path)


def main():
    # Model, data, and tokenizer arguments
    tokenizer_path = "tokenizer.model"
    train_path = "tiny_train.jsonl"
    val_path = "tiny_val.jsonl"
    checkpoint_base_path = persistent_storage_path + "checkpoints/"
    checkpoint_run_name = "dim-512-heads-8-layers-8"  # relative to base path
    load_checkpoint_path = None #"run2/chkpt-3.pt"  # relative to base path
    assert os.path.isfile(persistent_storage_path + tokenizer_path), "LLaMa tokenizer pretrained model file required"
    assert os.path.isfile(persistent_storage_path + train_path), "Train data subset in JSONL format required"
    assert os.path.isfile(persistent_storage_path + val_path), "Validation data subset in JSONL format required"
    epochs = 15
    batch_size = 16
    lr = 3e-4
    weight_decay = 0.01
    max_seq_len = 512
    dim = 512
    n_layers = 8
    n_heads = 8

    torch.cuda.empty_cache()
    initial_checkpoint = os.path.join(checkpoint_base_path, load_checkpoint_path) if load_checkpoint_path is not None else None
    model, tokenizer = load_llama(
        tokenizer_path=persistent_storage_path + tokenizer_path,
        initial_checkpoint=initial_checkpoint,
        use_xformers=True,
        max_seq_len=max_seq_len,
        dim=dim,
        n_layers=n_layers,
        n_heads=n_heads
    )

    train_set, val_set, _ = load_pile_dataset(
        tokenizer,
        persistent_storage_path + train_path,
        persistent_storage_path + val_path,
        num_train=20000,
        num_val=10000,
        max_seq_len=max_seq_len,
    )
    train_dataloader, val_dataloader, _ = get_pile_dataloaders(
        train_set,
        val_set,
        batch_size=batch_size
    )

    checkpoints_dir = os.path.join(checkpoint_base_path, checkpoint_run_name)
    if not os.path.isdir(checkpoints_dir):
        os.makedirs(checkpoints_dir)

    try:
        train(
            model,
            tokenizer,
            train_dataloader,
            val_dataloader,
            lr=lr,
            epochs=epochs,
            weight_decay=weight_decay,
            checkpoints_dir=checkpoints_dir
        )
    finally:
        # Ensure model is on CPU so it can be garbage collected
        model.cpu()


if __name__ == "__main__":
    try:
        main()
    finally:
        # Cleanup before exiting
        gc.collect()
        torch.cuda.empty_cache()

## Generation

In [None]:
# TODO