In [15]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import accuracy_score, classification_report, f1_score, precision_score, roc_auc_score


In [16]:
data_path = 'cleaned_data1.csv'
data = pd.read_csv(data_path)

In [17]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.features = data.drop(columns=['aki_stage']).values
        self.labels = data['aki_stage'].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        feature = self.features[idx]
        label = self.labels[idx]
        return torch.tensor(feature, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

dataset = CustomDataset(data)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [18]:
# Initialize models
class ImpModel(nn.Module):
    def __init__(self, input_dim):
        super(ImpModel, self).__init__()
        self.layer = nn.Linear(input_dim, 10)

    def forward(self, x):
        return self.layer(x)

#classifiers and policies
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.layer = nn.Linear(10, 3)  # 假设aki_stage有3个类别

    def forward(self, x):
        return self.layer(x)

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.layer = nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

def construct_rl_environment(state_embedding):
    pass

def run_rl_policy(policy, T):
    observations = []
    for t in range(T):
        observation = np.random.randn(10)
        observations.append(observation)
    return observations

def update_panel_selection_policy(observations, num_classes=3):
    return np.random.choice(num_classes, size=len(observations))

In [19]:
# main
def smddpo(dataloader, val_loader, num_epochs=10):
    input_dim = dataset.features.shape[1]
    imp_model = ImpModel(input_dim)
    classifier = Classifier()
    policy = Policy()
    L, L1, L2, eta = num_epochs, 10, 10, 0.01
    best_accuracy = 0.0

    for epoch in range(L):
        imp_model.train()
        classifier.train()
        for inputs, labels in dataloader:
            state_embedding = imp_model(inputs)
            construct_rl_environment(state_embedding)

            for j in range(L1):
                observations = run_rl_policy(policy, T=10)
                panel_selection_policy = update_panel_selection_policy(observations, num_classes=3)

            classifier_optimizer = optim.SGD(classifier.parameters(), lr=eta)

            for j in range(L2):
                B_j = torch.tensor(observations, dtype=torch.float32)
                selected_labels = torch.tensor(panel_selection_policy, dtype=torch.long)
                if len(selected_labels) != len(B_j):
                    selected_labels = selected_labels[:len(B_j)]

                outputs = classifier(B_j)
                loss = nn.CrossEntropyLoss()(outputs, selected_labels)

                classifier_optimizer.zero_grad()
                loss.backward()
                classifier_optimizer.step()

        # 在每个epoch结束时评估模型
        accuracy, f1, precision, auc, _ = evaluate(imp_model, classifier, val_loader)
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'imp_model_state_dict': imp_model.state_dict(),
                'classifier_state_dict': classifier.state_dict(),
                'policy_state_dict': policy.state_dict(),
            }, 'best_model.pth')
        print(f"Epoch {epoch + 1}/{L}, Accuracy: {accuracy}, F1: {f1}, Precision: {precision}, AUC: {auc}")

    return imp_model, classifier, policy


