In [19]:
import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
from numpy.matlib import empty
from pandas.core.interchange.dataframe_protocol import DataFrame
import os
import os.path
import shutil


plt.rcParams['font.sans-serif']=['Songti SC'] #用来正常显示中文标签

def read_tbm_data(file_path):
    '''
    读取TBM数据文件
    参数:
        file_path: 文件路径
    返回:
        df: 处理后的DataFrame
    '''
    #读取第一行，防止列名数和列数不匹配
    # 读取数据，使用制表符分隔
    df = pd.read_csv(file_path, sep='\t',index_col=False)
    df = df.loc[:,['时间戳','运行时间','刀盘转速','推进速度','刀盘扭矩','贯入度','总推进力','推进速度给定百分比','刀盘给定转速显示值','推进速度电位器设定值']]
    return df

def tbm_cycle_judge(df,buffer_size = 300,threshold = 0.1):
    #可以添加条件(df['总推进力']>0)
    df['is_extract'] = (df["刀盘转速"]>0) & (df["刀盘扭矩"]>0) & (df["贯入度"]<=20) &(df["贯入度"]>0)
    df['运行时间'] = pd.to_datetime(df['运行时间'])
    df.sort_values('运行时间',inplace = True)

    #缓冲区buffer_size = 300
    #设定阈值用以判断一个小窗口是否符合，比例大于90%则认为是稳定掘进状态,threshold = 0.1
    #利用滑动窗口来判断掘进循环
    df["smoothing"] = False
    #df.rolling()在整个dataframe上使用rolling函数，再调用is_extract列来取窗口平均值判定
    extract_mean = df.rolling(window=60, on='运行时间', closed='both')['is_extract'].mean()
    df['smoothing'] = extract_mean > threshold


    #找到smoothing由false转换为true的那部分（起始索引）,并利用缓冲区往前延伸
    smoothing_shift = df['smoothing'].shift(1,fill_value=False)
    smoothing_starts = df[(df['smoothing'] == False) & (smoothing_shift == True)].index
    for start_idx in smoothing_starts:
        buffer_start = max(start_idx-buffer_size, 0)
        df.loc[buffer_start:start_idx,'smoothing'] = True

    #检测状态变化
    df['cycle_id'] = (df['smoothing']!=df['smoothing'].shift(1)).cumsum()
    #仅保留掘进状态的循环，df[df['smoothing']]只选择值为true的循环
    excavation_cycles = df[df['smoothing']].groupby('cycle_id')
    #过滤掉时长不足600s的循环
    valid_cycle = []
    duration = 0
    for cycle_id , cycle in excavation_cycles:
        duration = len(cycle)
        if duration > 600:
            valid_cycle.append({"cycle_id": cycle_id,
                               'starttime': cycle['运行时间'].iloc[0],
                               'endtime' :cycle['运行时间'].iloc[-1],
                               'duration': duration,
            })
    return valid_cycle


def display(valid_cycle,tbm_para,df):
    cycle_df = pd.DataFrame(valid_cycle)
    for index ,cycle in cycle_df.iterrows():
        #找出每个循环段的开始和截止时间，以及持续时间
        cycle_id = cycle['cycle_id']
        start = cycle['starttime']
        end = cycle['endtime']
        duration = cycle['duration']
        #提取，并且利用这个把需要的数据筛选出来
        #cycle_data = df[(df['运行时间']>=end - pd.Timedelta(minutes=35)) & (df['运行时间'] <= end)]
        #cycle_data = df[(df['运行时间']>=start-pd.Timedelta(minutes=1)) & (df['运行时间'] <= end)]
        cycle_data = df[(df['运行时间']>=start) & (df['运行时间'] <= end)]
        #绘制图像
        plt.figure(dpi = 200)
        plt.plot(cycle_data['运行时间'],cycle_data[tbm_para],'b',label = f'Cycle{cycle_id}')
        plt.xlabel('运行时间')
        plt.ylabel(tbm_para)
        plt.title('{}-掘进时间关系图'.format(tbm_para))
        plt.legend()
        plt.grid(True)
        plt.xticks(rotation = 90)
        plt.tight_layout()
        plt.show()

