In [None]:
import os
import sys

#NVIDIA_TESLA_A100 x 16
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15"

if os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE":
    os.environ["HF_HOME"] = "/content/hf"
# Authenticate with Hugging Face
from huggingface_hub import get_token
if get_token() is None:
    from huggingface_hub import notebook_login
    notebook_login()

In [None]:
! pip install --upgrade --quiet bitsandbytes datasets evaluate peft tensorboard transformers trl

In [None]:
import kagglehub
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict

path = kagglehub.dataset_download("colewelkins/cardiovascular-disease")
data = pd.read_csv('/root/.cache/kagglehub/datasets/colewelkins/cardiovascular-disease/versions/1/cardio_data_processed.csv')

train_df, val_df = train_test_split(
    data,
    test_size=0.2,
    stratify=data["cardio"],
    random_state=42,
    shuffle=True
)
train_df, test_df = train_test_split(
    data,
    test_size=0.5,
    stratify=data["cardio"],
    random_state=42,
    shuffle=True
)

data = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
})
data

In [None]:
from typing import Any

CLASSES = [
  "0: No risk",
  "1: Risk of cardiovascular disease"
]

options = "\n".join(CLASSES)

def format_data(sample: dict[str, Any]) -> dict[str, Any]:
    prompt = f"""
    The patient is {sample['age_years']} years old, identified as {sample['gender']}, with a height of {sample['height']} cm and weight of {sample['weight']} kg (BMI: {sample['bmi']:.1f}).
    Their blood pressure is {sample['ap_hi']}/{sample['ap_lo']} mmHg, categorized as {sample['bp_category']}.
    Cholesterol level is {sample['cholesterol']}, glucose level is {sample['gluc']}.
    Lifestyle indicators: smoking = {sample['smoke']}, alcohol consumption = {sample['alco']}, physically active = {sample['active']}.
    Is there a risk of cardiovascular disease? Answer only by one of the following options:\n{options}
    """

    sample["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt,
                },
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": CLASSES[sample["cardio"]],
                },
            ],
        },
    ]
    return sample

In [None]:
data = data.map(format_data)

In [None]:
import evaluate

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(predictions) -> dict[str, float]:
    preds, labels = predictions

    metrics = {}
    metrics.update(accuracy_metric.compute(
        predictions=preds,
        references=labels,
    ))
    metrics.update(f1_metric.compute(
        predictions=preds,
        references=labels,
        average="weighted",
    ))
    return metrics

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

model_id = "google/medgemma-4b-it" #"google/medgemma-27b-text-it"

if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(model_id)

# Use right padding to avoid issues during training
processor.tokenizer.padding_side = "right"


In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

In [None]:
from typing import Any


def collate_fn(samples: list[dict[str, Any]]):
    texts = []
    for sample in samples:
        texts.append(processor.apply_chat_template(
            sample["messages"], add_generation_prompt=False, tokenize=False
        ).strip())

    batch = processor(text=texts, return_tensors="pt", padding=True)

    labels = batch["input_ids"].clone()

    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens that are not used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels

    return batch

In [None]:

from trl import SFTConfig, SFTTrainer

num_train_epochs = 1
learning_rate = 2e-4

args = SFTConfig(
    output_dir="medgemma-27b-text-it-sft-lora-cardiovascular-disease",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=50,
    save_strategy="epoch",
    eval_strategy="steps",
    eval_steps=50,
    learning_rate=learning_rate,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",
    push_to_hub=True,
    report_to="tensorboard",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns = False,
    label_names=["labels"],
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=data["train"],
    eval_dataset=data["validation"].shuffle().select(range(200)),  # Use subset of validation set for faster run
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
    # compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [None]:
trainer.save_model()

In [None]:
del model
del trainer
torch.cuda.empty_cache()

In [None]:
from typing import Any
from datasets import load_dataset


def format_test_data(sample: dict[str, Any]) -> dict[str, Any]:
    prompt = f"""
    The patient is {sample['age_years']} years old, identified as {sample['gender']}, with a height of {sample['height']} cm and weight of {sample['weight']} kg (BMI: {sample['bmi']:.1f}).
    Their blood pressure is {sample['ap_hi']}/{sample['ap_lo']} mmHg, categorized as {sample['bp_category']}.
    Cholesterol level is {sample['cholesterol']}, glucose level is {sample['gluc']}.
    Lifestyle indicators: smoking = {sample['smoke']}, alcohol consumption = {sample['alco']}, physically active = {sample['active']}.
    Is there a risk of cardiovascular disease? Answer only by one of the following options(full name):\n{options}
    """
    sample["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt,
                },
            ],
        },
    ]
    return sample

