<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Example_Symbolic_Driven_Decision_Making_in_Medical_Diagnostics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Rule-Based System
class RuleBasedSystem:
    def __init__(self, rules):
        self.rules = rules

    def diagnose(self, input_data):
        for rule in self.rules:
            if rule['condition'](input_data):
                return rule['diagnosis']
        return None

# Example neural network for medical diagnostics
class MyMedicalNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyMedicalNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def predict(self, input_data):
        # Assuming input_data is a tensor of appropriate shape
        self.eval()
        with torch.no_grad():
            output = self(input_data)
            return "Diagnosis: Condition A" if output.item() > 0.5 else "Diagnosis: Condition B"

# Hybrid Model
class HybridModel:
    def __init__(self, rule_based_system, neural_network):
        self.rule_based_system = rule_based_system
        self.neural_network = neural_network

    def diagnose(self, input_data):
        # Use rule-based system first
        rule_based_diagnosis = self.rule_based_system.diagnose(input_data)
        if rule_based_diagnosis is not None:
            return rule_based_diagnosis

        # If no rule-based diagnosis, use the neural network
        # Assuming we convert input_data to the appropriate format for the NN
        nn_input = torch.tensor([len(input_data['symptoms']), input_data['age']], dtype=torch.float).unsqueeze(0)
        neural_network_diagnosis = self.neural_network.predict(nn_input)
        return neural_network_diagnosis

# Example rules
def fever_rule(input_data):
    return "fever" in input_data['symptoms']

rules = [{'condition': fever_rule, 'diagnosis': "Flu"}]

# Instantiate rule-based system
rule_based_system = RuleBasedSystem(rules)

# Instantiate neural network
input_size = 2  # Example: number of symptoms and age
hidden_size = 4
output_size = 1
neural_network = MyMedicalNN(input_size, hidden_size, output_size)

# Instantiate hybrid model
model = HybridModel(rule_based_system, neural_network)

# Input data for diagnosis
input_data = {"symptoms": ["fever", "cough"], "age": 45}

# Get diagnosis
diagnosis = model.diagnose(input_data)
print("Diagnosis:", diagnosis)

# Sample DataLoader setup (replace with actual data loading logic)
class MedicalDataset(Dataset):
    def __init__(self, symptom_counts, ages, diagnoses):
        self.symptom_counts = symptom_counts
        self.ages = ages
        self.diagnoses = diagnoses

    def __len__(self):
        return len(self.symptom_counts)

    def __getitem__(self, idx):
        return self.symptom_counts[idx], self.ages[idx], self.diagnoses[idx]

# Create dummy data
symptom_counts = torch.tensor([2, 3, 1, 4], dtype=torch.float)
ages = torch.tensor([45, 34, 23, 54], dtype=torch.float)
diagnoses = torch.tensor([1, 0, 1, 0], dtype=torch.float)

dataset = MedicalDataset(symptom_counts, ages, diagnoses)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Model and training setup
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(neural_network.parameters(), lr=0.01)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for symptom_counts, ages, diagnoses in dataloader:
        optimizer.zero_grad()
        outputs = neural_network(torch.stack((symptom_counts, ages), dim=1))
        loss = criterion(outputs.squeeze(), diagnoses)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item():.4f}')

# Making a neural network diagnosis for a patient
test_input = torch.tensor([[len(["fever", "cough"]), 45]], dtype=torch.float)
neural_network.eval()
with torch.no_grad():
    nn_diagnosis = neural_network.predict(test_input)
    print(nn_diagnosis)