In [None]:
from google.colab import drive
drive.mount('/content/drive')
print("Google Drive mounted successfully!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully!


In [None]:
import os
import random
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, auc
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gym
from gym import spaces
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties

# 设置中文字体，避免中文显示为方块
font = FontProperties(fname=r"C:\Windows\Fonts\simhei.ttf", size=12)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# ----------------------------
DATA_DIR = r'/content/drive/MyDrive/1/dataset/dataset/train'
METADATA_PATH = r'/content/drive/MyDrive/1/metadata.csv'
FEATURES = ['ton', 'thrust', 'mfr']
WINDOW_SIZE = 10
TARGET_LENGTH = 1000  # 采样后序列长度，确保 >= WINDOW_SIZE
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# ----------------------------
metadata = pd.read_csv(METADATA_PATH)

# ----------------------------
def uniform_downsample(X, y, target_length):
    length = X.shape[0]
    if length <= target_length:
        return X, y
    interval = length / target_length
    indices = (np.floor(np.arange(target_length) * interval)).astype(int)
    indices = np.clip(indices, 0, length - 1)
    return X[indices], y[indices]

# ----------------------------
all_files = metadata['filename'].values.tolist()
sampled_files = all_files
print(f"Using all {len(sampled_files)} files for training/testing.")

# ----------------------------
data_dict = {}
for f in sampled_files:
    file_path = os.path.join(DATA_DIR, f)
    if os.path.exists(file_path):
        df = pd.read_csv(file_path)
        # 将所有非0异常标签统一为1，实现二分类
        if 'anomaly_code' not in df.columns:
            df['anomaly_code'] = 0
        else:
            df['anomaly_code'] = df['anomaly_code'].fillna(0).astype(int)
            df['anomaly_code'] = df['anomaly_code'].apply(lambda x: 0 if x == 0 else 1)
        data_dict[f] = df
    else:
        print(f"Warning: file {f} not found in data directory.")

# ----------------------------
normal_files = [f for f in sampled_files if metadata.loc[metadata['filename'] == f, 'anomaly_code'].values[0] == 0]
all_normal_data = []
for f in normal_files:
    all_normal_data.append(data_dict[f][FEATURES].values)
all_normal_data = np.vstack(all_normal_data)
scaler = StandardScaler()
scaler.fit(all_normal_data)

# ----------------------------
def preprocess(df):
    X = df[FEATURES].values
    X = scaler.transform(X)
    y = df['anomaly_code'].values
    return X, y

# ----------------------------
class AnomalyDetectionEnv(gym.Env):
    def __init__(self, data, labels, window_size=WINDOW_SIZE):
        super().__init__()
        self.data = data.astype(np.float32)
        self.labels = labels.astype(np.int64)
        self.window_size = window_size
        self.feature_dim = data.shape[1]
        self.T = data.shape[0]
        self.current_step = 0

        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(window_size, self.feature_dim), dtype=np.float32)
        self.action_space = spaces.Discrete(2)  # 0正常，1异常

    def reset(self):
        self.current_step = self.window_size - 1
        return self._get_state()

    def _get_state(self):
        return self.data[self.current_step - self.window_size + 1:self.current_step + 1]

    def step(self, action):
        true_label = self.labels[self.current_step]
        done = False

        if action == true_label:
            reward = 1.0 if action == 1 else 0.5
        else:
            reward = -1.0 if true_label == 1 else -0.5

        self.current_step += 1
        if self.current_step >= self.T - 1:
            done = True

        next_state = self._get_state() if not done else np.zeros((self.window_size, self.feature_dim), dtype=np.float32)
        return next_state, reward, done, {}

# ----------------------------
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, window_size):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 2)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        logits = self.fc(out)
        return logits

# ----------------------------
class PPOAgent:
    def __init__(self, policy_net, lr=3e-4, gamma=0.99, eps_clip=0.2):
        self.policy_net = policy_net.to(DEVICE)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.gamma = gamma
        self.eps_clip = eps_clip

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(DEVICE)
        logits = self.policy_net(state)
        probs = torch.softmax(logits, dim=-1)
        dist = Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action), probs[0,1].item()

    def compute_returns(self, rewards, dones):
        returns = []
        R = 0
        for r, done in zip(reversed(rewards), reversed(dones)):
            if done:
                R = 0
            R = r + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32).to(DEVICE)
        return returns

    def update(self, states, actions, old_log_probs, returns):
        states = torch.stack(states).to(DEVICE)
        actions = torch.tensor(actions).to(DEVICE)
        old_log_probs = torch.stack(old_log_probs).to(DEVICE)
        returns = returns.detach()

        logits = self.policy_net(states)
        probs = torch.softmax(logits, dim=-1)
        dist = Categorical(probs)
        log_probs = dist.log_prob(actions)

        ratios = torch.exp(log_probs - old_log_probs)
        advantages = returns - returns.mean()

        surr1 = ratios * advantages
        surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

        loss = -torch.min(surr1, surr2).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

