#albert model

ALBERT is a lighter model (approximately 11M parameters) compared to DistilBERT (66M), and while there isn’t a widely available biomedically pre-trained ALBERT-Base on Hugging Face, we’ll fine-tune it downstream on your biomedical data (CSV and PDFs) for disease and treatment prediction



In [None]:
# Step 2: Install required packages
!pip install -q pandas transformers torch langchain pymupdf sentence-transformers faiss-cpu scikit-learn

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m84.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m55.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m42.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!pip install -U langchain-community # Install the langchain-community package, which contains the necessary PyMuPDFLoader class

Collecting langchain-community
  Downloading langchain_community-0.3.18-py3-none-any.whl.metadata (2.4 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.8.1-py3-none-any.whl.metadata (3.5 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting python-dotenv>=0.21.0 (from pydantic-settings<3.0.0,>=2.4.0->langchain-community)
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB

#optimized
Reduced Epochs:
Lowered num_train_epochs from 3 to 1. For many tasks, a single epoch can suffice for fine-tuning, especially with a small dataset.

Increased Batch Size:
Raised per_device_train_batch_size and per_device_eval_batch_size from 16 to 32. ALBERT’s small size (11M parameters) allows larger batches, reducing the number of training steps.

Mixed Precision Training (FP16):
Added fp16=True in TrainingArguments. This uses half-precision floating-point numbers, speeding up training and reducing memory usage on GPU.

Reduced Sequence Length:
Lowered max_length from 512 to 256 in prepare_dataset and rag_predict. Shorter sequences decrease computation time, though ensure your symptoms data fits within this limit.

Efficient Dataset Preparation:
Replaced iterative tokenization with batch tokenization in prepare_dataset. This processes all texts at once, leveraging tokenizer efficiency.

Used usecols in pd.read_csv to load only required columns, reducing memory overhead.

Multi-threaded PDF Loading:
Implemented ThreadPoolExecutor in load_pdfs_from_folder to load PDFs concurrently, speeding up I/O operations.

Smaller Text Chunks:
Reduced chunk_size from 1000 to 500 and chunk_overlap from 200 to 100 in RecursiveCharacterTextSplitter. Smaller chunks decrease embedding computation time.

Reduced Warmup and Logging:
Lowered warmup_steps from 200 to 100 and logging_steps from 10 to 5, minimizing overhead during training.



In [None]:

# Import libraries
import os
import pandas as pd
from transformers import AlbertTokenizer, AlbertForSequenceClassification, Trainer, TrainingArguments
import torch
from langchain.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from sklearn.metrics import precision_score, recall_score, f1_score
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

# Paths (adjusted for Colab)
DATA_PATH = "/content/drive/MyDrive/01-nlp/modeling/data/merge_demo_amos_v3.csv"
PDF_FOLDER_PATH = "/content/drive/MyDrive/01-nlp/modeling/pdf"
MODEL_SAVE_PATH = "/content/drive/MyDrive/01-nlp/modeling/albert_finetuned"

# Load CSV dataset efficiently
df = pd.read_csv(DATA_PATH, usecols=["symptoms", "disease_name", "treatment"])

# Ensure required columns exist
required_cols = ["symptoms", "disease_name", "treatment"]
for col in required_cols:
    if col not in df.columns:
        raise ValueError(f"Column '{col}' is missing from the dataset")

# Load PDFs with multi-threading
def load_pdf(file_path):
    loader = PyMuPDFLoader(str(file_path))
    return loader.load()

def load_pdfs_from_folder(folder_path):
    pdf_files = list(Path(folder_path).glob("*.pdf"))
    with ThreadPoolExecutor() as executor:
        documents = list(executor.map(load_pdf, pdf_files))
    return [doc for sublist in documents for doc in sublist]  # Flatten list

pdf_docs = load_pdfs_from_folder(PDF_FOLDER_PATH)
print(f"Loaded {len(pdf_docs)} PDF pages")

# Split PDF documents into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)  # Reduced size for efficiency
split_docs = text_splitter.split_documents(pdf_docs)

# Create embeddings and vector store for RAG
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_documents(split_docs, embedding_model)

# Initialize ALBERT tokenizer and model
tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
model = AlbertForSequenceClassification.from_pretrained("albert-base-v2", num_labels=2)

# Prepare dataset efficiently with batch tokenization
def prepare_dataset(df, tokenizer, max_length=256):  # Reduced max_length
    texts = [f"Symptoms: {row['symptoms']}" for _, row in df.iterrows()]
    labels = [1 if pd.notna(row["treatment"]) and row["treatment"] != "" else 0 for _, row in df.iterrows()]
    encodings = tokenizer(texts, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
    return {"input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"]}, labels

encodings, labels = prepare_dataset(df, tokenizer)

# Convert to PyTorch dataset
class DiseaseDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = DiseaseDataset(encodings, labels)

# Split dataset for training and evaluation
train_size = int(0.8 * len(train_dataset))
eval_size = len(train_dataset) - train_size
train_subset, eval_subset = torch.utils.data.random_split(train_dataset, [train_size, eval_size])

# Define training arguments (optimized for speed)
training_args = TrainingArguments(
    output_dir="/content/results",
    num_train_epochs=15,  # Reduced epochs
    per_device_train_batch_size=32,  # Increased batch size
    per_device_eval_batch_size=32,
    warmup_steps=100,  # Further reduced warmup
    weight_decay=0.01,
    logging_dir="/content/logs",
    logging_steps=5,  # Reduced logging frequency
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=True,  # Enable mixed precision training
    report_to="none"
)

# Define metrics for evaluation
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision = precision_score(labels, preds, average="binary")
    recall = recall_score(labels, preds, average="binary")
    f1 = f1_score(labels, preds, average="binary")
    return {"precision": precision, "recall": recall, "f1": f1}

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_subset,
    eval_dataset=eval_subset,
    compute_metrics=compute_metrics,
)

# Fine-tune the model
print("Starting efficient downstream fine-tuning with ALBERT-Base...")
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate()
print("Evaluation Results:", eval_results)

# Save the fine-tuned model to Google Drive
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
model.save_pretrained(MODEL_SAVE_PATH)
tokenizer.save_pretrained(MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")

# RAG Pipeline: Predict disease and treatment from symptoms
def rag_predict(symptoms, vector_store, model, tokenizer, df, top_k=3):
    retrieved_docs = vector_store.similarity_search(symptoms, k=top_k)
    context = " ".join([doc.page_content for doc in retrieved_docs])
    input_text = f"Symptoms: {symptoms}"
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=256)  # Match training max_length

    # Move inputs to the same device as the model
    for key in inputs:
        inputs[key] = inputs[key].to(model.device)

    with torch.no_grad():
        outputs = model(**inputs)
        prediction = torch.argmax(outputs.logits, dim=-1).item()
    df["symptom_similarity"] = df["symptoms"].apply(lambda x: 1 if symptoms.lower() in str(x).lower() else 0)
    likely_disease_row = df.loc[df["symptom_similarity"].idxmax()]
    disease = likely_disease_row["disease_name"]
    treatment = likely_disease_row["treatment"] if prediction == 1 and pd.notna(likely_disease_row["treatment"]) else "No clear treatment identified."
    return {
        "disease": disease,
        "treatment": treatment,
        "context_snippet": context[:200] + "..." if len(context) > 200 else context
    }
# Example usage
symptoms_query = "fever, cough, fatigue"
prediction = rag_predict(symptoms_query, vector_store, model, tokenizer, df)
print("\nPrediction Result:")
print(f"Disease: {prediction['disease']}")
print(f"Treatment: {prediction['treatment']}")
print(f"Context from PDFs: {prediction['context_snippet']}")

# Print model evaluation summary
print("\nModel Evaluation Summary:")
for metric, value in eval_results.items():
    print(f"{metric}: {value:.4f}")

Loaded 5338 PDF pages


Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting efficient downstream fine-tuning with ALBERT-Base...


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.4331,0.407003,0.732877,0.757075,0.74478
2,0.4988,0.433561,0.801325,0.570755,0.666667
3,0.4714,0.39194,0.719828,0.787736,0.752252
4,0.4183,0.398073,0.72949,0.775943,0.752
5,0.3526,0.394021,0.771429,0.700472,0.73424
6,0.3495,0.391105,0.741935,0.759434,0.750583
7,0.3626,0.383133,0.777228,0.740566,0.758454
8,0.4282,0.432785,0.822695,0.54717,0.657224
9,0.3503,0.394015,0.763959,0.709906,0.735941
10,0.2915,0.382192,0.743056,0.757075,0.75


Evaluation Results: {'eval_loss': 0.3757314383983612, 'eval_precision': 0.7431818181818182, 'eval_recall': 0.7712264150943396, 'eval_f1': 0.7569444444444444, 'eval_runtime': 8.6643, 'eval_samples_per_second': 174.163, 'eval_steps_per_second': 5.54, 'epoch': 15.0}
Model saved to /content/drive/MyDrive/01-nlp/modeling/albert_finetuned

Prediction Result:
Disease: 5 common symptoms of ovulation
Treatment: No clear treatment identified.
Context from PDFs: Symptoms: Fever, fatigue, muscle aches, runny nose.   
Diagnosis: Clinical symptoms, viral PCR if needed, rule out bacterial.   
Treatment: Supportive care, acetaminophen (650 mg), fluids. 
Pneumonic ...

Model Evaluation Summary:
eval_loss: 0.3757
eval_precision: 0.7432
eval_recall: 0.7712
eval_f1: 0.7569
eval_runtime: 8.6643
eval_samples_per_second: 174.1630
eval_steps_per_second: 5.5400
epoch: 15.0000


In [None]:
# another usage
symptoms_query = "Fatigue, right upper quadrant discomfort"
prediction = rag_predict(symptoms_query, vector_store, model, tokenizer, df)
print("\nPrediction Result:")
print(f"Disease: {prediction['disease']}")
print(f"Treatment: {prediction['treatment']}")
print(f"Context from PDFs: {prediction['context_snippet']}")


Prediction Result:
Disease: 5 common symptoms of ovulation
Treatment: No clear treatment identified.
Context from PDFs: - Symptoms: Fatigue, right upper quadrant discomfort, enlarged liver. 
- Diagnosis: Ultrasound (fatty liver), LFTs (AST >ALT), history. 
- Treatment: Abstinence, vitamin E (800 IU daily), nutritional ...


In [None]:
# Print model evaluation summary
print("\nModel Evaluation Summary:")
for metric, value in eval_results.items():
    print(f"{metric}: {value:.4f}")


Model Evaluation Summary:
eval_loss: 0.3617
eval_precision: 0.8122
eval_recall: 0.7483
eval_f1: 0.7789
eval_runtime: 8.7789
eval_samples_per_second: 171.8900
eval_steps_per_second: 5.4680
epoch: 5.0000


#streamlit code

In [None]:
import streamlit as st
import pandas as pd
from transformers import AlbertTokenizer, AlbertForSequenceClassification
import torch
from langchain.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

# Paths (adjust these paths based on where your files are stored locally or in your environment)
DATA_PATH = "merge_demo_amos_v3.csv"  # Update to your local path
PDF_FOLDER_PATH = "pdf"  # Update to your local path
MODEL_SAVE_PATH = "albert_finetuned"  # Update to your local path

# Load pre-trained model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
    tokenizer = AlbertTokenizer.from_pretrained(MODEL_SAVE_PATH)
    model = AlbertForSequenceClassification.from_pretrained(MODEL_SAVE_PATH)
    model.eval()  # Set to evaluation mode
    return tokenizer, model

# Load CSV dataset
@st.cache_data
def load_csv_data():
    df = pd.read_csv(DATA_PATH, usecols=["symptoms", "disease_name", "treatment"])
    return df

# Load and process PDFs for RAG
@st.cache_resource
def load_vector_store():
    def load_pdf(file_path):
        loader = PyMuPDFLoader(str(file_path))
        return loader.load()

    def load_pdfs_from_folder(folder_path):
        pdf_files = list(Path(folder_path).glob("*.pdf"))
        with ThreadPoolExecutor() as executor:
            documents = list(executor.map(load_pdf, pdf_files))
        return [doc for sublist in documents for doc in sublist]

    pdf_docs = load_pdfs_from_folder(PDF_FOLDER_PATH)
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    split_docs = text_splitter.split_documents(pdf_docs)
    embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_store = FAISS.from_documents(split_docs, embedding_model)
    return vector_store

# Prediction function (adapted from your rag_predict)
def rag_predict(symptoms, vector_store, model, tokenizer, df, top_k=3):
    retrieved_docs = vector_store.similarity_search(symptoms, k=top_k)
    context = " ".join([doc.page_content for doc in retrieved_docs])
    input_text = f"Symptoms: {symptoms}"
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=256)

    # Move inputs to the same device as the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for key in inputs:
        inputs[key] = inputs[key].to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        prediction = torch.argmax(outputs.logits, dim=-1).item()

    df["symptom_similarity"] = df["symptoms"].apply(lambda x: 1 if symptoms.lower() in str(x).lower() else 0)
    likely_disease_row = df.loc[df["symptom_similarity"].idxmax()]
    disease = likely_disease_row["disease_name"]
    treatment = likely_disease_row["treatment"] if prediction == 1 and pd.notna(likely_disease_row["treatment"]) else "No clear treatment identified."
    return {
        "disease": disease,
        "treatment": treatment,
        "context_snippet": context[:200] + "..." if len(context) > 200 else context
    }