'''
异常值判定和处理
对于掘进过程中的极大异常点,采用 3σ准则判定。
通过计算整个掘进步贯入度和掘进速度的均值μ和标准差σ,当掘进速度和贯入度大于μ+ 3σ时,则按照异常点处理。
对识别出来的异常点,取5个临近数据点的平均值替代。
数据平滑处理
为了消除掘进参数中的白噪声,本文采用滑动均值滤波来处理。取窗口长度为n,从第一个数据点开始,计算相邻的n个数据点的算术平均值并作为该点滤波之后的新值,
均值滤波的窗口长度决定影响该点数值的数据范围,当选择较大的窗口长度时,可得到更加平滑曲线,但是忽视了很多数据的变化细节,如果选择的窗口长度过小,则噪声消除的效果不够理想。本文采用的滤波窗口长度为15s。
'''
def handle_outlier_data(df,valid_cycle,col_name,outlier_window):
    #这里写3sigma处理的算法
    cycle_df = pd.DataFrame(valid_cycle)
    for index,cycle in cycle_df.iterrows():
        cycle_id = cycle['cycle_id']
        start = cycle['starttime']
        end = cycle['endtime']
        cycle_data = df[(df['运行时间']>=start-pd.Timedelta(minutes=1)) & (df['运行时间'] <= end)]
        mu = cycle_data[col_name].mean()
        sigma = cycle_data[col_name].std()
        outlier = df[df[col_name] > (3*sigma + mu)]
        for idx in outlier.index:
            start_idx =max(idx - outlier_window, 0)
            end_idx = min(len(df),idx + outlier_window)
            df.at[idx,col_name] = (df.loc[start_idx:idx-1,col_name].mean()+df.loc[idx+1:end_idx,col_name].mean()+mu)/3.0
    return df

def smooth_data(df,valid_cycle,col_name,smooth_window):
    #这里写smooth处理的方法
    cycle_df = pd.DataFrame(valid_cycle)
    for index,cycle in cycle_df.iterrows():
        cycle_id = cycle['cycle_id']
        start = cycle['starttime']
        end = cycle['endtime']
        cycle_data = df[(df['运行时间']>=start-pd.Timedelta(minutes=1)) & (df['运行时间'] <= end)]
        temp = cycle_data[col_name].rolling(window=smooth_window,min_periods=1,center=True).mean()
        df.loc[cycle_data.index,col_name] = temp
    return df

def data_presv(df,valid_cycle,col_names = ['推进速度','贯入度' ,'刀盘扭矩'],outlier_window=20,smooth_window = 15):
    #调用上面的两个函数
    for col in col_names:
        df = handle_outlier_data(df,valid_cycle,col,outlier_window)
        df = smooth_data(df,valid_cycle,col,smooth_window)
    return df

