In [1]:
!pip install transformers==4.29.2
!pip install bitsandbytes==0.39.0
!pip install accelerate==0.19.0
!pip install torch==2.0.0
!pip install einops==0.6.1


[0m

In [None]:
from torch import cuda, bfloat16
import transformers

model_name = "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1"

device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=bfloat16
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    quantization_config=bnb_config,
    device_map='auto'
)
model.eval()
print(f"Model loaded on {device}")
     

Loading checkpoint shards:   0%|          | 0/18 [00:00<?, ?it/s]

In [None]:
import torch
from transformers import pipeline, BitsAndBytesConfig, AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=False,
    padding_side="left",
    trust_remote_code=True,
)

In [1]:
# generate_text = pipeline(
#     model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1",
#     tokenizer=tokenizer,
#     return_full_text=True,
#     trust_remote_code=True,
#     use_fast=False,
#     device_map={"": "cuda:0"},
#     model_kwargs=model_kwargs,
# )

In [None]:
generate_text = transformers.pipeline(
    model=model, tokenizer=tokenizer,
    return_full_text=True,  # langchain expects the full text
    task='text-generation',
    trust_remote_code=True,
    use_fast=False,
    # we pass model parameters here too
    # without this model rambles during chat
    temperature=0.0,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    max_new_tokens=512,  # mex number of tokens to generate in the output
    repetition_penalty=1.1  # without this output begins repeating
)

In [None]:
res = generate_text("Explain to me the difference between nuclear fission and fusion.")
print(res[0]["generated_text"])
     