In [4]:
# Install required libraries
!pip install transformers datasets torch accelerate

# Import necessary modules
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
import torch

# Step 1: Load a Generalist Model
general_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
general_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
general_model.to(device)

example_text = "The patient shows signs of myocardial infarction and requires immediate intervention."
inputs = general_tokenizer(example_text, return_tensors="pt").to(device)  # Move inputs to the same device
logits = general_model(**inputs).logits
predicted_class = torch.argmax(logits, dim=1).item()
print(f"Generalist model prediction (untrained): Class {predicted_class}")

# Step 2: Prepare Domain-Specific Dataset
domain_texts = [
    "Patient has elevated liver enzymes.",
    "MRI scan indicates no neurological abnormalities.",
    "Prescription drug interactions are possible.",
    "Common cold symptoms include cough and fever.",
    "Suspected case of deep vein thrombosis.",
    "Routine dental cleaning scheduled."
]
domain_labels = [1, 1, 1, 0, 1, 0]  # 1 = relevant to diagnosis, 0 = not relevant

domain_data = Dataset.from_dict({"text": domain_texts, "label": domain_labels})

def tokenize_function(examples):
    return general_tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized_domain_data = domain_data.map(tokenize_function, batched=True)
tokenized_domain_data = tokenized_domain_data.remove_columns(["text"])  # Keep only tokenized inputs + labels

# Step 3: Fine-Tune the Model on Domain Data
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    logging_dir="./logs",
    report_to="none"
)

# Move model to CPU before creating Trainer (TrainingArguments handles device placement automatically)
general_model.to("cpu")  # Reset for Trainer compatibility

trainer = Trainer(
    model=general_model,
    args=training_args,
    train_dataset=tokenized_domain_data,
    eval_dataset=tokenized_domain_data
)

trainer.train()

# Step 4: Evaluate Both Models
test_input = "Patient presents with acute respiratory distress syndrome."
inputs = general_tokenizer(test_input, return_tensors="pt").to(device)  # Move inputs to GPU if available

# Generalist model prediction (still on CPU or GPU depending on above setup)
general_model.to(device)
logits_general = general_model(**inputs).logits
class_general = torch.argmax(logits_general, dim=1).item()

# Domain-specific model prediction
logits_domain = trainer.model.to(device)(**inputs).logits
class_domain = torch.argmax(logits_domain, dim=1).item()

print(f"Generalist model prediction: Class {class_general}")
print(f"Domain-specific model prediction: Class {class_domain}")



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Generalist model prediction (untrained): Class 0


Map:   0%|          | 0/6 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss
1,No log,0.52827
2,No log,0.472516
3,No log,0.440746


Generalist model prediction: Class 1
Domain-specific model prediction: Class 1
