# Human Anatomy Chatbot with BERT

This notebook implements a chatbot for human anatomy using BERT and Flask. Follow the cells in order to:
1. Set up the environment
2. Load and prepare data
3. Train the model
4. Launch the API server

In [1]:
# Import Required Libraries
import json
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    TrainingArguments, 
    Trainer
)
from datasets import Dataset
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score
from flask import Flask, request, jsonify
from typing import Dict, Any
import time

# Check if CUDA is available
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")

CUDA available: False


In [2]:
class AnatomyChatbot:
    def __init__(self, model_name: str = "bert-base-uncased"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name, 
            num_labels=2
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
    
    def load_and_prepare_dataset(self, folder_path: str) -> pd.DataFrame:
        """Load and prepare dataset from JSONL files"""
        all_data = []
        if not os.path.exists(folder_path):
            raise FileNotFoundError(f"Directory {folder_path} not found")
            
        jsonl_files = [f for f in os.listdir(folder_path) if f.endswith('.jsonl')]
        if not jsonl_files:
            raise FileNotFoundError(f"No JSONL files found in {folder_path}")
            
        for file_name in jsonl_files:
            with open(os.path.join(folder_path, file_name), 'r', encoding='utf-8') as file:
                all_data.extend([json.loads(line) for line in file])
        
        df = pd.DataFrame(all_data)
        if 'title' not in df.columns or 'content' not in df.columns:
            raise ValueError("Dataset missing required columns: title and content")
            
        # Combine title and content into a single text field
        df['text'] = df['title'] + " " + df['content']
        
        # Add dummy labels for binary classification (you should replace this with your actual labels)
        df['labels'] = 0  # Replace with actual labels if you have them
        
        return df[['text', 'labels']]  # Only keep the columns we need
    
    def split_dataset(self, df: pd.DataFrame, train_ratio: float = 0.8) -> tuple:
        """Split dataset into training and testing sets"""
        train_size = int(len(df) * train_ratio)
        train_data = df.iloc[:train_size]
        test_data = df.iloc[train_size:]
        return train_data, test_data
    
    def tokenize_function(self, examples: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        """Tokenize the input text"""
        return self.tokenizer(
            examples['text'],  # Changed from 'contents' to 'text'
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
    
    def compute_metrics(self, pred) -> Dict[str, Any]:
        """Compute evaluation metrics"""
        labels = pred.label_ids
        preds = np.argmax(pred.predictions, axis=1)
        return {
            "accuracy": accuracy_score(labels, preds),
            "confusion_matrix": confusion_matrix(labels, preds).tolist()
        }
    
    def train(self, train_dataset, test_dataset, output_dir: str = "./results"):
        """Train the model"""
        training_args = TrainingArguments(
            output_dir=output_dir,
            eval_strategy="epoch",
            save_strategy="epoch",
            learning_rate=2e-5,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            num_train_epochs=3,
            weight_decay=0.01,
            logging_dir=os.path.join(output_dir, "logs"),
            logging_steps=10,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            greater_is_better=True,
            remove_unused_columns=False
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            compute_metrics=self.compute_metrics
        )
        
        trainer.train()
        return trainer.evaluate()
    
    def predict(self, text: str) -> Dict[str, Any]:
        """Make prediction for input text"""
        self.model.eval()
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        ).to(self.device)
        
        with torch.no_grad():
            start_time = time.time()
            outputs = self.model(**inputs)
            response_time = time.time() - start_time
            
        prediction = outputs.logits.argmax().item()
        confidence = torch.softmax(outputs.logits, dim=1).max().item()
        
        return {
            "prediction": prediction,
            "confidence": confidence,
            "response_time": response_time
        }

# Initialize the chatbot
chatbot = AnatomyChatbot()
print("Chatbot initialized successfully")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Chatbot initialized successfully


In [3]:
# Load and prepare the dataset
# Update this path to where your JSONL files are stored
DATA_DIR = "./data"

try:
    # Load dataset
    dataset = chatbot.load_and_prepare_dataset(DATA_DIR)
    print(f"Dataset loaded successfully. Shape: {dataset.shape}")
    
    # Display sample
    print("\nSample data:")
    display(dataset.head())
    
except Exception as e:
    print(f"Error loading dataset: {e}")

Dataset loaded successfully. Shape: (12060, 2)

Sample data:


Unnamed: 0,text,labels
0,Anatomy_Gray What is anatomy? Anatomy includes...,0
1,Anatomy_Gray Observation and visualization are...,0
2,Anatomy_Gray How can gross anatomy be studied?...,0
3,"Anatomy_Gray This includes the vasculature, th...",0
4,Anatomy_Gray Each of these approaches has bene...,0


In [4]:
# Split and prepare datasets
train_data, test_data = chatbot.split_dataset(dataset)

# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_data)
test_dataset = Dataset.from_pandas(test_data)

# Map the tokenization function across the datasets
train_dataset = train_dataset.map(chatbot.tokenize_function, batched=True)
test_dataset = test_dataset.map(chatbot.tokenize_function, batched=True)

print(f"Training set size: {len(train_dataset)}")
print(f"Testing set size: {len(test_dataset)}")

Map:   0%|          | 0/9648 [00:00<?, ? examples/s]

Map:   0%|          | 0/2412 [00:00<?, ? examples/s]

Training set size: 9648
Testing set size: 2412


In [None]:
# Train the model
try:
    print("Starting model training...")
    metrics = chatbot.train(train_dataset, test_dataset)
    print("\nTraining completed. Final metrics:")
    for key, value in metrics.items():
        print(f"{key}: {value}")
except Exception as e:
    print(f"Error during training: {e}")

Starting model training...


Epoch,Training Loss,Validation Loss


In [None]:
# Flask API setup
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict_endpoint():
    try:
        data = request.json
        if not data or 'input' not in data:
            return jsonify({"error": "No input provided"}), 400
            
        result = chatbot.predict(data['input'])
        return jsonify(result)
        
    except Exception as e:
        return jsonify({"error": str(e)}), 500

# Start the Flask server
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000)

## Testing the API

Once the Flask server is running, you can test it using the following cell:

In [None]:
import requests

# Test the API
test_text = "The human heart has four chambers"
response = requests.post('http://localhost:5000/predict', 
                        json={'input': test_text})

print("API Response:")
print(response.json())