In [1]:
TECHNIQUE_OPTIONS = [
    "Scanning electron microscopy (SEM)",
    "Transmission electron microscopy (TEM)",
    "Atomic force microscopy (AFM)",
    "Reflected light microscopy"
]

CATEGORY_OPTIONS = [
    "Metal or alloy",
    "Ceramic",
    "Polymer",
    "Composite",
    "Fracture"
]


In [2]:
def technique_mcq_prompt():
    return (
        "Look carefully at the microscopy image and identify the imaging technique.\n\n"
        "A. Scanning electron microscopy (SEM)\n"
        "B. Transmission electron microscopy (TEM)\n"
        "C. Atomic force microscopy (AFM)\n"
        "D. Reflected light microscopy\n\n"
        "Answer with only one letter."
    )

def category_mcq_prompt():
    return (
        "Which material category best describes the image?\n\n"
        "A. Metal or alloy\n"
        "B. Ceramic\n"
        "C. Polymer\n"
        "D. Composite\n"
        "E. Fracture\n\n"
        "Answer with only one letter."
    )


In [3]:
def parse_answer(text, valid):
    text = text.upper()
    for v in valid:
        if v in text:
            return v
    return None

def gt_to_letter(gt, options):
    try:
        return chr(ord("A") + options.index(gt))
    except:
        return None


In [4]:
import torch
from transformers import LlavaForConditionalGeneration, LlavaProcessor

def load_model(path, device):
    model = LlavaForConditionalGeneration.from_pretrained(
        path,
        torch_dtype=torch.float16
    ).to(device)

    processor = LlavaProcessor.from_pretrained(path)

    model.eval()
    return model, processor


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "/mnt/d/Subham/model/llava-v1.6-mistral-7b-hf"

model, processor = load_model(MODEL_PATH, device)


`torch_dtype` is deprecated! Use `dtype` instead!
You are using a model of type llava_next to instantiate a model of type llava. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
def parse_answer(text, valid_letters):
    text = text.upper()
    for v in valid_letters:
        if v in text:
            return v
    return None

def gt_to_letter(gt_value, options):
    try:
        return chr(ord("A") + options.index(gt_value))
    except:
        return None


In [7]:
def run_vlm(model, processor, image, prompt, device):

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt}
            ]
        }
    ]

    # Apply official LLaVA chat template
    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = processor(
        images=image,
        text=text,
        return_tensors="pt"
    )

    # LLaVA v1.6 compatibility fixes
    inputs.pop("image_sizes", None)

    if inputs["pixel_values"].ndim == 5:
        inputs["pixel_values"] = inputs["pixel_values"][:, 0]

    inputs = inputs.to(device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=5,
            do_sample=False,
            pad_token_id=model.config.eos_token_id
        )

    return processor.decode(output[0], skip_special_tokens=True)


In [24]:
import pandas as pd
from PIL import Image
from tqdm import tqdm
import time

CSV_PATH = "qa.csv"
df = pd.read_csv(CSV_PATH)

start_time = time.time()
y_true_tech = []
y_pred_tech = []

y_true_cat = []
y_pred_cat = []

for _, row in tqdm(df.iterrows(), total=len(df)):

    # --- Safe image loading ---
    try:
        image = Image.open(row["image_local_path"]).convert("RGB")
    except Exception:
        continue

    # -------- Technique MCQ --------
    tech_out = run_vlm(
        model,
        processor,
        image,
        technique_mcq_prompt(),
        device
    )

    tech_pred = parse_answer(tech_out, ["A", "B", "C", "D"])
    tech_gt   = gt_to_letter(row["technique"], TECHNIQUE_OPTIONS)

    if tech_pred is not None and tech_gt is not None:
        y_pred_tech.append(tech_pred)
        y_true_tech.append(tech_gt)

    # -------- Category MCQ --------
    cat_out = run_vlm(
        model,
        processor,
        image,
        category_mcq_prompt(),
        device
    )

    cat_pred = parse_answer(cat_out, ["A", "B", "C", "D", "E"])
    cat_gt   = gt_to_letter(row["categories"], CATEGORY_OPTIONS)

    if cat_pred is not None and cat_gt is not None:
        y_pred_cat.append(cat_pred)
        y_true_cat.append(cat_gt)
# Calculate total time
total_time = time.time() - start_time
print(f"\nTotal processing time: {total_time:.2f} seconds")
print(f"Average time per image: {total_time/len(df):.2f} seconds")

  0%|                                                                                | 0/51 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  2%|█▍                                                                      | 1/51 [00:00<00:39,  1.25it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  4%|██▊                                                                     | 2/51 [00:01<00:35,  1.36it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  6%|████▏                                                                   | 3/51 [00:02<00:33,  1.42it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  8%|█████▋             


Total processing time: 36.43 seconds
Average time per image: 0.71 seconds





In [25]:
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix
)

def print_metrics(y_true, y_pred, name):
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)

    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true, y_pred,
        average="macro",
        zero_division=0
    )

    print(f"\n=== {name} ===")
    print(f"Samples            : {len(y_true)}")
    print(f"Accuracy           : {acc:.3f}")
    print(f"Balanced Accuracy  : {bal_acc:.3f}")
    print(f"Macro Precision    : {prec:.3f}")
    print(f"Macro Recall       : {rec:.3f}")
    print(f"Macro F1-score     : {f1:.3f}")


In [26]:
print_metrics(y_true_tech, y_pred_tech, "Technique Classification")
print_metrics(y_true_cat,  y_pred_cat,  "Category Classification")



=== Technique Classification ===
Samples            : 46
Accuracy           : 0.261
Balanced Accuracy  : 0.250
Macro Precision    : 0.065
Macro Recall       : 0.250
Macro F1-score     : 0.103

=== Category Classification ===
Samples            : 45
Accuracy           : 0.556
Balanced Accuracy  : 0.250
Macro Precision    : 0.139
Macro Recall       : 0.250
Macro F1-score     : 0.179