# Streamlit app
def main():
    st.title("Symptom-Based Disease and Treatment Prediction")
    st.write("Enter your symptoms below to predict the likely disease and treatment based on a fine-tuned ALBERT model and RAG pipeline.")

    # Load resources
    with st.spinner("Loading model and data..."):
        tokenizer, model = load_model_and_tokenizer()
        df = load_csv_data()
        vector_store = load_vector_store()

    # User input
    symptoms = st.text_area("Enter your symptoms (e.g., fever, cough, fatigue):", "")

    if st.button("Predict"):
        if symptoms.strip() == "":
            st.error("Please enter some symptoms.")
        else:
            with st.spinner("Making prediction..."):
                prediction = rag_predict(symptoms, vector_store, model, tokenizer, df)
                st.success("Prediction complete!")

                # Display results
                st.subheader("Prediction Results")
                st.write(f"**Likely Disease:** {prediction['disease']}")
                st.write(f"**Recommended Treatment:** {prediction['treatment']}")
                st.write(f"**Context from PDFs:** {prediction['context_snippet']}")

    # Footer
    st.write("---")
    st.write("Built with Streamlit, ALBERT, and LangChain by xAI.")

if __name__ == "__main__":
    main()

#albert model

In [None]:
# Step 1: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Step 2: Install required packages
!pip install -q pandas transformers torch langchain pymupdf sentence-transformers faiss-cpu scikit-learn

