### Get the E.coli Genome list 

In [135]:
from Bio import SeqIO

# 定义读取FASTA文件并返回基因-序列字典的函数
def read_fasta(file_path):
    gene_sequence_dict = {}
    
    # 使用SeqIO解析FASTA文件
    for record in SeqIO.parse(file_path, "fasta"):
        # 获取基因名称（FASTA文件的标头部分）
        gene_name = record.id
        # 获取蛋白质序列（序列部分）
        sequence = str(record.seq)
        # 将基因-序列的键值对添加到字典中
        gene_sequence_dict[gene_name] = sequence

    return gene_sequence_dict

# 读取FASTA文件并生成字典
file_path = "/data1/xpgeng/cross_pathogen/autoencoder/E.coli.tag_seq.fasta"  # 请替换为你的FASTA文件路径
gene_sequence_dict = read_fasta(file_path)

# 打印字典的前几个项以确认
for gene, sequence in list(gene_sequence_dict.items())[:5]:
    print(f"Gene: {gene}, Sequence: {sequence[:30]}...")  # 只打印前30个氨基酸

Gene: b0001, Sequence: MKRISTTITTTITITTGNGAG...
Gene: b0002, Sequence: MRVLKFGGTSVANAERFLRVADILESNARQ...
Gene: b0003, Sequence: MVKVYAPASSANMSVGFDVLGAAVTPVDGA...
Gene: b0004, Sequence: MKLYNLKDHNEQVSFAQAVTQGLGKNQGLF...
Gene: b0005, Sequence: MKKMQSIVLALSLVLVAPMAAQAAEITLVP...


In [136]:
all_genes = set(list(gene_sequence_dict.keys()))

In [137]:
import numpy as np
from collections import Counter
from sklearn.decomposition import PCA

# 定义所有可能的2-mer组合
standard_amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
all_2mers = [a + b for a in standard_amino_acids for b in standard_amino_acids]
two_mer_index = {two_mer: idx for idx, two_mer in enumerate(all_2mers)}

# 生成基因2-mer特征字典
two_mer_dict = {}

# 遍历每个基因和序列
for gene, sequence in gene_sequence_dict.items():
    # 清洗序列，移除非标准氨基酸字符
    sequence = ''.join([aa for aa in sequence if aa in standard_amino_acids])
    
    # 计算2-mer出现次数
    two_mer_counts = Counter([sequence[i:i+2] for i in range(len(sequence)-1)])
    
    # 计算2-mer的总数
    total_two_mers = sum(two_mer_counts.values())
    
    # 初始化400维的零向量
    feature_vector = np.zeros(400)
    
    # 将2-mer的频率映射到向量的对应位置
    for two_mer, count in two_mer_counts.items():
        if two_mer in two_mer_index:
            # 计算频率而不是计数
            frequency = count / total_two_mers
            #frequency = count
            feature_vector[two_mer_index[two_mer]] = frequency
            
    # 将计算的特征向量保存到字典中
    two_mer_dict[gene] = feature_vector

# 打印前5个基因的2-mer特征查看
for gene, feature_vector in list(two_mer_dict.items())[:5]:
    print(f"Gene: {gene}")
    print(f"2-mer Feature Vector: {feature_vector}")


