In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from transformers import AutoModel, TrainingArguments, Trainer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType


device = (
    torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
)
# if it is not MPS, try CUDA
print(torch.__version__)
print(torch.cuda.is_available())
device = torch.device("cuda") if torch.cuda.is_available() else device
print(f"Using device: {device}")

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    # random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)  # If using GPUs

set_seed(4242)

In [None]:
class NVEmbedForSequenceClassification(nn.Module):
    """
    Wrap NV-Embed-v2 to produce embeddings, then classify using a linear head.
    """

    def __init__(self, model_name="nvidia/NV-Embed-v2", num_labels=5, instruction=""):
        super().__init__()
        self.instruction = instruction
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        # Freeze config if needed
        # self.model.config.use_cache = False

        # Classification head on top of 4096-dim embeddings (NV-Embed-v2 default).
        self.classifier = nn.Linear(4096, num_labels)

    def forward(
        self,
        text=None,
        labels=None,
        max_length=2048,
        use_cache=False,
        # The Trainer will usually pass arguments like "input_ids", "attention_mask"
        # but we won't use them directly. We'll rely on "text" from the collator.
        **kwargs,
    ):
        """
        text: List of strings in the batch.
        labels: Tensor of shape [batch_size], for classification.
        """
        # 1) Produce embeddings:
        embeddings = self.model.encode(
            text,
            instruction=self.instruction,
            max_length=max_length,
            use_cache=use_cache,
        )  # [batch_size, 4096]

        # 2) (Optional) normalize the embeddings:
        embeddings = F.normalize(embeddings, p=2, dim=1)

        # 3) Classification:
        logits = self.classifier(embeddings)  # [batch_size, num_labels]

        return {"logits": logits}

In [None]:
from datasets import load_dataset, Features, ClassLabel, Value

# Suppose your CSV has columns ["text","labels"], 5 classes in total
# Adjust data_file_path accordingly
data_file_path = "data/processed/finetuning_5_labels_topic_pruned.csv"

features = Features(
    {
        "text": Value("string"),
        "labels": ClassLabel(names=["0.0", "1.0", "2.0", "3.0", "4.0"]),
    }
)

dataset = load_dataset("csv", data_files=data_file_path, features=features)

# We only have a "train" split from the CSV, so let's do our own train/test split.
from sklearn.model_selection import train_test_split

df_all = dataset["train"].to_pandas()
train_df, test_df = train_test_split(
    df_all, test_size=0.1, stratify=df_all["labels"], random_state=42
)

# Convert back to Dataset
from datasets import Dataset

train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
test_dataset = Dataset.from_pandas(test_df.reset_index(drop=True))

In [None]:
class NVEmbedCollator:
    """
    Custom collator that returns text as a list of strings plus the labels.
    This is crucial so that inside the model.forward(...), we can call model.encode(text=...).
    """

    def __init__(self, label_name="labels"):
        self.label_name = label_name

    def __call__(self, features):
        # features is a list of dicts: [{"text": ..., "labels": ...}, ...]
        texts = [f["text"] for f in features]
        labels = [f[self.label_name] for f in features]

        # Convert labels to tensor
        labels_tensor = torch.tensor(labels, dtype=torch.long)

        # Return dictionary that HF Trainer can pass to the model
        # Notice "text" is a list of strings, "labels" is a tensor
        batch = {"text": texts, "labels": labels_tensor}
        return batch


collator = NVEmbedCollator(label_name="labels")

In [None]:
# Instruction for the model
task_instructions = (
    "Given a biotech press release, classify it into 5 categories (0..4)."
)

num_labels = 5
base_model = NVEmbedForSequenceClassification(
    model_name="nvidia/NV-Embed-v2",
    num_labels=num_labels,
    instruction=task_instructions,
)

In [None]:
# LoRA config for sequence classification
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_CLS,  # "SEQ_CLS" or "CAUSAL_LM" etc.
)

lora_model = get_peft_model(base_model, lora_config)

# Print trainable parameters
lora_model.print_trainable_parameters()

In [None]:
import evaluate

metric_accuracy = evaluate.load("accuracy")
metric_f1 = evaluate.load("f1")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    acc = metric_accuracy.compute(references=labels, predictions=predictions)
    f1_val = metric_f1.compute(
        references=labels, predictions=predictions, average="weighted"
    )
    return {"accuracy": acc["accuracy"], "f1": f1_val["f1"]}

In [None]:
training_args = TrainingArguments(
    output_dir="./results-nv-embed-lora",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_steps=10,
    load_best_model_at_end=True,
    report_to="none",  # Turn off W&B or HF logs
    fp16=True if device.type == "cuda" else False,
    # For large embeddings you may want gradient_accumulation_steps, etc.
)

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
eval_result = trainer.evaluate()
print("Evaluation:", eval_result)

In [None]:
test_texts = [
    "BioVie Announces Alignment with FDA on Clinical Trial to Assess Bezisterim in Parkinson’s Disease...",
    "The development of a recombinant polyclonal antibody therapy for COVID-19 by GigaGen...",
]

# Prepare batch manually
batch_for_inference = {
    "text": test_texts,
    "labels": torch.zeros(len(test_texts), dtype=torch.long),  # dummy
}

# Move to device if needed
for_inference = {
    "text": batch_for_inference["text"],
    "labels": batch_for_inference["labels"].to(device),
}

with torch.no_grad():
    outputs = trainer.model(**for_inference)
logits = outputs["logits"]
preds = torch.argmax(logits, dim=-1)
print("Predicted labels:", preds.tolist())