In [1]:
import logging
import time
from typing import Optional, List, Dict

import os
import sys
import json
import gc 

import torch
from torch.utils.data import Dataset
import torch.distributed._shard.checkpoint as dist_cp
from torch.distributed.checkpoint import FileSystemReader
from tqdm import tqdm

from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_model(model_name, sharded_model_path=None):
    if sharded_model_path is None:
        logging.info(f"Loading model {model_name}")
        model = LlamaForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            return_dict=True,
            low_cpu_mem_usage=True,
        )
    else:
        logging.info(f"Loading sharded model from {sharded_model_path} with {model_name} config")

        model_config = LlamaConfig.from_pretrained(
            model_name,
            return_dict=True,
        )
        model = LlamaForCausalLM(
            config=model_config,
        )

        state_dict = {
            "model": model.state_dict()
        }

        dist_cp.load_state_dict(
            state_dict=state_dict,
            storage_reader=FileSystemReader(sharded_model_path),
            no_dist=True,
        )
        model.load_state_dict(state_dict["model"])
        model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    logging.info(f"model device before prediction: {next(model.parameters()).device}")
    return model

def write_json(instructions: List[Dict[str, str]], file_path: str):
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(instructions, f, indent=4, default=str)

In [4]:
gc.collect()
torch.cuda.empty_cache()

seed = 42
model_name = "meta-llama/Llama-2-7b-hf"
sharded_model_path = "/gpfs/space/projects/nlpgroup/llms/checkpoints/full-finetune-llama-7b-translation-alpaca-alpacaest-bs16/checkpoint-8605/pytorch_model_0"

torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
        
model = load_model(model_name, sharded_model_path)
    
model.eval()
tokenizer = LlamaTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token_id = 0

2023-10-09 18:33:06 | INFO | root | Loading sharded model from /gpfs/space/projects/nlpgroup/llms/checkpoints/full-finetune-llama-7b-translation-alpaca-alpacaest-bs16/checkpoint-8605/pytorch_model_0 with meta-llama/Llama-2-7b-hf config
2023-10-09 18:36:20 | INFO | root | model device before prediction: cuda:0


Loading the tokenizer from the `special_tokens_map.json` and the `added_tokens.json` will be removed in `transformers 5`,  it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again. You will see the new `added_tokens_decoder` attribute that will store the relevant information.


In [56]:
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:\n"
    ),
}

from torch.utils.data import Dataset
class ValidationDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]

        if item.get("input", "") == "":
            prompt = PROMPT_DICT["prompt_no_input"].format_map(item)
        else:
            prompt = PROMPT_DICT["prompt_input"].format_map(item)
        
        if "output" in item:
            return {"prompts": prompt, "labels": item["output"]}

        return {"prompts": prompt}

val_data = ValidationDataset([
    {
        "instruction": "Tüdruk tahtis oma matemaatikaõpetajat tänada.\nMis oli selle tagajärg?\nValik 1: Tüdruk jäeti peale tunde.\nValik 2: Tüdruk tõi õpetajale õuna.\n\nVasta õige numbriga.",
        "input": "",
    }
])
val_dataloader = torch.utils.data.DataLoader(
    val_data,
    batch_size=1,
    num_workers=1,
    pin_memory=True,
    drop_last=False
)

In [74]:
def predict(val_dataloader):
    for data_batch in tqdm(val_dataloader):
        """
        We pad to the longest sequence in the batch and not truncate at all because we are confident
        they have a reasonable lenght.
        """
        batch = tokenizer(data_batch["prompts"], padding=True, truncation=False, return_tensors="pt")
        batch = {k: v.to("cuda") for k, v in batch.items()}

        with torch.no_grad():
            outputs = model.generate(
                **batch,
                max_new_tokens=100,
                do_sample=False,
                top_p=1.0,
                temperature=1.0,
                min_length=None,
                use_cache=True,
                top_k=50,
                repetition_penalty=1.0,
                length_penalty=1,
                pad_token_id=tokenizer.pad_token_id,
                num_beams=1,
            )

        # Could use batch decode here but I want to process each one separately.
        for ix, output in enumerate(outputs):
            prediction = tokenizer.decode(output[len(batch["input_ids"][ix]):], skip_special_tokens=True)
            yield prediction

for prediction in predict(val_dataloader):
    print(prediction)

100%|██████████| 1/1 [00:01<00:00,  1.55s/it]

Valik 2: Tüdruk tõi õpetajale õuna.



