# LongLLaMA: Focused Transformer Training for Context Scaling
**LongLLaMA is a large language model capable of handling long contexts of 256k tokens or even more**.

It is built upon the foundation of [OpenLLaMA](https://github.com/openlm-research/open_llama) and fine-tuned using the [Focused Transformer (FoT)](https://arxiv.org/abs/2307.03170) method.

This notebook is a demo of [LongLLaMA-Instruct-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_instruct), a [LongLLaMA-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_v1_1) fine-tuned using [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [ShareGPT-Processed](https://huggingface.co/datasets/zetavg/ShareGPT-Processed) datasets. Note that LongLLaMA-Instruct 3B is licensed differently due to the use of responses from GPT-4/GPT-3.5.
Similarly to the [LongLLaMA 3B](https://huggingface.co/syzymon/long_llama_3b), the model weights can serve as the drop-in replacement of LLaMA in existing implementations (for short context up to 2048 tokens).

This notebook is a research preview of [LongLLaMA-Instruct-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_instruct).
For more, see the [FoT paper](https://arxiv.org/abs/2307.03170) and [GitHub repository](https://github.com/CStanKonrad/long_llama).

# Initial steps

In [None]:
!pip install --upgrade pip
!pip install transformers==4.30.0  sentencepiece accelerate -q

In [None]:
import numpy as np
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM, TextStreamer, PreTrainedModel, PreTrainedTokenizer
from typing import List

In [None]:
MODEL_PATH = (
    "syzymon/long_llama_3b_instruct"
)
TOKENIZER_PATH = MODEL_PATH
# to fit into colab GPU we will use reduced precision
TORCH_DTYPE = torch.bfloat16

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=TORCH_DTYPE,
    device_map=device,
    trust_remote_code=True,
    # mem_attention_grouping is used
    # to trade speed for memory usage
    # for details, see the section Additional configuration
    mem_attention_grouping=(1, 2048),
)
model.eval()

# The demo

## Summarization and question answering
We used the [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) dataset to instruction tune the model.
Here we show the ability of the model to answer questions about long documents.
Note that as the model used in the demo has only 3B parameters, it can have trouble understanding complex documents.

In [None]:
import urllib.request
import tempfile
import shutil
import os


def get_paper(url: str, main_file: str):
    with tempfile.TemporaryDirectory() as tmp_dir:
        archive_dir = os.path.join(tmp_dir, "_archive.tar.gz")
        urllib.request.urlretrieve(url, archive_dir)

        shutil.unpack_archive(archive_dir, tmp_dir)

        with open(os.path.join(tmp_dir, main_file), "r") as f:
            data = f.read()

    return data


@torch.no_grad()
def load_to_memory(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text: str):
    tokenized_data = tokenizer(text, return_tensors="pt")
    input_ids = tokenized_data.input_ids
    input_ids = input_ids.to(model.device)
    torch.manual_seed(0)
    output = model(input_ids=input_ids)
    memory = output.past_key_values
    return memory


@torch.no_grad()
def generate_with_memory(
    model: PreTrainedModel, tokenizer: PreTrainedTokenizer, memory, prompt: str, temperature=0.6
):
    tokenized_data = tokenizer(prompt, return_tensors="pt")
    input_ids = tokenized_data.input_ids
    input_ids = input_ids.to(model.device)

    streamer = TextStreamer(tokenizer, skip_prompt=False)

    new_memory = memory

    stop = False
    while not stop:
        output = model(input_ids, past_key_values=new_memory)
        new_memory = output.past_key_values
        assert len(output.logits.shape) == 3
        assert output.logits.shape[0] == 1
        last_logit = output.logits[[0], [-1], :]
        dist = torch.distributions.Categorical(logits=last_logit / temperature)
        next_token = dist.sample()
        if next_token[0] == tokenizer.eos_token_id:
            streamer.put(next_token[None, :])
            streamer.end()
            stop = True
        else:
            input_ids = next_token[None, :]
            streamer.put(input_ids)


PROMPT_PREFIX = "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can.\n"


def construct_question_prompt(question: str):
    prompt = f"\nAnswer the following question breifly using information from the text above.\nQuestion: {question}\nAnswer: "
    return prompt


def ask_model(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, memory, seed=0):
    tokenized_data = tokenizer(prompt, return_tensors="pt")
    input_ids = tokenized_data.input_ids
    input_ids = input_ids.to(model.device)

    torch.manual_seed(seed)
    generate_with_memory(model, tokenizer, memory, prompt)

In [None]:
fot_paper = get_paper(url="https://arxiv.org/e-print/2307.03170", main_file="neurips_2023.tex")
fot_memory = load_to_memory(model, tokenizer, PROMPT_PREFIX + fot_paper)

In [None]:
prompt = construct_question_prompt("What is the paper above about? Summarize it briefly.")
ask_model(model, tokenizer, prompt, fot_memory)

In [None]:
prompt = construct_question_prompt("What method is introduced in the paper?")
ask_model(model, tokenizer, prompt, fot_memory)

In [None]:
prompt = construct_question_prompt("How is the 3B model called by the authors?")
ask_model(model, tokenizer, prompt, fot_memory)

In [None]:
prompt = construct_question_prompt("Name at least one author of the presented paper.")
ask_model(model, tokenizer, prompt, fot_memory)

In [None]:
prompt = construct_question_prompt("What is the distraction issue?")
ask_model(model, tokenizer, prompt, fot_memory)

## Chat
We have also used [ShareGPT-Processed](https://huggingface.co/datasets/zetavg/ShareGPT-Processed) dataset to enhance the model conversation abilities. The chat prompt was inspired by [LongChat](https://github.com/DachengLi1/LongChat).

In [None]:
class ChatOutputBuffer:
    """
    For providing online output that
    is truncated after generating specified (stop_text)
    sequence of characters
    """

    def __init__(self, stop_text: List[str], tokenizer: PreTrainedModel):
        self.tokenizer = tokenizer
        self.streamer = TextStreamer(tokenizer, skip_prompt=False)
        self.max_stop_seq = 0
        self.stop_seq = []
        for st in stop_text:
            self.stop_seq.append(st)
            self.max_stop_seq = max(self.max_stop_seq, len(st))

        self.output_buffer = np.empty((0,), dtype=np.int64)

    def reset_output_buffer(self):
        self.output_buffer = np.empty((0,), dtype=np.int64)

    def advance_output(self):
        beg = 0
        end = len(self.output_buffer) - self.max_stop_seq

        if end > beg:
            output = self.output_buffer[beg:end]
            self.streamer.put(output)
            self.output_buffer = self.output_buffer[end:]

    def flush_buffer(self):
        if len(self.output_buffer) > 0:
            self.streamer.put(self.output_buffer)
            self.output_buffer = self.output_buffer[len(self.output_buffer) :]
        self.streamer.end()

    def generation_too_long(self, text: str) -> int:
        end_requests = 0
        for st in self.stop_seq:
            if text.endswith(st):
                end_requests += 1
        return end_requests

    def update_buffer(self, next_tok: int) -> bool:
        assert isinstance(next_tok, int)

        array_next_tok = np.array([next_tok], dtype=np.int64)
        self.output_buffer = np.concatenate([self.output_buffer, array_next_tok], axis=0)

        suffix = self.output_buffer[-self.max_stop_seq :]
        decoded = self.tokenizer.decode(suffix)
        end_requests = self.generation_too_long(decoded)
        if end_requests > 0:
            decoded = self.tokenizer.decode(suffix[1:])
            while self.generation_too_long(decoded) == end_requests:
                suffix = suffix[1:]
                decoded = self.tokenizer.decode(suffix[1:])

            left_intact = len(self.output_buffer) - len(suffix)

            self.output_buffer = self.output_buffer[:left_intact]
            self.flush_buffer()
            return True

        self.advance_output()
        return False


class SimpleChatBot:
    def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.prompt = "A chat between a user (denoted as USER:) and an artificial intelligence assistant (denoted as ASSISTANT:). The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
        self.model_name = "\nASSISTANT: "
        self.user_name = "\nUSER: "
        self.past_key_values = None

        self.t = 0.8
        self.output_buffer = ChatOutputBuffer(
            [self.model_name.strip(), self.user_name.strip(), self.tokenizer.eos_token], self.tokenizer
        )

    @torch.no_grad()
    def ask(self, text: str):
        base_prompt = self.prompt if self.past_key_values is None else ""
        prompt = base_prompt + self.user_name + text + self.model_name
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
        if self.past_key_values is None:
            input_ids = torch.concatenate(
                [torch.tensor([self.tokenizer.bos_token_id], dtype=torch.long).reshape(1, 1), input_ids], dim=-1
            )

        self.output_buffer.reset_output_buffer()
        output_text = self.model_name
        output_ids = self.tokenizer.encode(
            output_text, return_tensors="pt", add_special_tokens=self.past_key_values is None
        )
        self.output_buffer.streamer.put(output_ids)

        is_writing = True

        while is_writing:
            input_ids = input_ids.to(model.device)
            output = self.model(input_ids=input_ids, past_key_values=self.past_key_values)

            logits = output.logits
            assert len(logits.shape) == 3
            assert logits.shape[0] == 1
            last_logit = logits[[0], [-1], :]

            dist = torch.distributions.Categorical(logits=last_logit / self.t)
            next_token = dist.sample()
            # Note that parts of cut out text may remain in model memory
            # this is implemented in this way for performance reasons
            past_key_values = output.past_key_values
            assert len(next_token.shape) == 1
            should_stop = self.output_buffer.update_buffer(next_token[0].cpu().item())
            if should_stop:
                is_writing = False
            else:
                input_ids = next_token[None, :]
                self.past_key_values = past_key_values

Feel free to try the chat yourself:

In [None]:
chatbot = SimpleChatBot(model=model, tokenizer=tokenizer)
while True:
    user_text = input("USER: ")
    chatbot.ask(user_text)