# ----------------------------
train_sns = [f"SN{str(i).zfill(2)}" for i in range(1,13)]
test_sns = [f"SN{str(i).zfill(2)}" for i in range(13,25)]

train_data_list, train_label_list = [], []
test_data_list, test_label_list = [], []

for f in sampled_files:
    sn = None
    for s in train_sns:
        if s in f:
            sn = s
            break
    if sn is not None:
        X, y = preprocess(data_dict[f])
        # 如果想用全部数据，不降采样，注释下一行
        X, y = uniform_downsample(X, y, TARGET_LENGTH)
        train_data_list.append(X)
        train_label_list.append(y)
    else:
        for s in test_sns:
            if s in f:
                sn = s
                break
        if sn is not None:
            X, y = preprocess(data_dict[f])
            # 如果想用全部数据，不降采样，注释下一行
            X, y = uniform_downsample(X, y, TARGET_LENGTH)
            test_data_list.append(X)
            test_label_list.append(y)

train_data = np.vstack(train_data_list)
train_labels = np.concatenate(train_label_list)
test_data = np.vstack(test_data_list)
test_labels = np.concatenate(test_label_list)

print(f"训练数据长度: {train_data.shape[0]}, 测试数据长度: {test_data.shape[0]}")

# ----------------------------
train_env = AnomalyDetectionEnv(train_data, train_labels, window_size=WINDOW_SIZE)
policy_net = PolicyNetwork(input_dim=len(FEATURES), hidden_dim=64, window_size=WINDOW_SIZE)
agent = PPOAgent(policy_net)

# ----------------------------
def train_ppo(env, agent, max_episodes=100, update_timestep=2000):
    timestep = 0
    memory = {'states': [], 'actions': [], 'log_probs': [], 'rewards': [], 'dones': []}

    for ep in range(max_episodes):
        state = env.reset()
        ep_reward = 0
        done = False

        while not done:
            with torch.no_grad():
                action, log_prob, _ = agent.select_action(state)

            memory['states'].append(torch.tensor(state, dtype=torch.float32))
            memory['actions'].append(action)
            memory['log_probs'].append(log_prob)

            next_state, reward, done, _ = env.step(action)
            memory['rewards'].append(reward)
            memory['dones'].append(done)

            state = next_state
            ep_reward += reward
            timestep += 1

            if timestep % update_timestep == 0 or done:
                returns = agent.compute_returns(memory['rewards'], memory['dones'])
                agent.update(memory['states'], memory['actions'], memory['log_probs'], returns)
                memory = {'states': [], 'actions': [], 'log_probs': [], 'rewards': [], 'dones': []}

        print(f"Episode {ep+1} Reward: {ep_reward:.2f}")

# ----------------------------
train_ppo(train_env, agent, max_episodes=50, update_timestep=1000)

# ----------------------------
def test_agent(agent, data_dict, test_files, window_size=WINDOW_SIZE, target_length=TARGET_LENGTH):
    all_preds = []
    all_true = []
    all_probs = []

    for f in test_files:
        X, y = preprocess(data_dict[f])
        # 如果想用全部数据，不降采样，注释下一行
        X, y = uniform_downsample(X, y, target_length)
        env = AnomalyDetectionEnv(X, y, window_size=window_size)
        state = env.reset()
        done = False
        preds = []
        probs = []
        while not done:
            action, _, prob = agent.select_action(state)
            preds.append(action)
            probs.append(prob)
            state, _, done, _ = env.step(action)
        all_preds.extend(preds)
        all_true.extend(y[window_size-1:window_size-1+len(preds)])
        all_probs.extend(probs)

    all_true = np.array(all_true)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)

    print("\n测试集分类报告:")
    print(classification_report(all_true, all_preds, digits=4))

    accuracy = accuracy_score(all_true, all_preds)
    precision = precision_score(all_true, all_preds, zero_division=0)
    recall = recall_score(all_true, all_preds, zero_division=0)
    f1 = f1_score(all_true, all_preds, zero_division=0)

    try:
        auc_roc = roc_auc_score(all_true, all_probs)
    except ValueError:
        auc_roc = float('nan')

    print(f"准确率 (Accuracy): {accuracy:.4f}")
    print(f"精确率 (Precision): {precision:.4f}")
    print(f"召回率 (Recall): {recall:.4f}")
    print(f"F1分数 (F1-score): {f1:.4f}")
    print(f"AUC-ROC: {auc_roc:.4f}")

    plot_time_series(all_true, all_preds, length=500)
    plot_confusion_matrix(all_true, all_preds)
    plot_roc_curve(all_true, all_probs)
    plot_precision_recall_curve(all_true, all_probs)

    return all_true, all_preds, all_probs

