# LLM Fine-Tuning for Disease Suggestion from Symptoms
Educational project notebook.


In [ ]:
from google.colab import drive
drive.mount('/content/drive')

## Install libraries

In [ ]:
!pip install transformers datasets peft accelerate bitsandbytes

## Load Dataset

In [ ]:
import pandas as pd
df = pd.read_csv('DiseaseAndSymptoms.csv')
df.head()

## Prepare JSONL

In [ ]:
import json
train = []
for _, row in df.iterrows():
    train.append({
        'instruction': 'Identify the disease pattern based on symptoms.',
        'input': row['Symptoms'],
        'output': f"Disease: {row['Disease']}\nExplanation: Pattern from dataset.\nNote: This is not medical advice."
    })
with open('train.jsonl','w') as f:
    for ex in train:
        f.write(json.dumps(ex)+'\n')

## Fine-Tune with QLoRA

In [ ]:
from datasets import load_dataset
dataset = load_dataset('json', data_files='train.jsonl', split='train')
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model

model_id = 'google/gemma-2b'
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)

def preprocess(examples):
    text = tokenizer('
'.join([examples['instruction'], examples['input'], examples['output']]), truncation=True)
    return text

tokenized = dataset.map(preprocess)

peft_config = LoraConfig(r=8, lora_alpha=16, target_modules=['q_proj','v_proj'], lora_dropout=0.05)
model = get_peft_model(model, peft_config)

args = TrainingArguments(
    output_dir='finetuned',
    per_device_train_batch_size=4,
    num_train_epochs=2,
    logging_steps=10
)

trainer = Trainer(model=model, args=args, train_dataset=tokenized)
trainer.train()

model.save_pretrained('adapter_out')

## Confusion Matrix Placeholder (User must regenerate after training)

In [ ]:
import matplotlib.pyplot as plt
import numpy as np
cm = np.array([[5,2],[1,7]])
plt.imshow(cm)
plt.title('Confusion Matrix (Placeholder)')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.savefig('/mnt/data/confusion_matrix.png')
plt.show()