In [20]:
def classify_excavation_phase(df,cycle_list,T_threshold = 0,Tf_threshold = 600,Ff_threshold = 1000,sigma_threshold = 100,duration_threshold = 1000,v_set_rate = 0.1):
    #T_threshold = 0
    #Tf_threshold = 600
    #Ff_threshold = 1000
    #duration_threshold = 60
    phase_list = []
    cycle_list = pd.DataFrame(cycle_list)
    for index,cycle in cycle_list.iterrows():
        cycle_id = cycle['cycle_id']
        start = cycle['starttime']
        end = cycle['endtime']
        cycle_data = df[(df['运行时间']>=start-pd.Timedelta(minutes=1)) & (df['运行时间'] <= end)].copy()

        #空推段
        push_start = cycle_data[cycle_data['刀盘扭矩']>T_threshold].index
        push_start_time = cycle_data.at[push_start[0],'运行时间'] if not push_start.empty else None

        #上升段
        rising_start = cycle_data[(cycle_data['刀盘扭矩'] > Tf_threshold) & (cycle_data['总推进力'] > Ff_threshold)].index
        rising_start_time = cycle_data.at[rising_start[0],'运行时间'] if not rising_start.empty else None

        #稳定段
        #stable_phase_data = cycle_data[cycle_data['运行时间'] > rising_start_time].copy()
        cycle_data['v_std'] = cycle_data['推进速度'].rolling(window = 60,min_periods=1).std()
        cycle_data['v_set_std'] = cycle_data['推进速度电位器设定值'].rolling(window = 30,min_periods=1).std()
        cycle_data['v_set_m'] = cycle_data['推进速度电位器设定值'].rolling(window = 60,min_periods=1).mean()

        #动态阈值设定，防止返回none
        #sigma_threshold = cycle_data['v_set_std'].mean()*sigma_threshold
        v_set_threshold = cycle_data['v_set_m']*v_set_rate

        #条件为设定10%的阈值+利用设定推进速度的sigma来判断
        conditions = (cycle_data['v_set_std']<=sigma_threshold)&(cycle_data['推进速度电位器设定值']>cycle_data['v_set_std']*10)
        #conditions = (cycle_data['v_set_std']<=sigma_threshold)(cycle_data['推进速度']>v_set_threshold)

        cycle_data['is_stable_candidate'] = conditions
        #检查是否有duration_threshold的持续稳定，文中设定为60s
        stable_candidates = cycle_data[(cycle_data['is_stable_candidate']) & (cycle_data['运行时间']>rising_start_time+pd.Timedelta(minutes=1))].copy()
        #stable_candidates = cycle_data[cycle_data['v_set_std'] <= sigma_threshold].copy()
        stable_candidates['diff'] = stable_candidates['运行时间'].diff().dt.total_seconds()
        stable_candidates['group'] = (stable_candidates['diff'] > 3).cumsum()
        stable_groups = stable_candidates.groupby('group')

        stable_start_time = rising_start_time
        for group_id, group in stable_groups:
            duration = (group['运行时间'].iloc[-1] - group['运行时间'].iloc[0]).total_seconds()
            if duration >= duration_threshold:
                stable_start_time = group['运行时间'].iloc[0]
                break
        #下降段
        falling_start = cycle_data[(cycle_data['运行时间'] > stable_start_time)&(cycle_data['刀盘扭矩'] <= Tf_threshold) &(cycle_data['总推进力'] <= Ff_threshold)].index

        #falling_start = cycle_data[(cycle_data['刀盘扭矩'] <= Tf_threshold) &(cycle_data['总推进力'] <= Ff_threshold)].index
        falling_start_time = cycle_data.at[falling_start[0], '运行时间'] if not falling_start.empty else None

        phases = {
                'cycle_id': cycle_id,
                'push_start': push_start_time,
                'rising_start': rising_start_time,
                'stable_start': stable_start_time,
                'falling_start': falling_start_time
            }
        phase_list.append(phases)


    return phase_list