Gene: b0001
2-mer Feature Vector: [0.   0.   0.   0.   0.   0.05 0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.05 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.05
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.05 0.15 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.05 0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0

In [37]:
len(two_mer_dict)

4305

### test file 

In [66]:
import pandas as pd

# 输入文件路径
input_file = '/data1/xpgeng/cross_pathogen/MLP/iML1515_test.csv'

# 输出文件路径
output_file = 'all.csv'

# 读取 CSV 文件，跳过第一行
df = pd.read_csv(input_file, skiprows=1)

# 提取前三列基因和最后一列 Label
df_filtered = df.iloc[:, [0, 1, 2, -1]]

# 过滤掉包含指定字符串的行
unwanted_strings = {'b2092', 'Fail', 'b4104'}
df_filtered = df_filtered[~df_filtered.iloc[:, :3].apply(lambda row: any(gene in unwanted_strings for gene in row), axis=1)]

# 将处理后的数据保存为新的 CSV 文件
df_filtered.to_csv(output_file, index=False)

print(f"Data has been processed and saved to {output_file}")


Data has been processed and saved to all.csv


In [67]:
import pandas as pd
import os

# 读取最终合并的 CSV 文件
input_file = 'all.csv'
output_folder = '/data1/xpgeng/cross_pathogen/MLP'

# 确保输出文件夹存在
os.makedirs(output_folder, exist_ok=True)

# 读取 CSV 文件，假设没有表头
df = pd.read_csv(input_file, header=None)

# 打乱数据
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

# 计算每份数据的大小
num_splits = 10
split_size = len(df) // num_splits

# 划分数据并保存
for i in range(num_splits):
    # 为每一份计算开始和结束的索引
    start_idx = i * split_size
    # 最后一份可能包含剩余所有数据
    end_idx = (i + 1) * split_size if i < num_splits - 1 else len(df)
    
    # 获取数据
    df_split = df.iloc[start_idx:end_idx]

    # 输出文件名
    output_file = os.path.join(output_folder, f'iML1515-{i+1}.csv')
    
    # 保存数据到文件
    df_split.to_csv(output_file, index=False, header=False)

print(f"Data successfully split into {num_splits} parts and saved in {output_folder}")


Data successfully split into 10 parts and saved in /data1/xpgeng/cross_pathogen/MLP


In [68]:
import pandas as pd
import os

# 文件夹路径
folder_path = '/data1/xpgeng/cross_pathogen/MLP/'

# 初始化行数总和
total_rows = 0

# 遍历文件夹中的所有 CSV 文件
for filename in os.listdir(folder_path):
    if filename.endswith('.csv'):  # 筛选 .csv 文件
        file_path = os.path.join(folder_path, filename)
        
        # 读取 CSV 文件
        df = pd.read_csv(file_path, header=None)  # 如果没有表头，使用 header=None
        
        # 计算当前文件的行数
        num_rows = len(df)
        
        # 更新总行数
        total_rows += num_rows
        
        # 输出文件名和前 3 行数据
        print(f"First 3 rows of {filename}:")
        print(df.head(3))  # 显示前 3 行
        print(f"Total rows in {filename}: {num_rows}")
        print("\n" + "="*50 + "\n")

# 输出所有文件的总行数
print(f"Total rows across all files: {total_rows}")

  df = pd.read_csv(file_path, header=None)  # 如果没有表头，使用 header=None


First 3 rows of iML1515_test.csv:
       0      1      2               3               4      5
0   Gene  Gene1  Gene2  Initial Growth  Altered Growth  Label
1  b0070  b0071  b0072           0.877             0.0      1
2  b0070  b0071  b0073           0.877             0.0      1
Total rows in iML1515_test.csv: 1100387


First 3 rows of all.csv:
       0      1      2  3
0  b0070  b0071  b0072  1
1  b0070  b0071  b0073  1
2  b0070  b0071  b0074  1
Total rows in all.csv: 1097421


First 3 rows of iML1515-1.csv:
       0      1      2  3
0  b0070  b0857  b3480  0
1  b0070  b2521  b4055  0
2  b0070  b0652  b3450  0
Total rows in iML1515-1.csv: 109742


First 3 rows of iML1515-2.csv:
       0      1      2  3
0  b0070  b0914  b2539  1
1  b0070  b1798  b4266  0
2  b0070  b1686  b2285  0
Total rows in iML1515-2.csv: 109742


First 3 rows of iML1515-3.csv:
       0      1      2  3
0  b0070  b0674  b2222  0
1  b0070  b1448  b2480  0
2  b0070  b0446  b3994  1
Total rows in iML1515-3.csv: 1097

#### Define autoencoder to process genome 

In [138]:
import torch
import numpy as np

# Define the model
class Autoencoder(torch.nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(400, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.35),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.35),
            torch.nn.Linear(128, 3),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(3, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 400),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

class Autoencoder2(torch.nn.Module):
    def __init__(self):
        super(Autoencoder2, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(4304, 3000),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(3000, 1000),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(1000, 400),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(400, 1000),
            torch.nn.ReLU(),
            torch.nn.Linear(1000, 3000),
            torch.nn.ReLU(),
            torch.nn.Linear(3000, 4304),
        )
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Load the models once
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = Autoencoder().to(device)
model.load_state_dict(torch.load('/data1/xpgeng/cross_pathogen/autoencoder/ae1_all_data_training.pth'))
model.eval()

model2 = Autoencoder2().to(device)
model2.load_state_dict(torch.load('/data1/xpgeng/cross_pathogen/autoencoder/ae2_all_data_training.pth'))
model2.eval()

def ae1_2(three_genes):
    rest_genes = list(all_genes - three_genes)
    inputs = np.vstack([two_mer_dict[gene] for gene in rest_genes]).astype(np.float32)

    # Create two rows of zeros and append them to the inputs
    zeros_400 = np.zeros((2, 400), dtype=np.float32)
    inputs = np.vstack([inputs, zeros_400])

    inputs = torch.tensor(inputs).to(device)
    inputs = model.encoder(inputs)  # autoencoder1
    inputs = inputs.cpu().detach().numpy().T  # transpose for autoencoder2 input
    inputs = torch.tensor(inputs).to(device)
    outputs = model2.encoder(inputs)  # autoencoder2 [3, 400] shape

    return outputs

In [96]:
ae1_2({'b4358','b0002','b4366'})

tensor([[-16.4019, -19.3173,  17.1869,  ...,  11.3169,  39.8056,  -7.2701],
        [-11.1142, -13.1415,  11.7600,  ...,   6.9131,  26.2965,  -5.5048],
        [ -9.2013, -10.5971,   9.5181,  ...,   6.2495,  21.4457,  -4.2572]],
       device='cuda:2', grad_fn=<AddmmBackward0>)

In [99]:
three_gene_features = np.array([two_mer_dict[gene] for gene in ['b0001','b4358','b4366']])
three_gene_features = three_gene_features.flatten()
three_gene_features = (three_gene_features - np.mean(three_gene_features)) / (np.std(three_gene_features) + 1e-8)  # 避免除以 0
print(three_gene_features)

rest_gene_features = ae1_2({'b0001','b4358','b4366'}).detach().cpu().numpy()
rest_gene_features = rest_gene_features.flatten()
rest_gene_features = (rest_gene_features - np.mean(rest_gene_features)) / (np.std(rest_gene_features) + 1e-8)  # 归一化，避免除 0
#rest_gene_features *= 0.2
print(rest_gene_features)

normalized_features = np.concatenate([three_gene_features, rest_gene_features])

[-0.25292426 -0.25292426 -0.25292426 ... -0.25292426 -0.25292426
 -0.25292426]
[-1.2737218  -1.5077536   1.4229503  ...  0.50894725  1.6286922
 -0.27716023]


### MLP model 10 fold Cross-validation

In [140]:
#%%time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import os
import traceback
from tqdm import tqdm
import time

# Check if CUDA device is available
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
log_file = os.path.join(os.getcwd(), "mlp_cross_training_log_10p.txt") ###########################
error_log_file = os.path.join(os.getcwd(), "mlp_cross_error_log_10p.txt") ###########################

# Redirect print to a file
def log_print(message):
    with open(log_file, "a") as f:
        f.write(message + "\n")
    print(message)

# Log error messages to a file
def log_error(message):
    with open(error_log_file, "a") as f:
        f.write(message + "\n")
    print(f"Error logged: {message}")  # Debugging line to ensure errors are logged
    
# Function to load and update del_twogenes files in batches
def load_three_genes(file_idx, batch_size):
    file_path = f'/data1/xpgeng/cross_pathogen/FBA/iML1515_parts/iML1515-{file_idx}.csv' ###########################
    #file_path = f'/data1/xpgeng/cross_pathogen/MLP/iML1515-{file_idx}.csv' ###########################
    df = pd.read_csv(file_path, header=None)
    
    # Split file into batches of size batch_size
    batches = [df.iloc[i:i + batch_size].values for i in range(0, len(df), batch_size)]
    
    return batches  # each batch contains three genes and a label

# Function to compute AUC
def compute_auc(predicted, labels):
    from sklearn.metrics import roc_auc_score
    return roc_auc_score(labels.cpu(), predicted.cpu())

# Define MLP model class
class MLP(nn.Module):
    def __init__(self, input_size=2400, hidden_size1=512, hidden_size2=256, output_size=1, dropout_rate=0.1):  #######
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(hidden_size2, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return self.sigmoid(x)

def train_mlp(train_data, train_labels, model, criterion, optimizer, num_epochs=4):  
    model.train()
    epoch_loss = 0.0

    optimizer.zero_grad()
    outputs = model(train_data)
    loss = criterion(outputs.view(-1), train_labels)
    loss.backward()
    optimizer.step()

    epoch_loss += loss.item()

    log_print(f"Loss: {epoch_loss / len(train_data):.6f}")  # Average loss over the epoch
    return model

In [45]:
# data preparation
import os
import pickle
import numpy as np
import torch
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor

# 已有函数 two_mer_dict 和 ae1_2 可用
# two_mer_dict[gene] 返回基因的 400 维向量
# ae1_2({gene1, gene2, gene3}) 返回 (3, 400) 维的张量

def process_and_save(file_idx):
    data = load_three_genes(file_idx, batch_size=60000)  # 读取文件，按 batch_size 分批

    for batch_idx, batch in enumerate(data)
        batch_data = []
        batch_labels = []
        
        for row in tqdm(batch, desc=f"Processing Fold {file_idx} - Batch {batch_idx+1}"):
            # 获取 three_gene_features 并进行归一化
            three_gene_features = np.array([two_mer_dict[gene] for gene in [row[0], row[1], row[2]]])
            three_gene_features = three_gene_features.flatten()
            three_gene_features = (three_gene_features - np.mean(three_gene_features)) / (np.std(three_gene_features) + 1e-8)  # 避免除以 0
                        
            rest_gene_features = ae1_2({row[0], row[1], row[2]}).detach().cpu().numpy() 
            rest_gene_features = rest_gene_features.flatten()
            rest_gene_features = (rest_gene_features - np.mean(rest_gene_features)) / (np.std(rest_gene_features) + 1e-8)  # 归一化，避免除 0
            #rest_gene_features *= 0.2

            normalized_features = np.concatenate([three_gene_features, rest_gene_features])

            batch_data.append(normalized_features)
            #batch_data.append(three_gene_features)
            batch_labels.append(row[3])  # 最后一列是 Label

         # 存储为 pickle 文件
        output_file = os.path.join("/data2/xpgeng/iML1515_MLP/", f"{file_idx}_{batch_idx+1}.pkl")
        with open(output_file, "wb") as f:
            pickle.dump((np.array(batch_data), np.array(batch_labels)), f)

        print(f"Saved")
        
process_and_save(1)
# Loop this from 1 to 10 to get intermidiate files for mlp training


KeyboardInterrupt



In [101]:
import os
import pickle
import numpy as np
import torch
from tqdm import tqdm

def cross_validate():
    all_indices = list(range(1, 11))   ##########
    total_auc = []

    log_print(f"Start time: {time.time()}")

    for fold in range(10):      ###########
        try:
            fold_auc = []
            train_files = [i for i in all_indices if i != fold + 1]  # 训练集 9 份
            test_file = fold + 1  # 测试集 1 份

            # 创建 MLP 模型
            model = MLP(input_size=2400).to(device)
            criterion = nn.BCELoss()
            optimizer = optim.Adam(model.parameters(), lr=0.001)

            # 训练 9 份数据
            for file_idx in train_files:
                log_print(f"Loading train file {file_idx}")

                batch_idx = 1
                while True:
                    pickle_file = f"/data2/xpgeng/iML1515_MLP/{file_idx}_{batch_idx}.pkl"
                    if not os.path.exists(pickle_file):
                        break  # 没有更多批次数据

                    with open(pickle_file, "rb") as f:
                        batch_data, batch_labels = pickle.load(f)
                    
                    # 转换为张量
                    batch_data = torch.tensor(batch_data, dtype=torch.float32).to(device)
                    batch_labels = torch.tensor(batch_labels, dtype=torch.float32).to(device)
                    
                    # Split into smaller batches of size 200
                    for i in range(0, len(batch_data), 200):
                        small_batch_data = batch_data[i:i+200]
                        small_batch_labels = batch_labels[i:i+200]
                        
                        # Train the model on this smaller batch
                        model = train_mlp(small_batch_data, small_batch_labels, model, criterion, optimizer, num_epochs=4)

                    log_print(f"Trained on {file_idx} - Batch {batch_idx}")
                    batch_idx += 1

            # 读取测试数据
            log_print(f"Loading test file {test_file}")
            
            batch_idx = 1
            while True:
                pickle_file = f"/data2/xpgeng/iML1515_MLP/{file_idx}_{batch_idx}.pkl"
                if not os.path.exists(pickle_file):
                    break  # 没有更多批次数据

                with open(pickle_file, "rb") as f:
                    test_data, test_labels = pickle.load(f)

                # 转换为张量
                test_data = torch.tensor(np.array(test_data), dtype=torch.float32).to(device)
                test_labels = torch.tensor(np.array(test_labels), dtype=torch.float32).to(device)
                
                # 评估
                with torch.no_grad():
                    model.eval()
                    outputs = model(test_data)
                    probabilities = torch.sigmoid(outputs.view(-1))  
                    auc = compute_auc(probabilities, test_labels)
                    fold_auc.append(auc)
                    total_auc.append(auc)
                
                batch_idx += 1

            log_print(f"Fold {fold+1} AUC: {np.mean(fold_auc):.6f}")

        except Exception as e:
            error_message = f"Error in fold {fold+1}: {str(e)}\n{traceback.format_exc()}"
            log_error(error_message)

    log_print(f"Average AUC over 10 folds: {np.mean(total_auc):.6f}")
    log_print(f"End time: {time.time()}")

cross_validate()

Start time: 1747057279.0576444
Loading train file 2
Loss: 0.003517
Loss: 0.004471
Loss: 0.003148
Loss: 0.003382
Loss: 0.003313
Loss: 0.003288
Loss: 0.003364
Loss: 0.003316
Loss: 0.003148
Loss: 0.003080
Loss: 0.003091
Loss: 0.003160
Loss: 0.003604
Loss: 0.003307
Loss: 0.003166
Loss: 0.003206
Loss: 0.003253
Loss: 0.003155
Loss: 0.003132
Loss: 0.003127
Loss: 0.003185
Loss: 0.003253
Loss: 0.003022
Loss: 0.002997
Loss: 0.003036
Loss: 0.003141
Loss: 0.003185
Loss: 0.003071
Loss: 0.002922
Loss: 0.002905
Loss: 0.003147
Loss: 0.003082
Loss: 0.002858
Loss: 0.002983
Loss: 0.002993
Loss: 0.003068
Loss: 0.002958
Loss: 0.002975
Loss: 0.003251
Loss: 0.002890
Loss: 0.003225
Loss: 0.003107
Loss: 0.002988
Loss: 0.003126
Loss: 0.002932
Loss: 0.002900
Loss: 0.003109
Loss: 0.002991
Loss: 0.003015
Loss: 0.002794
Loss: 0.003020
Loss: 0.002908
Loss: 0.002762
Loss: 0.003346
Loss: 0.002948
Loss: 0.003010
Loss: 0.002778
Loss: 0.002828
Loss: 0.002532
Loss: 0.002907
Loss: 0.002860
Loss: 0.003017
Loss: 0.002734
Los

### MLP model train all data

In [142]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import os
import traceback
from tqdm import tqdm
import time

# Check if CUDA device is available
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
log_file = os.path.join(os.getcwd(), "mlp_all_training_log_10p.txt")          ############
error_log_file = os.path.join(os.getcwd(), "mlp_all_error_log_10p.txt")       ############

# Redirect print to a file
def log_print(message):
    with open(log_file, "a") as f:
        f.write(message + "\n")
    print(message)

# Log error messages to a file
def log_error(message):
    with open(error_log_file, "a") as f:
        f.write(message + "\n")
    print(f"Error logged: {message}")  # Debugging line to ensure errors are logged
    
# Function to load and update del_twogenes files in batches
def load_three_genes(file_idx, batch_size):
    file_path = f'/data1/xpgeng/cross_pathogen/FBA/iML1515_parts/iML1515-{file_idx}.csv' ###########################
    #file_path = f'/data1/xpgeng/cross_pathogen/MLP/iML1515-{file_idx}.csv' ###########################
    df = pd.read_csv(file_path, header=None)
    
    # Split file into batches of size batch_size
    batches = [df.iloc[i:i + batch_size].values for i in range(0, len(df), batch_size)]
    
    return batches  # each batch contains three genes and a label

# Function to compute AUC
def compute_auc(predicted, labels):
    from sklearn.metrics import roc_auc_score
    return roc_auc_score(labels.cpu(), predicted.cpu())

# Define MLP model class
class MLP(nn.Module):
    def __init__(self, input_size=2400, hidden_size1=512, hidden_size2=256, output_size=1, dropout_rate=0.1):  #######
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(hidden_size2, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return self.sigmoid(x)

def train_mlp(train_data, train_labels, model, criterion, optimizer, num_epochs=4):  
    model.train()
    epoch_loss = 0.0

    optimizer.zero_grad()
    outputs = model(train_data)
    loss = criterion(outputs.view(-1), train_labels)
    loss.backward()
    optimizer.step()

    epoch_loss += loss.item()

    log_print(f"Loss: {epoch_loss / len(train_data):.6f}")  # Average loss over the epoch
    return model

In [None]:
def mlp_all():
    all_indices = list(range(1, 11))   ##########
    log_print(f"Start time: {time.time()}")

    for fold in range(1):      ###########
        try:
            train_files = all_indices
          
            # 创建 MLP 模型
            model = MLP(input_size=2400).to(device)
            criterion = nn.BCELoss()
            optimizer = optim.Adam(model.parameters(), lr=0.001)

            # train all data
            for file_idx in train_files:
                log_print(f"Loading train file {file_idx}")

                batch_idx = 1
                while True:
                    pickle_file = f"/data2/xpgeng/iML1515_MLP/{file_idx}_{batch_idx}.pkl"  ####
                    if not os.path.exists(pickle_file):
                        break  # 没有更多批次数据
                        
                    with open(pickle_file, "rb") as f:
                        batch_data, batch_labels = pickle.load(f)

                    # 转换为张量
                    batch_data = torch.tensor(batch_data, dtype=torch.float32).to(device)
                    batch_labels = torch.tensor(batch_labels, dtype=torch.float32).to(device)
                    
                    # Split into smaller batches of size 200
                    for i in range(0, len(batch_data), 200):
                        small_batch_data = batch_data[i:i+200]
                        small_batch_labels = batch_labels[i:i+200]
                        
                        # Train the model on this smaller batch
                        model = train_mlp(small_batch_data, small_batch_labels, model, criterion, optimizer, num_epochs=4)
                   
                    log_print(f"Trained on {file_idx} - Batch {batch_idx}")
                    batch_idx += 1
                    
            # Save model parameters
            model_save_path = os.path.join(os.getcwd(), f"mlp_all_data_10p.pth")   ### model name
            torch.save(model.state_dict(), model_save_path)

        except Exception as e:
            error_message = f"Error in fold {fold+1}: {str(e)}\n{traceback.format_exc()}"
            log_error(error_message)

    log_print(f"End time: {time.time()}")

mlp_all()

Start time: 1747639679.9935484
Loading train file 1
Loss: 0.003425
Loss: 0.005545
Loss: 0.002974
Loss: 0.003338
Loss: 0.003167
Loss: 0.003449
Loss: 0.003292
Loss: 0.003067
Loss: 0.003330
Loss: 0.003294
Loss: 0.003280
Loss: 0.003103
Loss: 0.003209
Loss: 0.003253
Loss: 0.003399
Loss: 0.003030
Loss: 0.003073
Loss: 0.003150
Loss: 0.003164
Loss: 0.003159
Loss: 0.003095
Loss: 0.003027
Loss: 0.003203
Loss: 0.003208
Loss: 0.003046
Loss: 0.003086
Loss: 0.003235
Loss: 0.003014
Loss: 0.003186
Loss: 0.003198
Loss: 0.003021
Loss: 0.003010
Loss: 0.003115
Loss: 0.002738
Loss: 0.003184
Loss: 0.002956
Loss: 0.002879
Loss: 0.002925
Loss: 0.002864
Loss: 0.003215
Loss: 0.002911
Loss: 0.002919
Loss: 0.002868
Loss: 0.002955
Loss: 0.002841
Loss: 0.003078
Loss: 0.002920
Loss: 0.003110
Loss: 0.002800
Loss: 0.002839
Loss: 0.003162
Loss: 0.002775
Loss: 0.002869
Loss: 0.002662
Loss: 0.002952
Loss: 0.002757
Loss: 0.002783
Loss: 0.002906
Loss: 0.002764
Loss: 0.002774
Loss: 0.002409
Loss: 0.002876
Loss: 0.002624
Los

Loss: 0.000268
Loss: 0.000092
Loss: 0.000271
Loss: 0.000288
Loss: 0.000266
Loss: 0.000228
Loss: 0.000171
Loss: 0.000086
Loss: 0.000311
Loss: 0.000143
Loss: 0.000428
Loss: 0.000234
Loss: 0.000398
Loss: 0.000199
Loss: 0.000237
Loss: 0.000095
Loss: 0.000180
Loss: 0.000102
Loss: 0.000210
Loss: 0.000107
Loss: 0.000170
Loss: 0.000143
Loss: 0.000067
Loss: 0.000268
Loss: 0.000219
Loss: 0.000294
Loss: 0.000261
Loss: 0.000236
Loss: 0.000156
Loss: 0.000152
Loss: 0.000077
Loss: 0.000113
Loss: 0.000300
Loss: 0.000289
Loss: 0.000282
Loss: 0.000077
Loss: 0.000205
Loss: 0.000040
Loss: 0.000201
Loss: 0.000344
Loss: 0.000224
Loss: 0.000118
Loss: 0.000108
Loss: 0.000057
Loss: 0.000144
Loss: 0.000229
Loss: 0.000354
Loss: 0.000206
Loss: 0.000087
Loss: 0.000094
Trained on 1 - Batch 36
Loss: 0.000158
Loss: 0.000264
Loss: 0.000119
Loss: 0.000159
Loss: 0.000288
Loss: 0.000237
Loss: 0.000086
Loss: 0.000161
Loss: 0.000161
Loss: 0.000204
Loss: 0.000136
Loss: 0.000143
Loss: 0.000250
Loss: 0.000487
Loss: 0.000243
L

Loss: 0.000186
Loss: 0.000451
Loss: 0.000221
Loss: 0.000279
Loss: 0.000256
Loss: 0.000141
Loss: 0.000061
Loss: 0.000166
Loss: 0.000080
Loss: 0.000204
Loss: 0.000195
Loss: 0.000382
Loss: 0.000240
Loss: 0.000175
Loss: 0.000376
Loss: 0.000204
Loss: 0.000512
Loss: 0.000047
Loss: 0.000068
Loss: 0.000315
Loss: 0.000102
Loss: 0.000094
Loss: 0.000138
Loss: 0.000214
Loss: 0.000058
Loss: 0.000531
Trained on 1 - Batch 38
Loss: 0.000070
Loss: 0.000193
Loss: 0.000090
Loss: 0.000154
Loss: 0.000128
Loss: 0.000254
Loss: 0.000181
Loss: 0.000744
Loss: 0.000372
Loss: 0.000170
Loss: 0.000376
Loss: 0.000240
Loss: 0.000266
Loss: 0.000302
Loss: 0.000400
Loss: 0.000290
Loss: 0.000336
Loss: 0.000232
Loss: 0.000261
Loss: 0.000198
Loss: 0.000313
Loss: 0.000193
Loss: 0.000121
Loss: 0.000101
Loss: 0.000176
Loss: 0.000267
Loss: 0.000541
Loss: 0.000031
Loss: 0.000146
Loss: 0.000158
Loss: 0.000245
Loss: 0.000278
Loss: 0.000063
Loss: 0.000409
Loss: 0.000282
Loss: 0.000610
Loss: 0.000150
Loss: 0.000287
Loss: 0.000295
L

Loss: 0.000045
Loss: 0.000349
Loss: 0.000098
Loss: 0.000253
Loss: 0.000392
Loss: 0.000164
Loss: 0.000163
Loss: 0.000365
Loss: 0.000085
Loss: 0.000094
Loss: 0.000140
Loss: 0.000224
Loss: 0.000272
Loss: 0.000100
Loss: 0.000169
Loss: 0.000143
Loss: 0.000056
Loss: 0.000193
Loss: 0.000145
Loss: 0.000120
Loss: 0.000145
Loss: 0.000157
Loss: 0.000156
Loss: 0.000156
Loss: 0.000569
Loss: 0.000133
Loss: 0.000090
Loss: 0.000150
Loss: 0.000030
Loss: 0.000274
Loss: 0.000476
Loss: 0.000132
Loss: 0.000048
Loss: 0.000082
Loss: 0.000084
Loss: 0.000223
Loss: 0.000256
Loss: 0.000123
Loss: 0.000037
Loss: 0.000329
Loss: 0.000123
Trained on 1 - Batch 40
Loss: 0.000102
Loss: 0.000125
Loss: 0.000190
Loss: 0.000339
Loss: 0.000374
Loss: 0.000019
Loss: 0.000125
Loss: 0.000122
Loss: 0.000381
Loss: 0.000318
Loss: 0.000112
Loss: 0.000208
Loss: 0.000241
Loss: 0.000305
Loss: 0.000256
Loss: 0.000070
Loss: 0.000315
Loss: 0.000052
Loss: 0.000330
Loss: 0.000172
Loss: 0.000313
Loss: 0.000115
Loss: 0.000373
Loss: 0.000321
L

Loss: 0.000038
Loss: 0.000294
Loss: 0.000175
Loss: 0.000364
Loss: 0.000178
Loss: 0.000228
Loss: 0.000331
Loss: 0.000338
Loss: 0.000358
Loss: 0.000233
Loss: 0.000125
Loss: 0.000114
Loss: 0.000486
Loss: 0.000111
Loss: 0.000269
Loss: 0.000081
Loss: 0.000266
Loss: 0.000111
Loss: 0.000485
Loss: 0.000060
Loss: 0.000209
Loss: 0.000291
Loss: 0.000083
Loss: 0.000304
Loss: 0.000410
Loss: 0.000214
Loss: 0.000347
Loss: 0.000354
Loss: 0.000253
Loss: 0.000126
Loss: 0.000152
Loss: 0.000578
Loss: 0.000222
Loss: 0.000239
Loss: 0.000322
Loss: 0.000408
Loss: 0.000177
Trained on 1 - Batch 42
Loss: 0.000240
Loss: 0.000194
Loss: 0.000298
Loss: 0.000258
Loss: 0.000163
Loss: 0.000178
Loss: 0.000138
Loss: 0.000102
Loss: 0.000289
Loss: 0.000484
Loss: 0.000164
Loss: 0.000472
Loss: 0.000440
Loss: 0.000472
Loss: 0.000353
Loss: 0.000316
Loss: 0.000257
Loss: 0.000157
Loss: 0.000144
Loss: 0.000262
Loss: 0.000039
Loss: 0.000478
Loss: 0.000317
Loss: 0.000230
Loss: 0.000214
Loss: 0.000145
Loss: 0.000220
Loss: 0.000144
L

Loss: 0.000171
Loss: 0.000201
Loss: 0.000343
Loss: 0.000490
Loss: 0.000168
Loss: 0.000390
Loss: 0.000374
Loss: 0.000151
Loss: 0.000176
Loss: 0.000196
Loss: 0.000143
Loss: 0.000087
Loss: 0.000069
Loss: 0.000248
Loss: 0.000123
Loss: 0.000408
Loss: 0.000125
Loss: 0.000246
Loss: 0.000193
Loss: 0.000335
Loss: 0.000100
Loss: 0.000284
Loss: 0.000142
Loss: 0.000229
Loss: 0.000209
Loss: 0.000351
Loss: 0.000044
Loss: 0.000253
Loss: 0.000154
Loss: 0.000264
Loss: 0.000571
Loss: 0.000110
Loss: 0.000225
Loss: 0.000324
Loss: 0.000225
Trained on 1 - Batch 44
Loss: 0.000370
Loss: 0.000246
Loss: 0.000094
Loss: 0.000129
Loss: 0.000375
Loss: 0.000168
Loss: 0.000666
Loss: 0.000347
Loss: 0.000239
Loss: 0.000442
Loss: 0.000382
Loss: 0.000251
Loss: 0.000093
Loss: 0.000186
Loss: 0.000294
Loss: 0.000137
Loss: 0.000157
Loss: 0.000051
Loss: 0.000216
Loss: 0.000096
Loss: 0.000143
Loss: 0.000123
Loss: 0.000070
Loss: 0.000086
Loss: 0.000097
Loss: 0.000228
Loss: 0.000125
Loss: 0.000063
Loss: 0.000128
Loss: 0.000235
L

Loss: 0.000257
Loss: 0.000114
Loss: 0.000121
Loss: 0.000204
Loss: 0.000180
Loss: 0.000141
Loss: 0.000036
Loss: 0.000227
Loss: 0.000183
Loss: 0.000133
Loss: 0.000042
Loss: 0.000114
Loss: 0.000115
Loss: 0.000100
Loss: 0.000079
Loss: 0.000215
Loss: 0.000040
Loss: 0.000051
Loss: 0.000237
Loss: 0.000126
Loss: 0.000393
Loss: 0.000273
Loss: 0.000150
Loss: 0.000226
Loss: 0.000219
Loss: 0.000170
Loss: 0.000117
Loss: 0.000172
Trained on 1 - Batch 46
Loss: 0.000118
Loss: 0.000185
Loss: 0.000246
Loss: 0.000080
Loss: 0.000233
Loss: 0.000140
Loss: 0.000133
Loss: 0.000083
Loss: 0.000069
Loss: 0.000022
Loss: 0.000128
Loss: 0.000335
Loss: 0.000243
Loss: 0.000601
Loss: 0.000183
Loss: 0.000158
Loss: 0.000082
Loss: 0.000392
Loss: 0.000079
Loss: 0.000181
Loss: 0.000051
Loss: 0.000084
Loss: 0.000138
Loss: 0.000227
Loss: 0.000107
Loss: 0.000113
Loss: 0.000150
Loss: 0.000084
Loss: 0.000141
Loss: 0.000127
Loss: 0.000158
Loss: 0.000130
Loss: 0.000138
Loss: 0.000286
Loss: 0.000113
Loss: 0.000118
Loss: 0.000337
L

Loss: 0.000316
Loss: 0.000122
Loss: 0.000129
Loss: 0.000058
Loss: 0.000039
Loss: 0.000057
Loss: 0.000240
Loss: 0.000189
Loss: 0.000020
Loss: 0.000334
Loss: 0.000194
Loss: 0.000028
Loss: 0.000368
Loss: 0.000096
Loss: 0.000100
Loss: 0.000123
Loss: 0.000132
Loss: 0.000209
Loss: 0.000046
Loss: 0.000243
Loss: 0.000152
Loss: 0.000308
Loss: 0.000397
Loss: 0.000274
Loss: 0.000081
Loss: 0.000232
Loss: 0.000166
Loss: 0.000045
Loss: 0.000100
Loss: 0.000115
Loss: 0.000074
Loss: 0.000070
Loss: 0.000133
Trained on 1 - Batch 48
Loss: 0.000037
Loss: 0.000229
Loss: 0.000425
Loss: 0.000129
Loss: 0.000164
Loss: 0.000161
Loss: 0.000157
Loss: 0.000085
Loss: 0.000067
Loss: 0.000124
Loss: 0.000216
Loss: 0.000108
Loss: 0.000135
Loss: 0.000182
Loss: 0.000223
Loss: 0.000243
Loss: 0.000283
Loss: 0.000338
Loss: 0.000372
Loss: 0.000233
Loss: 0.000219
Loss: 0.000363
Loss: 0.000223
Loss: 0.000193
Loss: 0.000087
Loss: 0.000180
Loss: 0.000119
Loss: 0.000261
Loss: 0.000083
Loss: 0.000258
Loss: 0.000047
Loss: 0.000176
L

Loss: 0.000081
Loss: 0.000179
Loss: 0.000287
Loss: 0.000106
Loss: 0.000193
Loss: 0.000326
Loss: 0.000198
Loss: 0.000128
Loss: 0.000304
Loss: 0.000159
Loss: 0.000216
Loss: 0.000253
Loss: 0.000242
Loss: 0.000108
Loss: 0.000089
Loss: 0.000112
Loss: 0.000141
Loss: 0.000280
Loss: 0.000131
Loss: 0.000120
Loss: 0.000259
Loss: 0.000328
Loss: 0.000076
Loss: 0.000157
Loss: 0.000024
Loss: 0.000272
Loss: 0.000397
Trained on 1 - Batch 50
Loss: 0.000290
Loss: 0.000213
Loss: 0.000068
Loss: 0.000098
Loss: 0.000108
Loss: 0.000066
Loss: 0.000140
Loss: 0.000271
Loss: 0.000113
Loss: 0.000091
Loss: 0.000083
Loss: 0.000251
Loss: 0.000150
Loss: 0.000072
Loss: 0.000232
Loss: 0.000225
Loss: 0.000229
Loss: 0.000093
Loss: 0.000113
Loss: 0.000278
Loss: 0.000139
Loss: 0.000055
Loss: 0.000218
Loss: 0.000073
Loss: 0.000215
Loss: 0.000020
Loss: 0.000055
Loss: 0.000170
Loss: 0.000252
Loss: 0.000096
Loss: 0.000219
Loss: 0.000108
Loss: 0.000190
Loss: 0.000063
Loss: 0.000072
Loss: 0.000269
Loss: 0.000103
Loss: 0.000145
L

Loss: 0.000302
Loss: 0.000211
Loss: 0.000083
Loss: 0.000072
Loss: 0.000201
Loss: 0.000091
Loss: 0.000184
Loss: 0.000127
Loss: 0.000127
Loss: 0.000198
Loss: 0.000213
Loss: 0.000037
Loss: 0.000151
Loss: 0.000084
Loss: 0.000120
Loss: 0.000134
Loss: 0.000138
Trained on 1 - Batch 52
Loss: 0.000032
Loss: 0.000257
Loss: 0.000029
Loss: 0.000033
Loss: 0.000166
Loss: 0.000259
Loss: 0.000123
Loss: 0.000111
Loss: 0.000100
Loss: 0.000539
Loss: 0.000048
Loss: 0.000189
Loss: 0.000183
Loss: 0.000085
Loss: 0.000051
Loss: 0.000193
Loss: 0.000232
Loss: 0.000195
Loss: 0.000103
Loss: 0.000088
Loss: 0.000325
Loss: 0.000127
Loss: 0.000187
Loss: 0.000055
Loss: 0.000073
Loss: 0.000220
Loss: 0.000320
Loss: 0.000119
Loss: 0.000187
Loss: 0.000453
Loss: 0.000184
Loss: 0.000058
Loss: 0.000116
Loss: 0.000133
Loss: 0.000132
Loss: 0.000212
Loss: 0.000150
Loss: 0.000027
Loss: 0.000181
Loss: 0.000062
Loss: 0.000209
Loss: 0.000150
Loss: 0.000276
Loss: 0.000141
Loss: 0.000296
Loss: 0.000131
Loss: 0.000206
Loss: 0.000099
L

Loss: 0.000140
Loss: 0.000238
Loss: 0.000405
Loss: 0.000261
Loss: 0.000249
Loss: 0.000108
Loss: 0.000696
Loss: 0.000299
Loss: 0.000111
Loss: 0.000161
Loss: 0.000057
Loss: 0.000223
Loss: 0.000122
Loss: 0.000171
Trained on 1 - Batch 54
Loss: 0.000266
Loss: 0.000211
Loss: 0.000188
Loss: 0.000082
Loss: 0.000139
Loss: 0.000081
Loss: 0.000143
Loss: 0.000236
Loss: 0.000083
Loss: 0.000220
Loss: 0.000056
Loss: 0.000092
Loss: 0.000062
Loss: 0.000026
Loss: 0.000037
Loss: 0.000127
Loss: 0.000096
Loss: 0.000109
Loss: 0.000386
Loss: 0.000245
Loss: 0.000329
Loss: 0.000133
Loss: 0.000328
Loss: 0.000057
Loss: 0.000176
Loss: 0.000244
Loss: 0.000048
Loss: 0.000029
Loss: 0.000151
Loss: 0.000150
Loss: 0.000170
Loss: 0.000343
Loss: 0.000279
Loss: 0.000122
Loss: 0.000263
Loss: 0.000116
Loss: 0.000092
Loss: 0.000081
Loss: 0.000104
Loss: 0.000088
Loss: 0.000153
Loss: 0.000117
Loss: 0.000083
Loss: 0.000060
Loss: 0.000214
Loss: 0.000174
Loss: 0.000310
Loss: 0.000074
Loss: 0.000111
Loss: 0.000058
Loss: 0.000153
L

Loss: 0.000124
Loss: 0.000149
Loss: 0.000218
Loss: 0.000319
Loss: 0.000269
Loss: 0.000166
Loss: 0.000185
Loss: 0.000263
Loss: 0.000250
Loss: 0.000102
Loss: 0.000271
Loss: 0.000100
Loss: 0.000126
Loss: 0.000235
Loss: 0.000275
Loss: 0.000025
Loss: 0.000372
Loss: 0.000009
Loss: 0.000249
Loss: 0.000111
Loss: 0.000101
Loss: 0.000115
Loss: 0.000088
Loss: 0.000216
Loss: 0.000207
Loss: 0.000121
Loss: 0.000047
Loss: 0.000119
Loss: 0.000342
Loss: 0.000221
Loss: 0.000155
Loss: 0.000336
Loss: 0.000117
Loss: 0.000098
Loss: 0.000213
Loss: 0.000070
Trained on 1 - Batch 56
Loss: 0.000336
Loss: 0.000064
Loss: 0.000211
Loss: 0.000146
Loss: 0.000120
Loss: 0.000100
Loss: 0.000354
Loss: 0.000285
Loss: 0.000092
Loss: 0.000553
Loss: 0.000055
Loss: 0.000075
Loss: 0.000195
Loss: 0.000127
Loss: 0.000092
Loss: 0.000192
Loss: 0.000080
Loss: 0.000215
Loss: 0.000036
Loss: 0.000222
Loss: 0.000205
Loss: 0.000035
Loss: 0.000128
Loss: 0.000103
Loss: 0.000049
Loss: 0.000234
Loss: 0.000234
Loss: 0.000140
Loss: 0.000101
L

Loss: 0.000228
Loss: 0.000309
Loss: 0.000077
Loss: 0.000078
Loss: 0.000134
Loss: 0.000057
Loss: 0.000329
Loss: 0.000158
Loss: 0.000173
Loss: 0.000157
Loss: 0.000273
Loss: 0.000146
Loss: 0.000061
Loss: 0.000408
Loss: 0.000213
Loss: 0.000324
Loss: 0.000161
Loss: 0.000384
Loss: 0.000182
Loss: 0.000276
Loss: 0.000147
Loss: 0.000063
Loss: 0.000141
Loss: 0.000032
Loss: 0.000294
Loss: 0.000162
Loss: 0.000040
Loss: 0.000164
Loss: 0.000233
Loss: 0.000168
Loss: 0.000230
Loss: 0.000197
Loss: 0.000184
Loss: 0.000086
Loss: 0.000078
Loss: 0.000083
Loss: 0.000153
Loss: 0.000038
Loss: 0.000098
Loss: 0.000046
Loss: 0.000210
Loss: 0.000069
Loss: 0.000257
Loss: 0.000175
Loss: 0.000449
Trained on 1 - Batch 58
Loss: 0.000090
Loss: 0.000169
Loss: 0.000162
Loss: 0.000205
Loss: 0.000310
Loss: 0.000152
Loss: 0.000310
Loss: 0.000303
Loss: 0.000052
Loss: 0.000273
Loss: 0.000096
Loss: 0.000081
Loss: 0.000158
Loss: 0.000063
Loss: 0.000041
Loss: 0.000035
Loss: 0.000164
Loss: 0.000040
Loss: 0.000194
Loss: 0.000157
L

Loss: 0.000063
Loss: 0.000107
Loss: 0.000345
Loss: 0.000208
Loss: 0.000059
Loss: 0.000217
Loss: 0.000062
Loss: 0.000271
Loss: 0.000054
Loss: 0.000065
Loss: 0.000212
Loss: 0.000222
Loss: 0.000284
Loss: 0.000137
Loss: 0.000110
Loss: 0.000021
Loss: 0.000105
Loss: 0.000412
Loss: 0.000137
Loss: 0.000119
Loss: 0.000105
Loss: 0.000166
Loss: 0.000172
Loss: 0.000190
Loss: 0.000116
Loss: 0.000102
Loss: 0.000184
Loss: 0.000088
Loss: 0.000053
Loss: 0.000211
Loss: 0.000237
Loss: 0.000343
Loss: 0.000209
Loss: 0.000104
Loss: 0.000121
Loss: 0.000202
Loss: 0.000042
Loss: 0.000064
Loss: 0.000033
Loss: 0.000099
Loss: 0.000090
Loss: 0.000241
Loss: 0.000174
Loss: 0.000133
Loss: 0.000123
Loss: 0.000113
Loss: 0.000078
Loss: 0.000417
Loss: 0.000088
Loss: 0.000058
Loss: 0.000107
Loss: 0.000318
Loss: 0.000128
Loss: 0.000369
Loss: 0.000092
Loss: 0.000138
Loss: 0.000092
Loss: 0.000303
Trained on 1 - Batch 60
Loss: 0.000239
Loss: 0.000195
Loss: 0.000125
Loss: 0.000180
Loss: 0.000228
Loss: 0.000168
Loss: 0.000110
L

Loss: 0.000089
Loss: 0.000230
Loss: 0.000050
Loss: 0.000244
Loss: 0.000050
Loss: 0.000085
Loss: 0.000062
Loss: 0.000085
Loss: 0.000191
Loss: 0.000353
Loss: 0.000202
Loss: 0.000080
Loss: 0.000036
Loss: 0.000029
Loss: 0.000107
Loss: 0.000321
Loss: 0.000101
Loss: 0.000092
Loss: 0.000138
Loss: 0.000066
Loss: 0.000266
Loss: 0.000071
Loss: 0.000036
Loss: 0.000328
Loss: 0.000153
Loss: 0.000179
Loss: 0.000398
Loss: 0.000014
Loss: 0.000071
Loss: 0.000134
Loss: 0.000148
Loss: 0.000106
Loss: 0.000124
Loss: 0.000121
Loss: 0.000228
Loss: 0.000108
Loss: 0.000035
Loss: 0.000210
Loss: 0.000034
Loss: 0.000070
Loss: 0.000078
Loss: 0.000088
Loss: 0.000161
Loss: 0.000064
Loss: 0.000188
Loss: 0.000058
Loss: 0.000278
Loss: 0.000068
Loss: 0.000590
Loss: 0.000065
Loss: 0.000060
Loss: 0.000144
Loss: 0.000200
Loss: 0.000084
Loss: 0.000301
Loss: 0.000303
Loss: 0.000109
Loss: 0.000464
Loss: 0.000106
Loss: 0.000334
Loss: 0.000149
Loss: 0.000135
Loss: 0.000164
Loss: 0.000294
Trained on 1 - Batch 62
Loss: 0.000510
L

Loss: 0.000103
Loss: 0.000227
Loss: 0.000326
Loss: 0.000227
Loss: 0.000111
Loss: 0.000059
Loss: 0.000184
Loss: 0.000154
Loss: 0.000029
Loss: 0.000084
Loss: 0.000089
Loss: 0.000079
Loss: 0.000098
Loss: 0.000127
Loss: 0.000184
Loss: 0.000033
Loss: 0.000138
Loss: 0.000124
Loss: 0.000060
Loss: 0.000024
Loss: 0.000113
Loss: 0.000236
Loss: 0.000029
Loss: 0.000312
Loss: 0.000079
Loss: 0.000128
Loss: 0.000211
Loss: 0.000104
Loss: 0.000279
Loss: 0.000173
Loss: 0.000044
Loss: 0.000383
Loss: 0.000232
Loss: 0.000205
Loss: 0.000068
Loss: 0.000268
Loss: 0.000318
Loss: 0.000300
Loss: 0.000094
Loss: 0.000098
Loss: 0.000377
Loss: 0.000065
Loss: 0.000129
Loss: 0.000066
Loss: 0.000022
Loss: 0.000264
Loss: 0.000108
Loss: 0.000021
Loss: 0.000033
Loss: 0.000193
Loss: 0.000139
Loss: 0.000284
Loss: 0.000240
Loss: 0.000348
Loss: 0.000391
Loss: 0.000396
Loss: 0.000186
Loss: 0.000043
Loss: 0.000043
Loss: 0.000164
Loss: 0.000298
Loss: 0.000074
Loss: 0.000527
Loss: 0.000385
Loss: 0.000131
Loss: 0.000150
Loss: 0.00

Loss: 0.000327
Loss: 0.000283
Loss: 0.000070
Loss: 0.000314
Loss: 0.000174
Loss: 0.000145
Loss: 0.000198
Loss: 0.000086
Loss: 0.000259
Loss: 0.000131
Loss: 0.000144
Loss: 0.000101
Loss: 0.000065
Loss: 0.000122
Loss: 0.000100
Loss: 0.000107
Loss: 0.000254
Loss: 0.000065
Loss: 0.000193
Loss: 0.000065
Loss: 0.000100
Loss: 0.000122
Loss: 0.000170
Loss: 0.000257
Loss: 0.000062
Loss: 0.000060
Loss: 0.000127
Loss: 0.000073
Loss: 0.000197
Loss: 0.000028
Loss: 0.000060
Loss: 0.000173
Loss: 0.000116
Loss: 0.000202
Loss: 0.000176
Loss: 0.000165
Loss: 0.000035
Loss: 0.000172
Loss: 0.000050
Loss: 0.000022
Loss: 0.000115
Loss: 0.000125
Loss: 0.000074
Loss: 0.000148
Loss: 0.000079
Loss: 0.000128
Loss: 0.000083
Loss: 0.000141
Loss: 0.000171
Loss: 0.000185
Loss: 0.000108
Loss: 0.000143
Loss: 0.000062
Loss: 0.000016
Loss: 0.000073
Loss: 0.000273
Loss: 0.000098
Loss: 0.000208
Loss: 0.000163
Loss: 0.000120
Loss: 0.000183
Loss: 0.000212
Loss: 0.000176
Loss: 0.000083
Loss: 0.000334
Loss: 0.000126
Loss: 0.00

Loss: 0.000091
Loss: 0.000063
Loss: 0.000072
Loss: 0.000230
Loss: 0.000162
Loss: 0.000199
Loss: 0.000012
Loss: 0.000035
Loss: 0.000044
Loss: 0.000038
Loss: 0.000178
Loss: 0.000102
Loss: 0.000025
Loss: 0.000142
Loss: 0.000334
Loss: 0.000111
Loss: 0.000063
Loss: 0.000338
Loss: 0.000165
Loss: 0.000213
Loss: 0.000242
Loss: 0.000299
Loss: 0.000019
Loss: 0.000015
Loss: 0.000212
Loss: 0.000246
Loss: 0.000348
Loss: 0.000151
Loss: 0.000390
Loss: 0.000170
Loss: 0.000084
Loss: 0.000098
Loss: 0.000105
Loss: 0.000162
Loss: 0.000059
Loss: 0.000065
Loss: 0.000397
Loss: 0.000047
Loss: 0.000048
Loss: 0.000255
Loss: 0.000142
Loss: 0.000155
Loss: 0.000098
Loss: 0.000023
Loss: 0.000419
Loss: 0.000085
Loss: 0.000114
Loss: 0.000128
Loss: 0.000066
Loss: 0.000077
Loss: 0.000053
Loss: 0.000150
Loss: 0.000174
Loss: 0.000205
Loss: 0.000122
Loss: 0.000172
Loss: 0.000111
Loss: 0.000165
Loss: 0.000195
Loss: 0.000130
Loss: 0.000235
Loss: 0.000126
Loss: 0.002705
Loss: 0.000135
Loss: 0.000307
Loss: 0.000033
Loss: 0.00

Loss: 0.000131
Loss: 0.000148
Loss: 0.000195
Loss: 0.000083
Loss: 0.000070
Loss: 0.000018
Loss: 0.000128
Loss: 0.000101
Loss: 0.000074
Loss: 0.000242
Loss: 0.000053
Loss: 0.000140
Loss: 0.000372
Loss: 0.000084
Loss: 0.000074
Loss: 0.000068
Loss: 0.000099
Loss: 0.000302
Loss: 0.000147
Loss: 0.000107
Loss: 0.000048
Loss: 0.000202
Loss: 0.000192
Loss: 0.000086
Loss: 0.000069
Loss: 0.000156
Loss: 0.000410
Loss: 0.000013
Loss: 0.000092
Loss: 0.000156
Loss: 0.000190
Loss: 0.000089
Loss: 0.000056
Loss: 0.000146
Loss: 0.000384
Loss: 0.000100
Loss: 0.000059
Loss: 0.000073
Loss: 0.000347
Loss: 0.000103
Loss: 0.000058
Loss: 0.000372
Loss: 0.000091
Loss: 0.000346
Loss: 0.000153
Loss: 0.000076
Loss: 0.000342
Loss: 0.000126
Loss: 0.000068
Loss: 0.000035
Loss: 0.000036
Loss: 0.000460
Loss: 0.000043
Loss: 0.000153
Loss: 0.000169
Loss: 0.000191
Loss: 0.000051
Loss: 0.000108
Loss: 0.000101
Loss: 0.000103
Loss: 0.000022
Loss: 0.000167
Loss: 0.000079
Loss: 0.000114
Loss: 0.000166
Loss: 0.000146
Loss: 0.00

Loss: 0.000070
Loss: 0.000148
Loss: 0.000088
Loss: 0.000362
Loss: 0.000090
Loss: 0.000140
Loss: 0.000215
Loss: 0.000113
Loss: 0.000043
Loss: 0.000114
Loss: 0.000278
Loss: 0.000166
Loss: 0.000099
Loss: 0.000020
Loss: 0.000146
Loss: 0.000165
Loss: 0.000086
Loss: 0.000138
Loss: 0.000236
Loss: 0.000189
Loss: 0.000350
Loss: 0.000134
Loss: 0.000108
Loss: 0.000104
Loss: 0.000035
Loss: 0.000112
Loss: 0.000038
Loss: 0.000166
Loss: 0.000050
Loss: 0.000236
Loss: 0.000293
Loss: 0.000013
Loss: 0.000067
Loss: 0.000066
Loss: 0.000179
Loss: 0.000169
Loss: 0.000078
Loss: 0.000103
Loss: 0.000111
Loss: 0.000174
Loss: 0.000040
Loss: 0.000132
Loss: 0.000046
Loss: 0.000032
Loss: 0.000109
Loss: 0.000268
Loss: 0.000258
Loss: 0.000128
Loss: 0.000455
Loss: 0.000166
Loss: 0.000241
Loss: 0.000141
Loss: 0.000112
Loss: 0.000200
Loss: 0.000179
Loss: 0.000066
Loss: 0.000018
Loss: 0.000310
Loss: 0.000111
Loss: 0.000204
Loss: 0.000208
Loss: 0.000308
Loss: 0.000112
Loss: 0.000041
Loss: 0.000058
Loss: 0.000363
Loss: 0.00

Loss: 0.000073
Loss: 0.000089
Loss: 0.000145
Loss: 0.000149
Loss: 0.000511
Loss: 0.000048
Loss: 0.000046
Loss: 0.000107
Loss: 0.000415
Loss: 0.000131
Loss: 0.000229
Loss: 0.000023
Loss: 0.000046
Loss: 0.000379
Loss: 0.000020
Loss: 0.000094
Loss: 0.000119
Loss: 0.000198
Loss: 0.000163
Loss: 0.000156
Loss: 0.000212
Loss: 0.000119
Loss: 0.000147
Loss: 0.000081
Loss: 0.000091
Loss: 0.000287
Loss: 0.000015
Loss: 0.000167
Loss: 0.000067
Loss: 0.000024
Loss: 0.000179
Loss: 0.000185
Loss: 0.000205
Loss: 0.000011
Loss: 0.000569
Loss: 0.000153
Loss: 0.000102
Loss: 0.000140
Loss: 0.000159
Loss: 0.000209
Loss: 0.000228
Loss: 0.000165
Loss: 0.000063
Loss: 0.000178
Loss: 0.000047
Loss: 0.000293
Loss: 0.000145
Loss: 0.000189
Loss: 0.000225
Loss: 0.000138
Loss: 0.000255
Loss: 0.000112
Loss: 0.000183
Loss: 0.000083
Loss: 0.000058
Loss: 0.000320
Loss: 0.000114
Loss: 0.000168
Loss: 0.000092
Loss: 0.000248
Loss: 0.000026
Loss: 0.000209
Loss: 0.000044
Loss: 0.000197
Loss: 0.000084
Loss: 0.000172
Loss: 0.00

Loss: 0.000163
Loss: 0.000544
Loss: 0.000201
Loss: 0.000100
Loss: 0.000412
Loss: 0.000098
Loss: 0.000151
Loss: 0.000215
Loss: 0.000091
Loss: 0.000079
Loss: 0.000072
Loss: 0.000179
Loss: 0.000098
Loss: 0.000031
Loss: 0.000353
Loss: 0.000023
Loss: 0.000161
Loss: 0.000048
Loss: 0.000033
Loss: 0.000092
Loss: 0.000065
Loss: 0.000287
Loss: 0.000174
Loss: 0.000062
Loss: 0.000518
Loss: 0.000160
Loss: 0.000081
Loss: 0.000069
Loss: 0.000130
Loss: 0.000147
Loss: 0.000054
Loss: 0.000029
Loss: 0.000232
Loss: 0.000201
Loss: 0.000041
Loss: 0.000029
Loss: 0.000093
Loss: 0.000058
Loss: 0.000341
Loss: 0.000030
Loss: 0.000066
Loss: 0.000186
Loss: 0.000414
Loss: 0.000076
Loss: 0.000100
Loss: 0.000023
Loss: 0.000129
Loss: 0.000092
Loss: 0.000410
Loss: 0.000047
Loss: 0.000219
Loss: 0.000010
Loss: 0.000024
Loss: 0.000133
Loss: 0.000467
Loss: 0.000111
Loss: 0.000031
Loss: 0.000258
Loss: 0.000127
Loss: 0.000057
Loss: 0.000116
Loss: 0.000315
Loss: 0.000128
Loss: 0.000263
Loss: 0.000210
Loss: 0.000142
Loss: 0.00

Loss: 0.000149
Loss: 0.000036
Loss: 0.000058
Loss: 0.000067
Loss: 0.000014
Loss: 0.000102
Loss: 0.000216
Loss: 0.000314
Loss: 0.000099
Loss: 0.000278
Loss: 0.000043
Loss: 0.000031
Loss: 0.000125
Loss: 0.000073
Loss: 0.000070
Loss: 0.000083
Loss: 0.000110
Loss: 0.000291
Loss: 0.000026
Loss: 0.000013
Loss: 0.000047
Loss: 0.000028
Loss: 0.000079
Loss: 0.000245
Loss: 0.000133
Loss: 0.000434
Loss: 0.000040
Loss: 0.000128
Loss: 0.000126
Loss: 0.000051
Loss: 0.000241
Loss: 0.000026
Loss: 0.000198
Loss: 0.000057
Loss: 0.000037
Loss: 0.000027
Loss: 0.000404
Loss: 0.000166
Loss: 0.000307
Loss: 0.000245
Loss: 0.000264
Loss: 0.000042
Loss: 0.000149
Loss: 0.000144
Loss: 0.000077
Loss: 0.000254
Loss: 0.000109
Loss: 0.000058
Loss: 0.000038
Loss: 0.000138
Loss: 0.000213
Loss: 0.000176
Loss: 0.000143
Loss: 0.000297
Loss: 0.000057
Loss: 0.000069
Loss: 0.000101
Loss: 0.000046
Loss: 0.000062
Loss: 0.000018
Loss: 0.000015
Loss: 0.000050
Loss: 0.000034
Loss: 0.000190
Loss: 0.000123
Loss: 0.000035
Loss: 0.00

Loss: 0.000381
Loss: 0.000057
Loss: 0.000098
Loss: 0.000047
Loss: 0.000061
Loss: 0.000070
Loss: 0.000030
Loss: 0.000222
Loss: 0.000052
Loss: 0.000034
Loss: 0.000156
Loss: 0.000153
Loss: 0.000186
Loss: 0.000092
Loss: 0.000037
Loss: 0.000258
Loss: 0.000197
Loss: 0.000545
Loss: 0.000190
Loss: 0.000067
Loss: 0.000281
Loss: 0.000130
Loss: 0.000229
Loss: 0.000052
Loss: 0.000307
Loss: 0.000082
Loss: 0.000117
Loss: 0.000347
Loss: 0.000208
Loss: 0.000015
Loss: 0.000088
Loss: 0.000166
Loss: 0.000259
Loss: 0.000100
Loss: 0.000071
Loss: 0.000149
Loss: 0.000040
Loss: 0.000060
Loss: 0.000075
Loss: 0.000056
Loss: 0.000053
Loss: 0.000112
Loss: 0.000078
Loss: 0.000101
Loss: 0.000173
Loss: 0.000114
Loss: 0.000127
Loss: 0.000284
Loss: 0.000223
Loss: 0.000014
Loss: 0.000096
Loss: 0.000198
Loss: 0.000161
Loss: 0.000125
Loss: 0.000026
Loss: 0.000071
Loss: 0.000188
Loss: 0.000149
Loss: 0.000161
Loss: 0.000207
Loss: 0.000070
Loss: 0.000100
Loss: 0.000048
Loss: 0.000086
Loss: 0.000063
Loss: 0.000317
Loss: 0.00

In [102]:
pickle_file = f"/data2/xpgeng/iML1515_MLP/{1}_{1}.pkl"

with open(pickle_file, "rb") as f:
    batch_data, batch_labels = pickle.load(f)

# 转换为张量
batch_data = torch.tensor(batch_data, dtype=torch.float32).to(device)

tensor([[ 3.6431, -0.7383,  2.7668,  ...,  0.0252, -0.1478, -0.0602],
        [ 3.8448, -0.7019,  2.1398,  ...,  0.0302, -0.1493, -0.0578],
        [ 0.1532, -0.5058,  0.1532,  ...,  0.0313, -0.1570, -0.0748],
        ...,
        [ 1.1735, -0.7811,  0.0565,  ...,  0.0287, -0.1808, -0.0865],
        [-0.6403, -0.6403,  0.3966,  ...,  0.0199, -0.1605, -0.0596],
        [ 2.7856, -0.6605,  0.4882,  ...,  0.0263, -0.1715, -0.0879]],
       device='cuda:3')

In [109]:
print(batch_data[0][600:1200])

tensor([ 0.4973, -0.7383, -0.7383,  0.4973, -0.1205, -0.7383, -0.7383,  0.4973,
        -0.1205, -0.1205, -0.7383, -0.7383, -0.7383, -0.1205, -0.1205,  0.4973,
        -0.7383,  0.4973, -0.7383, -0.1205, -0.7383, -0.7383, -0.7383, -0.7383,
        -0.7383,  0.4973, -0.1205, -0.1205, -0.7383, -0.1205, -0.7383, -0.7383,
        -0.7383, -0.7383,  0.4973, -0.7383, -0.7383, -0.1205, -0.7383, -0.7383,
        -0.1205, -0.7383, -0.1205, -0.7383,  0.4973,  1.1151,  0.4973, -0.1205,
        -0.7383,  1.1151, -0.1205, -0.7383, -0.1205,  0.4973, -0.1205, -0.1205,
        -0.1205, -0.7383, -0.7383, -0.7383,  0.4973, -0.7383, -0.1205, -0.7383,
        -0.7383, -0.7383, -0.7383, -0.7383, -0.7383,  1.1151, -0.7383, -0.7383,
        -0.7383, -0.1205,  0.4973, -0.1205, -0.1205, -0.7383, -0.7383, -0.7383,
         0.4973, -0.7383, -0.7383, -0.1205, -0.7383, -0.1205, -0.7383, -0.7383,
        -0.7383,  2.3507, -0.7383, -0.7383, -0.7383, -0.7383, -0.1205, -0.1205,
        -0.7383,  0.4973, -0.1205,  0.49