# Import libraries
import os
import pandas as pd
from transformers import AlbertTokenizer, AlbertForSequenceClassification, Trainer, TrainingArguments
import torch
from langchain.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from sklearn.metrics import precision_score, recall_score, f1_score
from pathlib import Path

# Paths (adjusted for Colab)
DATA_PATH = "/content/drive/MyDrive/01-nlp/modeling/data/merge_demo_amos_v3.csv"  # Your CSV path
PDF_FOLDER_PATH = "/content/drive/MyDrive/01-nlp/modeling/pdf"  # Your PDF folder path
MODEL_SAVE_PATH = "/content/drive/MyDrive/01-nlp/modeling/albert_finetuned"  # Save path for model

# Load CSV dataset
df = pd.read_csv(DATA_PATH)

# Ensure required columns exist
required_cols = ["symptoms", "disease_name", "treatment"]
for col in required_cols:
    if col not in df.columns:
        raise ValueError(f"Column '{col}' is missing from the dataset")

# Load PDFs from folder
def load_pdfs_from_folder(folder_path):
    documents = []
    for pdf_file in Path(folder_path).glob("*.pdf"):
        loader = PyMuPDFLoader(str(pdf_file))
        documents.extend(loader.load())
    return documents

pdf_docs = load_pdfs_from_folder(PDF_FOLDER_PATH)
print(f"Loaded {len(pdf_docs)} PDF pages")