# ----------------------------
def plot_time_series(true_labels, pred_labels, length=500):
    plt.figure(figsize=(15,4))
    x = np.arange(length)
    plt.plot(x, true_labels[:length], label='真实异常标签', color='blue', marker='x', linestyle='-')
    plt.plot(x, pred_labels[:length], label='预测异常标签', color='red', marker='o', linestyle='--')
    plt.xlabel('时间步', fontproperties=font)
    plt.ylabel('异常状态 (0=正常,1=异常)', fontproperties=font)
    plt.title(f'异常检测结果对比（前{length}步）', fontproperties=font)
    plt.legend(prop=font)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(true_labels, pred_labels):
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['正常','异常'], yticklabels=['正常','异常'])
    plt.xlabel('预测标签', fontproperties=font)
    plt.ylabel('真实标签', fontproperties=font)
    plt.title('混淆矩阵', fontproperties=font)
    plt.show()

def plot_roc_curve(true_labels, pred_probs):
    fpr, tpr, _ = roc_curve(true_labels, pred_probs)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(6,5))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.4f})')
    plt.plot([0,1], [0,1], color='navy', lw=2, linestyle='--')
    plt.xlabel('假阳性率 (FPR)', fontproperties=font)
    plt.ylabel('真阳性率 (TPR)', fontproperties=font)
    plt.title('ROC曲线', fontproperties=font)
    plt.legend(loc='lower right', prop=font)
    plt.grid(True)
    plt.show()

def plot_precision_recall_curve(true_labels, pred_probs):
    precision, recall, _ = precision_recall_curve(true_labels, pred_probs)
    pr_auc = auc(recall, precision)
    plt.figure(figsize=(6,5))
    plt.plot(recall, precision, color='purple', lw=2, label=f'PR曲线 (AUC = {pr_auc:.4f})')
    plt.xlabel('召回率 (Recall)', fontproperties=font)
    plt.ylabel('精确率 (Precision)', fontproperties=font)
    plt.title('精确率-召回率曲线', fontproperties=font)
    plt.legend(loc='lower left', prop=font)
    plt.grid(True)
    plt.show()

# ----------------------------
test_files = [f for f in sampled_files if any(s in f for s in test_sns)]
all_true, all_preds, all_probs = test_agent(agent, data_dict, test_files, window_size=WINDOW_SIZE, target_length=TARGET_LENGTH)


Using device: cpu
Using all 2612 files for training/testing.


  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)




  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)




  return datetime.utcnow().replace(tzinfo=utc)




KeyError: '01270_002_SN13_21bars_ssf.csv'

  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cp -r /root/.cache/kagglehub/datasets/patrickfleith/spacecraft-thruster-firing-tests-dataset/versions/1 /content/drive/MyDrive/

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("patrickfleith/spacecraft-thruster-firing-tests-dataset")

print("Path to dataset files:", path)

  return datetime.utcnow().replace(tzinfo=utc)


Downloading from https://www.kaggle.com/api/v1/datasets/download/patrickfleith/spacecraft-thruster-firing-tests-dataset?dataset_version_number=1...


100%|██████████| 2.97G/2.97G [00:51<00:00, 62.5MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/patrickfleith/spacecraft-thruster-firing-tests-dataset/versions/1


In [None]:
pwd


'/content'

In [None]:
ls pwd


ls: cannot access 'pwd': No such file or directory


In [None]:
ls /content

[0m[01;34mdrive[0m/  DRL2.py  [01;34msample_data[0m/


In [None]:
cp -r /root/.cache/kagglehub/datasets/patrickfleith/spacecraft-thruster-firing-tests-dataset/versions/1 /content/drive/MyDrive/

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("patrickfleith/spacecraft-thruster-firing-tests-dataset")

print("Path to dataset files:", path)

  return datetime.utcnow().replace(tzinfo=utc)


Downloading from https://www.kaggle.com/api/v1/datasets/download/patrickfleith/spacecraft-thruster-firing-tests-dataset?dataset_version_number=1...


100%|██████████| 2.97G/2.97G [00:51<00:00, 62.5MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/patrickfleith/spacecraft-thruster-firing-tests-dataset/versions/1


In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("patrickfleith/spacecraft-thruster-firing-tests-dataset")

print("Path to dataset files:", path)

  return datetime.utcnow().replace(tzinfo=utc)


Downloading from https://www.kaggle.com/api/v1/datasets/download/patrickfleith/spacecraft-thruster-firing-tests-dataset?dataset_version_number=1...


100%|██████████| 2.97G/2.97G [00:51<00:00, 62.5MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/patrickfleith/spacecraft-thruster-firing-tests-dataset/versions/1
