In [1]:
import gc
import os
from dataclasses import dataclass, field
from typing import Optional

import huggingface_hub
import torch
from accelerate import Accelerator
from datasets import load_dataset
from huggingface_hub import login
from peft import LoraConfig, TaskType, get_peft_model
from tqdm import tqdm
from transformers import (
    Adafactor,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    pipeline,
)
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler

from redditqa.dataset import load_reddit_dataset

  from .autonotebook import tqdm as notebook_tqdm
2023-08-15 13:13:45.533187: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
model_name = "meta-llama/Llama-2-7b-chat-hf"

In [3]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    #torch_dtype=torch.bfloat16,
    load_in_8bit=True,
    device_map={"": 0},
    peft_config=lora_config,
)

#model = get_peft_model(model, lora_config)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
model = model.cuda(0)

In [6]:
question = "Question: What is 1+1?\nAnswer: "

question_tokenized = tokenizer(question, return_tensors='pt')

input_ids = question_tokenized['input_ids'].cuda(0)

input_ids

tensor([[    1,   894, 29901,  1724,   338, 29871, 29896, 29974, 29896, 29973,
            13, 22550, 29901, 29871]], device='cuda:0')

In [7]:
result = model.generate(input_ids=input_ids, max_length=128)
result



tensor([[    1,   894, 29901,  1724,   338, 29871, 29896, 29974, 29896, 29973,
            13, 22550, 29901, 29871,  1133, 19933, 29896, 12372,  9728, 31254,
         28675, 15842, 14181, 27611,  5790,  2035, 14170, 10312, 14709,  6198,
          1943, 22349, 10612, 13375, 28758, 21576,   605,  9892, 23929, 18500,
         18415,  2425,  5521, 27079, 10221, 18964,  8765, 14248,  8395, 31582,
           616,  2492,  2963,  8854, 11059,  2608, 15880,  1265,  2207, 11054,
         19933,  8732, 27917,  1129,  1656, 20809, 12613, 10474,   374,   983,
          4927, 27536, 18817,  2963, 23332,  3708, 22499,  1264, 15232,   434,
          4125, 29582, 27456, 29664, 22453, 24573,  5848,  4680, 18956,  1675,
           948,  7172, 15831,  4101, 11566,   147,  2806, 11000, 10363, 12184,
          4471,  1929,  4471,  9940,  1619,  7654, 27537, 22818, 12443,   189,
          7651, 28100,  1377, 25946,  8467, 15991, 27151,  5166, 16447,  7820,
         28716, 13816,  6567,  9265,  1617, 21480, 2

In [8]:
tokenizer.batch_decode(result)

['<s> Question: What is 1+1?\nAnswer: cedimen1 Hamilton romanば Touch FORhell Prefakerietumenagr einges relativeadealufeldouxbahGMlass banragma unwIdentity loopbal Bentrachlingsфеelteenda민ialimgolaaroburylev Guylaceoraindaimenовано Ritterpoythonnaleedaretriicarij Gonz Ezola dece pur címurrent arcueomaignonershellusta Zarassarumumer Graybjectyn familjenpoleatz Dun�uth pairsRCendorubyrafubyaki My beskrerouestoregh� släktetBehaviorrite baron Gl Baron Jules Handpiciloalusrile handsurbron Havignoniz']