In [23]:
def plot_cycle_phases(df, phase_list, tbm_param='刀盘扭矩', buffer_minutes=0):
    '''
    绘制每个有效循环的图像，并标记各阶段起点
    参数:
        df: DataFrame，包含原始数据。
        phase_list: 阶段划分后的列表，每个元素包含 cycle_id 和各阶段的起点时间。
        tbm_param: 要绘制的参数列名（如'刀盘扭矩'、'推进速度'）。
        buffer_minutes: 缓冲区时间（分钟）。
    '''
    for phase in phase_list:
        cycle_id = phase['cycle_id']
        push_start = phase['push_start']
        rising_start = phase['rising_start']
        stable_start = phase['stable_start']
        falling_start = phase['falling_start']

        # 定义循环的开始和结束时间，包含缓冲区前1分钟
        # 假设 cycle['starttime'] 和 cycle['endtime'] 存在
        # 但在 phase_list 中只有 'push_start', 'rising_start' 等
        # 所以需要根据具体情况调整
        # 这里假设 'push_start' 是循环的开始
        start_time = push_start - pd.Timedelta(minutes=buffer_minutes) if push_start else None
        # 结束时间可以根据数据范围或其他逻辑确定
        # 这里假设循环结束时间为 'falling_start' 后的某个时间点
        end_time = falling_start + pd.Timedelta(minutes=buffer_minutes) if falling_start else None

        # 提取循环数据
        if start_time and end_time:
            cycle_data = df[(df['运行时间'] >= start_time) & (df['运行时间'] <= end_time)].copy()
        elif start_time:
            cycle_data = df[df['运行时间'] >= start_time].copy()
        elif end_time:
            cycle_data = df[df['运行时间'] <= end_time].copy()
        else:
            cycle_data = df.copy()

        # 检查是否有数据
        if cycle_data.empty:
            print(f"Cycle {cycle_id} has no data in the specified time range.")
            continue

        # 绘制参数曲线
        plt.figure(figsize=(12, 6),dpi=200)
        plt.plot(cycle_data['运行时间'], cycle_data[tbm_param], label=f'Cycle {cycle_id} {tbm_param}')

        # 标记各阶段起点
        if push_start:
            plt.axvline(push_start, color='green', linestyle='--', label='空推段起点')
        if rising_start:
            plt.axvline(rising_start, color='orange', linestyle='--', label='上升段起点')
        if stable_start:
            plt.axvline(stable_start, color='blue', linestyle='--', label='稳定段起点')
        if falling_start:
            plt.axvline(falling_start, color='red', linestyle='--', label='下降段起点')

        plt.xlabel('运行时间')
        plt.ylabel(tbm_param)
        plt.title(f'{tbm_param} - 掘进时间关系图 (Cycle {cycle_id})')
        plt.legend()
        plt.grid(True)
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()

In [24]:
import shutil

In [25]:
def unzip_file(zip_path = '/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data' ,unzip_path = '/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/un_zip',type = '.zip'):
    #确保输出文件夹存在
    chdir = os.chdir(zip_path)
    if not os.path.exists(unzip_path):
        try:
            os.mkdir(unzip_path)
        except Exception as e:
            print(f'无法创建目标文件夹{unzip_path}')
    for file in os.listdir(zip_path):
        if file.endswith(type):
            shutil.unpack_archive(file,unzip_path)

In [26]:
unzip_file()

In [27]:
print(os.listdir('/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/un_zip'))

['CREC188_20150713.txt', 'CREC188_20150707.txt', 'CREC188_20150712.txt', 'CREC188_20150710.txt', 'CREC188_20170629.txt', 'CREC188_20150711.txt', 'CREC188_20150708.txt', 'CREC188_20150709.txt']


In [28]:
list_of_bc = os.listdir('/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/un_zip')

In [29]:
list_of_bc

['CREC188_20150713.txt',
 'CREC188_20150707.txt',
 'CREC188_20150712.txt',
 'CREC188_20150710.txt',
 'CREC188_20170629.txt',
 'CREC188_20150711.txt',
 'CREC188_20150708.txt',
 'CREC188_20150709.txt']

In [30]:
import os
import pandas as pd

def save_process(df, phase_list, output_folder, start_cycle_id=1):
    '''
    读取TBM数据文件，并对每个循环重新标注 cycle_id（保证全局连续）
    参数:
        df：包含了提取循环后的所有数据
        phase_list：记录掘进分段节点时间的列表，
                    例如：[{'cycle_id': 2, 'push_start': Timestamp(...),
                             'rising_start': Timestamp(...),
                             'stable_start': Timestamp(...),
                             'falling_start': Timestamp(...)}]
        output_folder: 文件输出路径
        start_cycle_id: 本次处理的起始 cycle_id
    返回:
        combined_df: 处理后的DataFrame
        next_cycle_id: 下一个可用的 cycle_id（即全局连续计数器的最新值）
    '''
    all_cycle_data = []
    current_cycle_id = start_cycle_id  # 使用传入的初始值

    for phase in phase_list:
        # 使用全局的 current_cycle_id 而非 phase['cycle_id']
        cycle_id = current_cycle_id
        current_cycle_id += 1

        rising_start = phase['rising_start']
        stable_start = phase['stable_start']
        falling_start = phase['falling_start']

        # 标记空推段，只需要用上升段预测稳定段
        rising_data = df[(df['运行时间'] > rising_start) & (df['运行时间'] < stable_start)].copy()
        stable_data = df[(df['运行时间'] > stable_start) & (df['运行时间'] < falling_start)].copy()

        rising_data['cycle_id'] = cycle_id
        stable_data['cycle_id'] = cycle_id
        rising_data['phase_label'] = 'rising'
        stable_data['phase_label'] = 'stable'

        all_cycle_data.append(rising_data)
        all_cycle_data.append(stable_data)

    os.makedirs(output_folder, exist_ok=True)
    basename = 'data'
    output_path = os.path.join(output_folder, f'{basename}_process.csv')

    if all_cycle_data:
        combined_df = pd.concat(all_cycle_data, ignore_index=True)
        # 判断输出文件是否存在，存在则追加，否则写入表头
        if not os.path.exists(output_path):
            combined_df.to_csv(path_or_buf=output_path, index=False, mode='w', header=True)
        else:
            combined_df.to_csv(path_or_buf=output_path, index=False, mode='a', header=False)
        return combined_df, current_cycle_id
    else:
        return None, start_cycle_id

