In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import OneHotEncoder, MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, jaccard_score, hamming_loss, precision_recall_curve
from torch.utils.data import TensorDataset, DataLoader

df = pd.read_csv("/home/chanbo.s/personalized_ecoli/new_code/merged_microbiology_admissions_final.csv.csv")
bins = [0, 20, 40, 60, 80, 100]
labels = ['0-20', '21-40', '41-60', '61-80', '80+']
df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels, include_lowest=True)
df['effective_antibiotics'] = df['effective_antibiotics'].apply(
    lambda x: [ant.strip() for ant in x.strip("[]").replace("'", "").split(",") if ant.strip()]
)

encoder = OneHotEncoder(sparse_output=False, drop='first')
X = encoder.fit_transform(df[['gender', 'age_group']])


mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['effective_antibiotics'])


X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.1, random_state=42)


X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=32, shuffle=True)

class AntibioticNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(AntibioticNN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, output_dim)
        )

    def forward(self, x):
        return self.net(x)


model = AntibioticNN(X_train.shape[1], y_train.shape[1])
class_counts = y_train.sum(axis=0)
weights = 1.0 / (class_counts + 1e-6)
weights = torch.tensor(weights, dtype=torch.float32)
criterion = nn.BCEWithLogitsLoss(pos_weight=weights)

optimizer = optim.Adam(model.parameters(), lr=0.001)


best_val_loss = float('inf')
patience = 10
counter = 0
epochs = 100

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val_tensor)
        val_loss = criterion(val_outputs, y_val_tensor).item()

    print(f"Epoch {epoch+1} - Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break


model.load_state_dict(torch.load("best_model.pth"))


model.eval()
with torch.no_grad():
    y_logits = model(X_test_tensor)
    y_prob = torch.sigmoid(y_logits).numpy()

optimal_thresholds = []
for i in range(y.shape[1]):
    precision, recall, thresholds = precision_recall_curve(y_test[:, i], y_prob[:, i])
    f1_scores = 2 * precision * recall / (precision + recall + 1e-8)
    best_threshold = thresholds[np.argmax(f1_scores)] if len(thresholds) > 0 else 0.5
    optimal_thresholds.append(best_threshold)


y_pred_binary = np.array([
    (y_prob[:, i] >= optimal_thresholds[i]).astype(int)
    for i in range(y.shape[1])
]).T


f1 = f1_score(y_test, y_pred_binary, average='macro')
jaccard = jaccard_score(y_test, y_pred_binary, average='samples')
hamming = hamming_loss(y_test, y_pred_binary)

print(f"\nFinal Evaluation:")
print(f"F1 Score (macro): {f1:.3f}")
print(f"Jaccard Similarity: {jaccard:.3f}")
print(f"Hamming Loss: {hamming:.3f}")

# After evaluation prints

def predict_antibiotics(gender, age):
    model.eval()
    age_group = pd.cut([age], bins=[0, 20, 40, 60, 80, 100], labels=labels, include_lowest=True)[0]
    input_df = pd.DataFrame({'gender': [gender], 'age_group': [age_group]})
    input_encoded = encoder.transform(input_df)
    input_tensor = torch.tensor(input_encoded, dtype=torch.float32)

    with torch.no_grad():
        logits = model(input_tensor)
        probas = torch.sigmoid(logits).numpy()[0]

    preds = (probas >= optimal_thresholds).astype(int).reshape(1, -1)
    recommended_antibiotics = mlb.inverse_transform(preds)
    return recommended_antibiotics

gender, age = 'F', 30
recommended = predict_antibiotics(gender, age)
print(f"Recommended Antibiotics for {gender}, Age {age}: {recommended}")


Epoch 1 - Train Loss: 6.4469 | Val Loss: 0.0010
Epoch 2 - Train Loss: 0.3744 | Val Loss: 0.0006
Epoch 3 - Train Loss: 0.3143 | Val Loss: 0.0005
Epoch 4 - Train Loss: 0.3041 | Val Loss: 0.0005
Epoch 5 - Train Loss: 0.3019 | Val Loss: 0.0005
Epoch 6 - Train Loss: 0.2999 | Val Loss: 0.0005
Epoch 7 - Train Loss: 0.3031 | Val Loss: 0.0005
Epoch 8 - Train Loss: 0.3000 | Val Loss: 0.0005
Epoch 9 - Train Loss: 0.2992 | Val Loss: 0.0005
Epoch 10 - Train Loss: 0.3025 | Val Loss: 0.0005
Epoch 11 - Train Loss: 0.3004 | Val Loss: 0.0005
Epoch 12 - Train Loss: 0.3004 | Val Loss: 0.0005
Epoch 13 - Train Loss: 0.3017 | Val Loss: 0.0005
Epoch 14 - Train Loss: 0.2985 | Val Loss: 0.0005
Epoch 15 - Train Loss: 0.3002 | Val Loss: 0.0005
Epoch 16 - Train Loss: 0.3025 | Val Loss: 0.0005
Epoch 17 - Train Loss: 0.3006 | Val Loss: 0.0005
Epoch 18 - Train Loss: 0.3028 | Val Loss: 0.0005
Epoch 19 - Train Loss: 0.3014 | Val Loss: 0.0005
Epoch 20 - Train Loss: 0.3013 | Val Loss: 0.0005
Epoch 21 - Train Loss: 0.2986