# Split PDF documents into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
split_docs = text_splitter.split_documents(pdf_docs)

# Create embeddings and vector store for RAG
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_documents(split_docs, embedding_model)

# Initialize ALBERT tokenizer and model
tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
model = AlbertForSequenceClassification.from_pretrained("albert-base-v2", num_labels=2)  # Binary classification

# Prepare dataset for downstream fine-tuning
def prepare_dataset(df):
    encodings = {'input_ids': [], 'attention_mask': []}
    labels = []
    for _, row in df.iterrows():
        text = f"Symptoms: {row['symptoms']}"  # Input is symptoms only
        encoding = tokenizer(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
        # Label: 1 if treatment exists, 0 if not
        label = 1 if pd.notna(row["treatment"]) and row["treatment"] != "" else 0
        encodings['input_ids'].append(encoding["input_ids"].squeeze())
        encodings['attention_mask'].append(encoding["attention_mask"].squeeze())
        labels.append(label)
    return encodings, labels

encodings, labels = prepare_dataset(df)

# Convert to PyTorch dataset
class DiseaseDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = DiseaseDataset(encodings, labels)

# Split dataset for training and evaluation
train_size = int(0.8 * len(train_dataset))
eval_size = len(train_dataset) - train_size
train_subset, eval_subset = torch.utils.data.random_split(train_dataset, [train_size, eval_size])

# Define training arguments (optimized for ALBERT's smaller size)
training_args = TrainingArguments(
    output_dir="/content/results",  # Temp dir in Colab
    num_train_epochs=3,
    per_device_train_batch_size=16,  # Larger batch size due to ALBERT's efficiency
    per_device_eval_batch_size=16,
    warmup_steps=200,  # Reduced warmup steps for faster convergence
    weight_decay=0.01,
    logging_dir="/content/logs",  # Temp dir in Colab
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none"  # Disable wandb logging in Colab
)

# Define metrics for evaluation
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision = precision_score(labels, preds, average="binary")
    recall = recall_score(labels, preds, average="binary")
    f1 = f1_score(labels, preds, average="binary")
    return {"precision": precision, "recall": recall, "f1": f1}

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_subset,
    eval_dataset=eval_subset,
    compute_metrics=compute_metrics,
)