In [None]:
def save_process4test(df,phase_list,output_folder):
    '''
    读取TBM数据文件
    参数:
        output_folder: 文件输出路径
    返回:
        df: 处理后的DataFrame
    '''
    all_cycle_data = []
    for phase in phase_list:
        push_start = phase['push_start']
        rising_start = phase['rising_start']
        stable_start = phase['stable_start']
        falling_start = phase['falling_start']

        buffer_minutes = 0
        start_time = push_start - pd.Timedelta(minutes=buffer_minutes) if push_start else None
        # 结束时间可以根据数据范围或其他逻辑确定
        # 这里假设循环结束时间为 'falling_start' 后的某个时间点
        end_time = falling_start + pd.Timedelta(minutes=buffer_minutes) if falling_start else None

        #标记空推段，只需要用上升段预测稳定段
        push_data = df[(df['运行时间']>start_time)&(df['运行时间']<end_time)]

        all_cycle_data.append(push_data)
    os.makedirs(output_folder,exist_ok = True)
    basename = 'data'
    output_path = os.path.join(output_folder,f'{basename}_process.csv')
    if all_cycle_data:
        combined_df = pd.concat(all_cycle_data,ignore_index=True)
        #确保有输出文件夹
        if not os.path.exists(output_folder):
            combined_df.to_csv(path_or_buf=output_path,index = False,mode='w')
        else:
            #文件已经存在跳过header
            combined_df.to_csv(path_or_buf=output_path,index = False,mode='a')

In [31]:
data = []
global_cycle_id = 1 #设置全局cycle_id值
input_folder = '/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/un_zip'
output_folder = '/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/processed'
for file in list_of_bc:
    df = read_tbm_data('/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/un_zip/'+file)
    #取出这天的循环列表，方便我后面用手段对数据进行划分
    cycle_list = tbm_cycle_judge(df)
    df = data_presv(df,cycle_list,col_names=['推进速度','贯入度','刀盘扭矩','刀盘转速','刀盘给定转速显示值','推进速度给定百分比'],outlier_window = 5,smooth_window = 5)
    phase_list = classify_excavation_phase(df,cycle_list,T_threshold = 0,Tf_threshold = 600,Ff_threshold = 1000,sigma_threshold = 20,duration_threshold = 60,v_set_rate = 0)
    processed_df,global_cycle_id = save_process(df,phase_list,output_folder,start_cycle_id=global_cycle_id)
    data += phase_list

    #plt.figure(figsize=(12, 6),dpi=200)
    #plt.plot(df['运行时间'], df['刀盘转速'])

  df = pd.read_csv(file_path, sep='\t',index_col=False)
  df = pd.read_csv(file_path, sep='\t',index_col=False)
  df = pd.read_csv(file_path, sep='\t',index_col=False)
  df = pd.read_csv(file_path, sep='\t',index_col=False)
  df = pd.read_csv(file_path, sep='\t',index_col=False)
  df = pd.read_csv(file_path, sep='\t',index_col=False)
  df = pd.read_csv(file_path, sep='\t',index_col=False)
  df = pd.read_csv(file_path, sep='\t',index_col=False)


