In [8]:
from typing import Tuple
import os
import sys
import torch
import time
import json
import fire

from pathlib import Path

from fairscale.nn.model_parallel.initialize import initialize_model_parallel

sys.path.append("/coc/pskynet6/dhe83/mice/src")
import config
from utils import *

sys.path.append(config.llama)
from llama import ModelArgs, Transformer, Tokenizer, LLaMA


def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    torch.distributed.init_process_group("nccl")
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size


def load(
    ckpt_dir: str,
    local_rank: int,
    world_size: int,
    max_seq_len: int,
    max_batch_size: int,
) -> LLaMA:
    tokenizer_path = os.path.join(config.llama,"checkpoints",  "tokenizer.model")

    start_time = time.time()
    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
    assert world_size == len(
        checkpoints
    ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
    ckpt_path = checkpoints[local_rank]
    print("Loading")
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(
        max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
    )
    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_args.vocab_size = tokenizer.n_words
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(model_args)
    torch.set_default_tensor_type(torch.FloatTensor)
    model.load_state_dict(checkpoint, strict=False)

    generator = LLaMA(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return generator




In [10]:
ckpt_dir = os.path.join(config.llama, "checkpoints", "7B")

local_rank, world_size = setup_model_parallel()
if local_rank > 0:
    sys.stdout = open(os.devnull, "w")
model = load(ckpt_dir, local_rank, world_size, max_seq_len, max_batch_size)

ValueError: Error initializing torch.distributed using env:// rendezvous: environment variable RANK expected, but not set

In [None]:
results = model.generate(
    [premise], max_gen_len=256, temperature=temperature, top_p=top_p
)

for result in results:
    print(result)
    print("\n==================================\n")