test_data = Dataset.from_pandas(test_df.reset_index(drop=True))
test_data = test_data.shuffle(seed=42).select(range(2048))
test_data = test_data.map(format_test_data)

In [None]:

import evaluate

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

REFERENCES = test_data["cardio"]


def compute_metrics(predictions: list[int]) -> dict[str, float]:
    metrics = {}
    metrics.update(accuracy_metric.compute(
        predictions=predictions,
        references=REFERENCES,
    ))
    metrics.update(f1_metric.compute(
        predictions=predictions,
        references=REFERENCES,
        average="weighted",
    ))
    return metrics

In [None]:
from datasets import ClassLabel

test_data = test_data.cast_column(
    "cardio",
    ClassLabel(names=CLASSES)
)

LABEL_FEATURE = test_data.features["cardio"]
ALT_LABELS = dict([
    (label, f"({label.replace(': ', ') ')}") for label in CLASSES
])


def postprocess(prediction: list[dict[str, str]], do_full_match: bool=False) -> int:
    response_text = prediction[0]["generated_text"]
    if do_full_match:
        return LABEL_FEATURE.str2int(response_text)
    for label in CLASSES:
        if label in response_text or ALT_LABELS[label] in response_text:
            return LABEL_FEATURE.str2int(label)
    return -1

In [None]:
from transformers import pipeline
import torch
# model_id = "google/medgemma-27b-text-it"
model_id = "google/medgemma-4b-it"

pt_pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16
)
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(model_id)

pt_pipe.model.generation_config.do_sample = False
pt_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id

pt_outputs = pt_pipe(
    text_inputs=test_data["messages"],
    max_new_tokens=40,
    batch_size=64,
    return_full_text=False,
)

pt_predictions = [postprocess(out) for out in pt_outputs]

# from tqdm import tqdm

# batch_size = 8
# all_outputs = []

# texts = test_data["messages"]

# for i in tqdm(range(0, len(texts), batch_size), desc="Generating"):
#     batch = texts[i : i + batch_size]
#     outs = pt_pipe(
#         batch,
#         max_new_tokens=30,
#         return_full_text=False,
#     )
#     all_outputs.extend(outs)

# pt_predictions = [postprocess(out) for out in all_outputs]


In [None]:
pt_metrics = compute_metrics(pt_predictions)
print(f"Baseline metrics: {pt_metrics}")

Baseline metrics: {'accuracy': 0.509765625, 'f1': 0.5780870295577049}


In [None]:
from transformers import pipeline
import torch
ft_pipe = pipeline(
    "text-generation",
    model=args.output_dir,
    torch_dtype=torch.bfloat16,
)

# Set `do_sample = False` for deterministic responses
ft_pipe.model.generation_config.do_sample = False
ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
# Use left padding during inference
processor.tokenizer.padding_side = "left"


ft_outputs = ft_pipe(
    text_inputs=test_data["messages"],
    max_new_tokens=40,
    batch_size=64,
    return_full_text=False,
)

ft_predictions = [postprocess(out, do_full_match=True) for out in ft_outputs]

In [None]:
ft_metrics = compute_metrics(ft_predictions)
print(f"Fine-tuned metrics: {ft_metrics}")

Fine-tuned metrics: {'accuracy': 0.681640625, 'f1': 0.6815495221036418}