In [None]:
data = []
input_folder = '/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/un_zip'
output_folder = '/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/processed'
for file in list_of_bc:
    df = read_tbm_data('/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/un_zip/'+file)
    #取出这天的循环列表，方便我后面用手段对数据进行划分
    cycle_list = tbm_cycle_judge(df)
    df = data_presv(df,cycle_list,col_names=['推进速度','贯入度','刀盘扭矩','刀盘转速','刀盘给定转速显示值',''],outlier_window = 5,smooth_window = 5)
    #save_process(df,file,output_folder,)
    phase_list = classify_excavation_phase(df,cycle_list,T_threshold = 0,Tf_threshold = 600,Ff_threshold = 1000,sigma_threshold = 20,duration_threshold = 180,v_set_rate = 0)
    save_process4test(df,phase_list,output_folder)
    data += phase_list
    #批量画图用以修改阈值
    plot_cycle_phases(df, phase_list, tbm_param='刀盘扭矩', buffer_minutes=0)
    plt.figure(figsize=(12, 6),dpi=200)
    #plt.plot(df['运行时间'], df['刀盘转速'])

In [6]:
fp = '/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/processed/data_process'
df = pd.read_csv(fp, sep=',',index_col=False)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/xudongzuo/Library/CloudStorage/OneDrive-个人/文档/workspace/test_data/processed/data_process'

In [None]:
df

In [None]:
data

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 假设我们使用上升段的时间序列数据作为输入
# 定义序列长度（时间步数）
time_steps = 60  # 例如，60秒的序列

# 提取上升段序列数据
# 假设上升段的时间范围内有足够的数据点
input_features = ['刀盘转速', '推进速度', '总推进力', '刀盘扭矩', '贯入度']
output_features = ['avg_stable_torque', 'avg_stable_thrust']

# 创建输入和输出列表
X = []
y = []

for phase in data:
    cycle_id = phase['cycle_id']
    rising_start = phase['rising_start']
    stable_start = phase['stable_start']
    falling_start = phase['falling_start']

    # 提取上升段数据
    rising_data = df[(df['运行时间'] >= rising_start) & (df['运行时间'] < stable_start)]

    # 确保有足够的时间步
    if len(rising_data) < time_steps:
        continue  # 或者使用填充
    else:
        # 滚动窗口生成序列
        for i in range(len(rising_data) - time_steps + 1):
            seq = rising_data.iloc[i:i+time_steps][input_features].values
            X.append(seq)

            # 对应的稳定段参数
            stable_data = df[(df['运行时间'] >= stable_start) & (df['运行时间'] < falling_start)]
            if not stable_data.empty:
                avg_stable_torque = stable_data['刀盘扭矩'].mean()
                avg_stable_thrust = stable_data['总推进力'].mean()
                y.append([avg_stable_torque, avg_stable_thrust])
            else:
                y.append([np.nan, np.nan])  # 处理缺失值

# 转换为numpy数组
X = np.array(X)
y = np.array(y)

# 移除包含NaN的样本
valid_indices = ~np.isnan(y).any(axis=1)
X = X[valid_indices]
y = y[valid_indices]

print(f'Input shape: {X.shape}')
print(f'Output shape: {y.shape}')

In [None]:
import scipy.io as sio
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
from torch.autograd import Variable
import math
import csv

# Define LSTM Neural Networks
class LstmRNN(nn.Module):
    """
        Parameters：
        - input_size: feature size
        - hidden_size: number of hidden units
        - output_size: number of output
        - num_layers: layers of LSTM to stack
    """

    def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):
        super().__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # utilize the LSTM model in torch.nn
        self.linear1 = nn.Linear(hidden_size, output_size) # 全连接层

    def forward(self, _x):
        x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)
        s, b, h = x.shape  # x is output, size (seq_len, batch, hidden_size)
        x = self.linear1(x)
        return x[-1, :, :]

