
REASON FOR CHOOSING MISTRAL 7B

For fine-tuning on a domain-specific question-answer dataset, Mistral 7B is generally considered better than BERT for several reasons. Mistral 7B is a modern language model with 7 billion parameters that outperforms larger models like Llama 2 13B on many benchmarks, including question answering. It uses advanced attention mechanisms like Grouped-Query Attention for faster inference and Sliding Window Attention for handling long inputs efficiently. Mistral 7B supports much longer context lengths than BERT, which is important for complex QA tasks. It also performs well on domain-specific applications such as medical question answering, showing superior precision compared to other models.

In contrast, while BERT has been a pioneering model for QA and is efficient for smaller tasks, it is older and limited by its smaller context window and architecture focused more on masked language modeling rather than autoregressive generation.

Thus, if you want one recommendation for a smaller model to fine-tune on domain-specific QA datasets, Mistral 7B is the better choice due to its improved performance, efficiency, and robustness on longer sequences and instruction-following tasks.

In [6]:
!nvidia-smi



Thu Nov 27 18:31:05 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:04:00.0 Off |                  Off |
|  0%   27C    P8             19W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import torch
import bitsandbytes as bnb

print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))



2.4.1+cu121
True
NVIDIA GeForce RTX 4090


In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import torch

model_name = "mistralai/Mistral-7B-Instruct-v0.2"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

print("Loading model in 4-bit (QLoRA compatible)...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print("Applying QLoRA adapters...")
lora_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)

print("✅ Mistral 7B successfully loaded in 4-bit with QLoRA!")


Loading tokenizer...


Loading model in 4-bit (QLoRA compatible)...


`torch_dtype` is deprecated! Use `dtype` instead!
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.15s/it]


Applying QLoRA adapters...
✅ Mistral 7B successfully loaded in 4-bit with QLoRA!


In [5]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

base = "mistralai/Mistral-7B-v0.1"
lora_path = "/workspace/LoRA_finetuning/model/mistral-medical-lora"


tokenizer = AutoTokenizer.from_pretrained(base)
model = AutoModelForCausalLM.from_pretrained(base, device_map="auto", load_in_4bit=True)
model = PeftModel.from_pretrained(model, lora_path)

prompt = """### Question:
What are the symptoms of glaucoma?

### Answer:
"""

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

output = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(output[0], skip_special_tokens=True))


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.64s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


### Question:
What are the symptoms of glaucoma?

### Answer:
The most common symptom of glaucoma is a gradual loss of peripheral or side vision. As glaucoma remains untreated, central vision may also be lost. The early, painless stages of open-angle glaucoma often go unnoticed. Without treatment, people with glaucoma will slowly lose their peripheral (side) vision. As glaucoma remains untreated, central vision may also be lost. The early, painless stages of open-angle glaucoma often go unnoticed. Without treatment, people with glaucoma will slowly lose their peripheral (side) vision. As glaucoma remains untreated, central vision may also be lost. The early, painless stages of open-angle glaucoma often go unnoticed. Without treatment, people with glaucoma will slowly lose their peripheral (side) vision. As glaucoma remains untreated, central


In [6]:
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./mistral-medical-merged")
tokenizer.save_pretrained("./mistral-medical-merged")




('./mistral-medical-merged/tokenizer_config.json',
 './mistral-medical-merged/special_tokens_map.json',
 './mistral-medical-merged/tokenizer.model',
 './mistral-medical-merged/added_tokens.json',
 './mistral-medical-merged/tokenizer.json')

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

merged_path = "/workspace/LoRA_finetuning/model/mistral-medical-merged"

tokenizer = AutoTokenizer.from_pretrained(merged_path)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    merged_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)



  from .autonotebook import tqdm as notebook_tqdm
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


In [2]:
def ask(question, max_new_tokens=200):
    prompt = f"""### Question:
{question}

### Answer:
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    output = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,     # deterministic for medical domain
        eos_token_id=tokenizer.eos_token_id,
    )

    text = tokenizer.decode(output[0], skip_special_tokens=True)
    answer = text.replace(prompt, "").strip()

    return answer


In [3]:
print(ask("What is diabetic retinopathy?"))


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Diabetic retinopathy is a complication of diabetes that affects the eyes. It is caused by damage to the blood vessels of the light-sensitive tissue at the back of the eye (retina). The retina changes blood sugar levels into a form that the brain can understand. The retina then sends the information to the brain through the optic nerve. Diabetic retinopathy can cause vision loss and blindness. The longer a person has diabetes, the more likely he or she will get diabetic retinopathy. If you have diabetes and have had it for a number of years, you have a greater risk of having diabetic retinopathy. Diabetic retinopathy may be more common in people with type 1 diabetes than in people with type 2 diabetes. Diabetic retinopathy is classified as nonproliferative or proliferative. Nonproliferative diabetic retinopathy is the most
