In [1]:
!pip install bitsandbytes



In [2]:
from transformers import pipeline
from datasets import load_dataset
from transformers import AutoTokenizer

from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
import gc

quantization_config = BitsAndBytesConfig(load_in_8bit=True)


model_pipeline = pipeline(
    model = "mistralai/Mistral-7B-v0.1",
    device = "cuda",
    # quantization_config=quantization_config

)

result = model_pipeline("What is the preferred HIV ART second line?")
print(result)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[{'generated_text': 'What is the preferred HIV ART second line?\n\nIt is important to clarify the differences between the different second line antiretroviral (ART) regimens, as there is a lot of variation. Many second line antiretroviral therapy (ART) regimens are used outside of South Africa. However, the regimens used in South Africa are different to those used in other countries.\n\nIn most other countries the second line ART regimen is based on two nucleoside reverse transcriptase inhibitor (NRTI) backbones:\n\n- Combivir + abacavir+ lamivudine (Kaletra (lopinavir/ritonavir) or Nelfinavir.\n- Zidovudine+ lamivudine (Kaletra or Nelfinavir)\n\nHowever, these regimens are not used in South Africa.\n\nIn South Africa the second line ART regimen is based on a three NRTI backbone:\n\n- Combivir + tenofovir (TDF) + efavirenz (EFV) (Efavirenz and tenofovir are combined in Tru'}]


In [3]:
print(result[0]['generated_text'])

What is the preferred HIV ART second line?

It is important to clarify the differences between the different second line antiretroviral (ART) regimens, as there is a lot of variation. Many second line antiretroviral therapy (ART) regimens are used outside of South Africa. However, the regimens used in South Africa are different to those used in other countries.

In most other countries the second line ART regimen is based on two nucleoside reverse transcriptase inhibitor (NRTI) backbones:

- Combivir + abacavir+ lamivudine (Kaletra (lopinavir/ritonavir) or Nelfinavir.
- Zidovudine+ lamivudine (Kaletra or Nelfinavir)

However, these regimens are not used in South Africa.

In South Africa the second line ART regimen is based on a three NRTI backbone:

- Combivir + tenofovir (TDF) + efavirenz (EFV) (Efavirenz and tenofovir are combined in Tru


In [4]:
del model_pipeline
del result

gc.collect()

torch.cuda.empty_cache()


In [5]:
tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1"
)

tokenizer.pad_token = tokenizer.eos_token

In [6]:
train_data = load_dataset("json", data_files="training_data.json")


In [7]:
def preprocess(sample):

    sample = sample["prompt"]+" \n "+sample["completion"]
    tokenized = tokenizer(
            sample,
            max_length=128,
            truncation=True,
            # return_tensors="pt",
            # padding=True
            padding="max_length",
            # tokenizer.pad_token = tokenizer.eos_token
        )

    tokenized["labels"] = tokenized["input_ids"]
    return tokenized

In [8]:
data = train_data.map(preprocess)

Map:   0%|          | 0/73 [00:00<?, ? examples/s]

In [9]:
tokenizer.save_pretrained("./guide_summary_mistral")
del tokenizer
del train_data
gc.collect()
torch.cuda.empty_cache()

In [10]:
print(data['train'][10])

{'prompt': 'Can rifepentine/Isoniazid be used for TB preventive therapy?', 'completion': '3 months of Rifapentine/Isoniazid (3HP) is the recommended TPT option except in pregnant and lactating women and children.', 'input_ids': [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2418, 12950, 615, 308, 473, 28748, 28737, 1265, 25939, 313, 347, 1307, 354, 320, 28760, 5297, 495, 12238, 28804, 28705, 13, 28705, 28770, 3370, 302, 399, 335, 377, 308, 473, 28748, 28737, 1265, 25939, 313, 325, 28770, 22106, 28731, 349, 272, 11572, 320, 6316, 3551, 3741, 297, 15446, 304, 543, 310, 1077, 2525, 304, 2436, 28723], 'attention_mask': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [11]:
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    device_map = "cuda",
    torch_dtype = torch.float16,
    quantization_config = quantization_config,
    # tokenizer.pad_token = tokenizer.eos_token
)

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules = ["q_proj", "k_proj", "v_proj"]
)

model = get_peft_model(model, lora_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
from transformers import TrainingArguments, Trainer
train_args = TrainingArguments(
    num_train_epochs = 15,
    learning_rate = 0.001,
    logging_steps = 25,
    fp16 = True,
    report_to="none"

)

In [13]:
trainer = Trainer(
    args=train_args,
    model=model,
    train_dataset=data["train"]
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [14]:
trainer.train()

TypeError: device() received an invalid combination of arguments - got (NoneType), but expected one of:
 * (torch.device device)
      didn't match because some of the arguments have invalid types: (!NoneType!)
 * (str type, int index = -1)


In [None]:
trainer.save_model("./guide_summary_mistral")
# tokenizer.save_pretrained("./guide_summary_qwen")

In [None]:
from transformers import pipeline

ask_summary = pipeline(
    model="./guide_summary_mistral",
    tokenizer="./guide_summary_mistral",
    device="cuda"
)

In [None]:
ask_summary("What is the preferred HIV ART second line?")[0]["generated_text"]

In [None]:
ask_summary("Is ATV/r better than LPV/r?")[0]["generated_text"]