In [20]:
# train
imp_model, classifier, policy = smddpo(train_loader, val_loader, num_epochs=20)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/20, Accuracy: 0.7318435754189944, F1: 0.6513287649870196, Precision: 0.5867739909424482, AUC: 0.49367854402262706


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2/20, Accuracy: 0.6033519553072626, F1: 0.5871036979830095, Precision: 0.5717076219869516, AUC: 0.48412961041359703


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3/20, Accuracy: 0.6815642458100558, F1: 0.6229474219714171, Precision: 0.5736146292376557, AUC: 0.5375921140539358


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 4/20, Accuracy: 0.7569832402234636, F1: 0.6522812658205362, Precision: 0.5730236259792141, AUC: 0.47717024321223195


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 5/20, Accuracy: 0.5893854748603352, F1: 0.5850676325536661, Precision: 0.580812595226003, AUC: 0.526655097941819


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 6/20, Accuracy: 0.019553072625698324, F1: 0.037849162011173186, Precision: 0.5887647423960273, AUC: 0.44990606926821813


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 7/20, Accuracy: 0.723463687150838, F1: 0.6470582812471192, Precision: 0.5852497290085884, AUC: 0.4799341078805401


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 8/20, Accuracy: 0.0335195530726257, F1: 0.059328548612629796, Precision: 0.5236233040702314, AUC: 0.5504744556616513


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 9/20, Accuracy: 0.05307262569832402, F1: 0.05261901351286826, Precision: 0.05217308966953887, AUC: 0.48219634740105916


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 10/20, Accuracy: 0.11173184357541899, F1: 0.18240560005384665, Precision: 0.4963824526055499, AUC: 0.5615534419781888


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 11/20, Accuracy: 0.0, F1: 0.0, Precision: 0.0, AUC: 0.534543394406798


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 12/20, Accuracy: 0.7458100558659218, F1: 0.6514755037348566, Precision: 0.5990403335026241, AUC: 0.4830472766183484


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 13/20, Accuracy: 0.09217877094972067, F1: 0.03751837694795648, Precision: 0.023552285088474883, AUC: 0.49136628441014524


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 14/20, Accuracy: 0.09497206703910614, F1: 0.03865529746153092, Precision: 0.024265990697216546, AUC: 0.48323236278689174


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 15/20, Accuracy: 0.36312849162011174, F1: 0.40559276775253505, Precision: 0.5856755741388786, AUC: 0.547060801102718


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 16/20, Accuracy: 0.7569832402234636, F1: 0.6522812658205362, Precision: 0.5730236259792141, AUC: 0.5414918628302728


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 17/20, Accuracy: 0.7569832402234636, F1: 0.6522812658205362, Precision: 0.5730236259792141, AUC: 0.501904190795393


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 18/20, Accuracy: 0.6201117318435754, F1: 0.6196903115753428, Precision: 0.619391977771866, AUC: 0.5226170278944169


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 19/20, Accuracy: 0.2569832402234637, F1: 0.2413710785469957, Precision: 0.6378819540995025, AUC: 0.49522067189019037
Epoch 20/20, Accuracy: 0.16201117318435754, F1: 0.04517619252256123, Precision: 0.026247620236571893, AUC: 0.47812079169423916


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [21]:
# evaluate function
def evaluate(imp_model, classifier, dataloader):
    imp_model.eval()
    classifier.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            embeddings = imp_model(inputs)
            outputs = classifier(embeddings)
            probs = nn.Softmax(dim=1)(outputs)
            _, preds = torch.max(outputs, 1)
            all_probs.extend(probs.numpy())
            all_preds.extend(preds.numpy())
            all_labels.extend(labels.numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    precision = precision_score(all_labels, all_preds, average='weighted')

    try:
        auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
    except ValueError:
        auc = None

    report = classification_report(all_labels, all_preds)

    return accuracy, f1, precision, auc, report

In [22]:
# best_model
checkpoint = torch.load('best_model.pth')
imp_model.load_state_dict(checkpoint['imp_model_state_dict'])
classifier.load_state_dict(checkpoint['classifier_state_dict'])
policy.load_state_dict(checkpoint['policy_state_dict'])

# evaluate model
accuracy, f1, precision, auc, report = evaluate(imp_model, classifier, val_loader)
print(f"accuracy：{accuracy}")
print(f"F1 score：{f1}")
print(f"precision：{precision}")
print(f"auc：{auc}")
print(f"report：\n{report}")

accuracy：0.7569832402234636
F1 score：0.6522812658205362
precision：0.5730236259792141
auc：0.47717024321223195
report：
              precision    recall  f1-score   support

           1       0.76      1.00      0.86       271
           2       0.00      0.00      0.00        58
           3       0.00      0.00      0.00        29

    accuracy                           0.76       358
   macro avg       0.25      0.33      0.29       358
weighted avg       0.57      0.76      0.65       358



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