if __name__ == '__main__':

    # checking if GPU is available
    device = torch.device("cpu")

    if (torch.cuda.is_available()):
        device = torch.device("cuda:0")
        print('Training on GPU.')
    else:
        print('No GPU available, training on CPU.')

    # 数据读取&类型转换
    data_x = np.array(pd.read_csv('Data_x.csv', header=None)).astype('float32')
    data_y = np.array(pd.read_csv('Data_y.csv', header=None)).astype('float32')

    # 数据集分割
    data_len = len(data_x)
    t = np.linspace(0, data_len, data_len + 1)

    train_data_ratio = 0.8  # Choose 80% of the data for training
    train_data_len = int(data_len * train_data_ratio)

    train_x = data_x[5:train_data_len]
    train_y = data_y[5:train_data_len]
    t_for_training = t[5:train_data_len]

    test_x = data_x[train_data_len:]
    test_y = data_y[train_data_len:]
    t_for_testing = t[train_data_len:]

    # ----------------- train -------------------
    INPUT_FEATURES_NUM = 1
    OUTPUT_FEATURES_NUM = 1
    train_x_tensor = train_x.reshape(5, -1, INPUT_FEATURES_NUM)
    train_y_tensor = train_y.reshape(1, OUTPUT_FEATURES_NUM)
    # transfer data to pytorch tensor
    train_x_tensor = torch.from_numpy(train_x_tensor)
    train_y_tensor = torch.from_numpy(train_y_tensor)

    lstm_model = LstmRNN(INPUT_FEATURES_NUM, 20, output_size=OUTPUT_FEATURES_NUM, num_layers=1)  # 20 hidden units
    print('LSTM model:', lstm_model)
    print('model.parameters:', lstm_model.parameters)
    print('train x tensor dimension:', Variable(train_x_tensor).size())

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(lstm_model.parameters(), lr=1e-2)

    prev_loss = 1000
    max_epochs = 2000

    train_x_tensor = train_x_tensor.to(device)

    for epoch in range(max_epochs):
        output = lstm_model(train_x_tensor).to(device)
        loss = criterion(output, train_y_tensor)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if loss < prev_loss:
            torch.save(lstm_model.state_dict(), 'lstm_model.pt')  # save model parameters to files
            prev_loss = loss

        if loss.item() < 1e-4:
            print('Epoch [{}/{}], Loss: {:.5f}'.format(epoch + 1, max_epochs, loss.item()))
            print("The loss value is reached")
            break
        elif (epoch + 1) % 100 == 0:
            print('Epoch: [{}/{}], Loss:{:.5f}'.format(epoch + 1, max_epochs, loss.item()))

    # prediction on training dataset
    pred_y_for_train = lstm_model(train_x_tensor).to(device)
    pred_y_for_train = pred_y_for_train.view(-1, OUTPUT_FEATURES_NUM).data.numpy()

    # ----------------- test -------------------
    lstm_model = lstm_model.eval()  # switch to testing model

    # prediction on test dataset
    test_x_tensor = test_x.reshape(5, -1, INPUT_FEATURES_NUM)
    test_x_tensor = torch.from_numpy(test_x_tensor)  # 变为tensor
    test_x_tensor = test_x_tensor.to(device)

    pred_y_for_test = lstm_model(test_x_tensor).to(device)
    pred_y_for_test = pred_y_for_test.view(-1, OUTPUT_FEATURES_NUM).data.numpy()

    loss = criterion(torch.from_numpy(pred_y_for_test), torch.from_numpy(test_y))
    print("test loss：", loss.item())

    # ----------------- plot -------------------
    plt.figure()
    plt.plot(t_for_training, train_y, 'b', label='y_trn')
    plt.plot(t_for_training, pred_y_for_train, 'y--', label='pre_trn')

    plt.plot(t_for_testing, test_y, 'k', label='y_tst')
    plt.plot(t_for_testing, pred_y_for_test, 'm--', label='pre_tst')

    plt.xlabel('t')
    plt.ylabel('Vce')
    plt.show()