# Fine-tune the model downstream
print("Starting downstream fine-tuning with ALBERT-Base...")
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate()
print("Evaluation Results:", eval_results)

# Save the fine-tuned model to Google Drive
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
model.save_pretrained(MODEL_SAVE_PATH)
tokenizer.save_pretrained(MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")

# RAG Pipeline: Predict disease and treatment from symptoms
def rag_predict(symptoms, vector_store, model, tokenizer, df, top_k=3):
    # Retrieve relevant documents from vector store
    retrieved_docs = vector_store.similarity_search(symptoms, k=top_k)
    context = " ".join([doc.page_content for doc in retrieved_docs])

    # Tokenize input symptoms
    input_text = f"Symptoms: {symptoms}"
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)

    # Predict treatment presence
    with torch.no_grad():
        outputs = model(**inputs)
        prediction = torch.argmax(outputs.logits, dim=-1).item()

    # Find closest matching disease from CSV based on symptoms
    df["symptom_similarity"] = df["symptoms"].apply(lambda x: 1 if symptoms.lower() in str(x).lower() else 0)
    likely_disease_row = df.loc[df["symptom_similarity"].idxmax()]
    disease = likely_disease_row["disease_name"]
    treatment = likely_disease_row["treatment"] if prediction == 1 and pd.notna(likely_disease_row["treatment"]) else "No clear treatment identified."

    # Generate response
    response = {
        "disease": disease,
        "treatment": treatment,
        "context_snippet": context[:200] + "..." if len(context) > 200 else context
    }
    return response

# Example usage
symptoms_query = "fever, cough, fatigue"
prediction = rag_predict(symptoms_query, vector_store, model, tokenizer, df)
print("\nPrediction Result:")
print(f"Disease: {prediction['disease']}")
print(f"Treatment: {prediction['treatment']}")
print(f"Context from PDFs: {prediction['context_snippet']}")

# Print model evaluation summary
print("\nModel Evaluation Summary:")
for metric, value in eval_results.items():
    print(f"{metric}: {value:.4f}")