In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
 
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    num = 0
    for filename in filenames:
        num += 1
        #print(os.path.join(dirname, filename))
    print(dirname,num)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
"""
    DataSet库说明-
    读取文件路径-img_dir
    读取50,000张图片
    输出为 X, Y
    X是（10000，4,56,56） 1000表示样本数 4表示通道数（包含四个emcal hcal trkn trkp)
    Y是（10000,56,56）
"""
# 导入相关库
import tifffile as tiff #读取tiff文件格式
from PIL import Image #图片处理
#与torch 相关的库
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

#from sklearn.preprocessing import MinMaxScaler
import numpy as np
import imageio 

class MaxMinNormalizeGlobalPerChannel:
    """
    针对 (Batch, Channel, Width, Height) 的张量，
    在所有 Batch 中对每个通道整体进行最大最小归一化。
    """
    def __call__(self, tensor):
        # 确保输入是 (Batch, Channel, Width, Height) 的张量
        assert tensor.dim() == 4, "Input tensor must have 4 dimensions: (Batch, Channel, Width, Height)."
        
        # 计算每个通道在所有 Batch 上的全局最小值和最大值
        # 结果是 (Channel, 1, 1)
        min_vals = tensor.amin(dim=(0, 2, 3), keepdim=True)  # 在 Batch、Width、Height 维度求最小值
        max_vals = tensor.amax(dim=(0, 2, 3), keepdim=True)  # 在 Batch、Width、Height 维度求最大值
        
        # 最大最小归一化公式
        tensor = (tensor - min_vals) / (max_vals - min_vals + 1e-8)
        
        return tensor


#创建数据集
class MyDataSet(Dataset):
    def __init__(self,img_dir,group_size=10000,size_in=10000,transform=None,
                split_shuffle = True,splition = True):
        self.img_dir=img_dir
        self.images=os.listdir(img_dir)
        self.transform=transform
        self.all_imgs=[]
        self.emcal=[]
        self.hcal=[]
        self.trkn=[]
        self.trkp=[]
        self.truth=[]
        self.group_size=group_size
        self.size_in=size_in
        self.splition=splition
        self.split_shuffle = split_shuffle
        self.load_images()
        #self.normalize()
    
    def load_images(self):
        all_imgs=[]
        to_pil = transforms.ToPILImage()
        prefixes = ['emcal', 'hcal', 'trkn', 'trkp', 'truth']
        for prefix in prefixes:
            for i in range(self.size_in):
                filename = f"{prefix}_{str(i)}.tiff"
                img_path = img_path=os.path.join(self.img_dir, filename)
                #print(img_path)
                img_array=tiff.imread(img_path)
                img=Image.fromarray(img_array)
                img_tensor=transform(img)
                all_imgs.append(img_tensor)
        self.emcal=all_imgs[:self.size_in]
        self.hcal=all_imgs[self.group_size:self.group_size+self.size_in]
        self.trkn=all_imgs[2*self.group_size:2*self.group_size+self.size_in]
        self.trkp=all_imgs[3*self.group_size:3*self.group_size+self.size_in]
        self.truth=all_imgs[4*self.group_size:4*self.group_size+self.size_in]
        
        self.X=[]
        self.Y=[]
        picture = np.ndarray([])
        
        if self.transform is not None:
            transformation = self.transform
            print('transformation is not None')
        else:
            transformation = lambda x: x
            print('transformation is None')
        
        for emcal, hcal, trkn, trkp in zip(self.emcal,self.hcal,self.trkn, self.trkp):
            combined_features=torch.stack((emcal,hcal,trkn,trkp),dim=0).reshape(-1,56,56)
            self.X.append(combined_features)
        
        self.X=torch.stack(self.X).squeeze()
        self.X=transformation(self.X)
        self.Y=torch.stack(self.truth)
        self.Y=transformation(self.Y)
        
        N = self.X.size(0)
        train_size = int(0.8 * N)
        val_size = int(0.1 * N)
        if self.split_shuffle:
            indices = torch.randperm(N)

        else:
            indices = torch.arange(N)
            # 按照比例划分索引
        train_indices = indices[:train_size]
        val_indices = indices[train_size:train_size + val_size]
        test_indices = indices[train_size + val_size:]
        if self.splition == True:
            # 根据索引划分数据集
            self.train_X = self.X[train_indices]
            self.train_Y = self.Y[train_indices]
            self.val_X = self.X[val_indices]
            self.val_Y = self.Y[val_indices]
            self.test_X = self.X[test_indices]
            self.test_Y = self.Y[test_indices]
            # 释放内存
            del self.X
            del self.Y



    def __len__(self):
        return len(self.X)
    def __getitem__(self,idx):
        return self.X[idx], self.Y[idx]
transform=transforms.Compose([
    transforms.ToTensor(),
    # 数据预处理后期添加
])

    
class dataset_2(Dataset):
    def __init__(self,X,Y):
        self.X=X
        self.Y=Y
    def __len__(self):
        return len(self.X)
    def __getitem__(self,idx):
        return self.X[idx], self.Y[idx]

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1) #H*W->H*W
        self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, H, W = x.size()
        # 生成查询、键、值
        queries = self.query(x).view(batch_size, C, -1) # (B, C, H*W)
        keys = self.key(x).view(batch_size, C, -1) # (B, C, H*W)
        values = self.value(x).view(batch_size, C, -1) # (B, C, H*W)

        # 计算自注意力
        attention_scores = torch.bmm(queries.permute(0, 2, 1), keys) # (B, H*W, H*W)
        attention_scores = self.softmax(attention_scores)

        out = torch.bmm(values, attention_scores.permute(0, 2, 1)) # (B, C, H*W)
        return out.view(batch_size, C, H, W) #不改变形状




class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock,self).__init__()
        self.fc1=nn.Linear(channels,channels//reduction,bias=False)
        self.fc2=nn.Linear(channels//reduction,channels,bias=False)

    def forward(self,x):
        b, c,_,_=x.size()
        y = F.adaptive_avg_pool2d(x, (1, 1)).view(b, c) # Squeeze
        y=F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y)).view(b, c, 1, 1) # Excitation - 2nd layer
        return x * y.expand_as(x) # Scale

class DownSampling(nn.Module):
    def __init__(self, C_in, C_out):
        super(DownSampling, self).__init__()
        self.Down = nn.Sequential(
            nn.Conv2d(C_in, C_out, kernel_size=2, stride=2),  # 2x2卷积，步幅2会让特征尺寸减半
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.Down(x)
#定义上采样层
class UpSampling(nn.Module):
    def __init__(self, C_in, C_out):
        super(UpSampling, self).__init__()
        self.Up = nn.Conv2d(C_in, C_out, kernel_size=1)  # 改变通道数的卷积

    def forward(self, x, r):
        up = F.interpolate(x, scale_factor=2, mode='nearest')  # 使用最近邻插值进行上采样
        x = self.Up(up)  # 改变输出通道数
        x = torch.cat([x, r], dim=1)  # 进行跳跃连接，拼接特征
        return x


class Conv_UNet(nn.Module):
    def __init__(self, C_in, C_out):
        super(Conv_UNet, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(C_in, C_out, kernel_size=3, stride=1, padding=1),  # 3x3卷积，padding=1保持尺寸不变
            nn.BatchNorm2d(C_out),
            nn.Dropout(0.3),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        return self.layer(x)

class DownSampling_UNet(nn.Module):
    def __init__(self, C_in, C_out):
        super(DownSampling_UNet, self).__init__()
        self.Down = nn.Sequential(
            nn.Conv2d(C_in, C_out, kernel_size=2, stride=2),  # 2x2卷积，步幅2会让特征尺寸减半
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.Down(x)

class UpSampling_UNet(nn.Module):
    def __init__(self, C_in, C_out):
        super(UpSampling_UNet, self).__init__()
        self.Up = nn.Conv2d(C_in, C_out, kernel_size=1)  # 改变通道数的卷积

    def forward(self, x, r):
        up = F.interpolate(x, scale_factor=2, mode='nearest')  # 使用最近邻插值进行上采样
        x = self.Up(up)  # 改变输出通道数
        x = torch.cat([x, r], dim=1)  # 进行跳跃连接，拼接特征
        return x
        

'''''''''
CNN_90k:
PS D:\LECINSUMMER\project4> & D:/anaconda/python.exe d:/LECINSUMMER/project4/demo.py
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 56, 56]           1,616
       BatchNorm2d-2           [-1, 16, 56, 56]              32
              ReLU-3           [-1, 16, 56, 56]               0
            Conv2d-4           [-1, 32, 56, 56]          12,832
       BatchNorm2d-5           [-1, 32, 56, 56]              64
              ReLU-6           [-1, 32, 56, 56]               0
            Conv2d-7           [-1, 64, 56, 56]          51,264
       BatchNorm2d-8           [-1, 64, 56, 56]             128
              ReLU-9           [-1, 64, 56, 56]               0
           Conv2d-10           [-1, 16, 56, 56]          25,616
           Conv2d-11            [-1, 1, 56, 56]             401
          Sigmoid-12            [-1, 1, 56, 56]               0
================================================================
Total params: 91,953
Trainable params: 91,953
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 8.47
Params size (MB): 0.35
Estimated Total Size (MB): 8.87
----------------------------------------------------------------
None
'''''''''

class CNN_90k(nn.Module):
    def __init__(self):
        super(CNN_90k,self).__init__()

        self.encoder=nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.decoder=nn.Sequential(
            nn.Conv2d(64, 16, kernel_size=5, padding=2),
            nn.Conv2d(16, 1, kernel_size=5, padding=2),
            nn.Sigmoid()
        )
    def forward(self,x): #x=torch.cat((emcal,hcal,trkn,trkp),dim=1) (4,56,56)
        x=self.encoder(x)
        x=self.decoder(x)
        return x


'''''''''''''''''''''
CNNwithSEBlock_90k:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 56, 56]           1,616
       BatchNorm2d-2           [-1, 16, 56, 56]              32
              ReLU-3           [-1, 16, 56, 56]               0
            Conv2d-4           [-1, 32, 56, 56]          12,832
       BatchNorm2d-5           [-1, 32, 56, 56]              64
              ReLU-6           [-1, 32, 56, 56]               0
            Conv2d-7           [-1, 64, 56, 56]          51,264
       BatchNorm2d-8           [-1, 64, 56, 56]             128
              ReLU-9           [-1, 64, 56, 56]               0
           Linear-10                    [-1, 4]             256
           Linear-11                   [-1, 64]             256
          SEBlock-12           [-1, 64, 56, 56]               0
           Conv2d-13           [-1, 16, 56, 56]          25,616
           Conv2d-14            [-1, 1, 56, 56]             401
          Sigmoid-15            [-1, 1, 56, 56]               0
================================================================
Total params: 92,465
Trainable params: 92,465
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 10.00
Params size (MB): 0.35
Estimated Total Size (MB): 10.40
----------------------------------------------------------------
'''''''''''''''''''''''



class CNNwithSEBlock_90k(nn.Module):
    def __init__(self):
        super(CNNwithSEBlock_90k,self).__init__()

        self.encoder=nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.se1=SEBlock(64)
        self.decoder=nn.Sequential(
            nn.Conv2d(64, 16, kernel_size=5, padding=2),
            nn.Conv2d(16, 1, kernel_size=5, padding=2),
            nn.Sigmoid()
        )
    def forward(self,x): #x=torch.cat((emcal,hcal,trkn,trkp),dim=1) (4,56,56)
        x=self.encoder(x)
        x=self.se1(x)
        x=self.decoder(x)
        return x




'''''''''''
CNNwithSelfattention_90k:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 56, 56]           1,616
       BatchNorm2d-2           [-1, 16, 56, 56]              32
              ReLU-3           [-1, 16, 56, 56]               0
            Conv2d-4           [-1, 32, 56, 56]          12,832
       BatchNorm2d-5           [-1, 32, 56, 56]              64
              ReLU-6           [-1, 32, 56, 56]               0
            Conv2d-7           [-1, 64, 56, 56]          51,264
       BatchNorm2d-8           [-1, 64, 56, 56]             128
              ReLU-9           [-1, 64, 56, 56]               0
           Conv2d-10           [-1, 64, 56, 56]           4,160
           Conv2d-11           [-1, 64, 56, 56]           4,160
           Conv2d-12           [-1, 64, 56, 56]           4,160
          Softmax-13           [-1, 3136, 3136]               0
    SelfAttention-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 16, 56, 56]          25,616
           Conv2d-16            [-1, 1, 56, 56]             401
          Sigmoid-17            [-1, 1, 56, 56]               0
================================================================
Total params: 104,433
Trainable params: 104,433
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 89.63
Params size (MB): 0.40
Estimated Total Size (MB): 90.07
----------------------------------------------------------------
'''''''''''

class CNNwithSelfattention_90k(nn.Module):
    def __init__(self):
        super(CNNwithSelfattention_90k,self).__init__()

        self.encoder=nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.attention = SelfAttention(64)
        self.decoder=nn.Sequential(
            nn.Conv2d(64, 16, kernel_size=5, padding=2),
            nn.Conv2d(16, 1, kernel_size=5, padding=2),
            nn.Sigmoid()
        )
    def forward(self,x): #x=torch.cat((emcal,hcal,trkn,trkp),dim=1) (4,56,56)
        x=self.encoder(x)
        x=self.attention(x)
        x=self.decoder(x)
        return x

''''''''''
CNN3D_90k:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1         [-1, 2, 1, 56, 56]              56
            Conv3d-2         [-1, 2, 1, 56, 56]             152
            Conv3d-3         [-1, 2, 1, 56, 56]             296
            Conv3d-4         [-1, 2, 1, 56, 56]              56
            Conv3d-5         [-1, 2, 1, 56, 56]             152
            Conv3d-6         [-1, 2, 1, 56, 56]             296
            Conv2d-7           [-1, 32, 56, 56]           9,632
       BatchNorm2d-8           [-1, 32, 56, 56]              64
              ReLU-9           [-1, 32, 56, 56]               0
           Conv2d-10           [-1, 64, 56, 56]          51,264
      BatchNorm2d-11           [-1, 64, 56, 56]             128
             ReLU-12           [-1, 64, 56, 56]               0
           Conv2d-13           [-1, 16, 56, 56]          25,616
           Conv2d-14            [-1, 1, 56, 56]             401
          Sigmoid-15            [-1, 1, 56, 56]               0
================================================================
Total params: 88,113
Trainable params: 88,113
Non-trainable params: 0
----------------------------------------------------------------
'''

class CNN3D_90k(nn.Module):
    def __init__(self):
        super(CNN3D_90k,self).__init__()
        self.conv3x3x3 = nn.Conv3d(1, 2, kernel_size=3, padding=(0,1,1))
        self.conv3x5x5 = nn.Conv3d(1, 2, kernel_size=(3,5,5), padding=(0,2,2))
        self.conv3x7x7 = nn.Conv3d(1, 2, kernel_size=(3,7,7), padding=(0,3,3))
        self.encoder=nn.Sequential(
            nn.Conv2d(12, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.decoder=nn.Sequential(
            nn.Conv2d(64, 16, kernel_size=5, padding=2),
            nn.Conv2d(16, 1, kernel_size=5, padding=2),
            nn.Sigmoid()
        )
    def forward(self,x): #x=torch.cat((emcal,hcal,trkn,trkp),dim=1) (4,56,56)
        x = x.unsqueeze(1)
        x_e_h_n = x[:,:,:3,:,:]
        x_e_h_p = x[:,:,[0,1,3],:,:]
        x2 = self.conv3x3x3(x_e_h_n)
        x3 = self.conv3x5x5(x_e_h_n)
        x4 = self.conv3x7x7(x_e_h_n)
        x5 = self.conv3x3x3(x_e_h_p)
        x6 = self.conv3x5x5(x_e_h_p)
        x7 = self.conv3x7x7(x_e_h_p)
        x = torch.cat((x2,x3,x4,x5,x6,x7),dim=1).view(-1,12,56,56)
        x=self.encoder(x)
        x=self.decoder(x)
        return x

'''''''''''
UNet_90k:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 56, 56]             592
       BatchNorm2d-2           [-1, 16, 56, 56]              32
           Dropout-3           [-1, 16, 56, 56]               0
         LeakyReLU-4           [-1, 16, 56, 56]               0
         Conv_UNet-5           [-1, 16, 56, 56]               0
            Conv2d-6           [-1, 32, 28, 28]           2,080
         LeakyReLU-7           [-1, 32, 28, 28]               0
 DownSampling_UNet-8           [-1, 32, 28, 28]               0
            Conv2d-9           [-1, 32, 28, 28]           9,248
      BatchNorm2d-10           [-1, 32, 28, 28]              64
          Dropout-11           [-1, 32, 28, 28]               0
        LeakyReLU-12           [-1, 32, 28, 28]               0
        Conv_UNet-13           [-1, 32, 28, 28]               0
           Conv2d-14           [-1, 64, 14, 14]           8,256
        LeakyReLU-15           [-1, 64, 14, 14]               0
DownSampling_UNet-16           [-1, 64, 14, 14]               0
           Conv2d-17           [-1, 64, 14, 14]          36,928
      BatchNorm2d-18           [-1, 64, 14, 14]             128
          Dropout-19           [-1, 64, 14, 14]               0
        LeakyReLU-20           [-1, 64, 14, 14]               0
        Conv_UNet-21           [-1, 64, 14, 14]               0
           Conv2d-22           [-1, 32, 28, 28]           2,080
  UpSampling_UNet-23           [-1, 64, 28, 28]               0
           Conv2d-24           [-1, 32, 28, 28]          18,464
      BatchNorm2d-25           [-1, 32, 28, 28]              64
          Dropout-26           [-1, 32, 28, 28]               0
        LeakyReLU-27           [-1, 32, 28, 28]               0
        Conv_UNet-28           [-1, 32, 28, 28]               0
           Conv2d-29           [-1, 16, 56, 56]             528
  UpSampling_UNet-30           [-1, 32, 56, 56]               0
           Conv2d-31            [-1, 8, 56, 56]           2,312
      BatchNorm2d-32            [-1, 8, 56, 56]              16
          Dropout-33            [-1, 8, 56, 56]               0
        LeakyReLU-34            [-1, 8, 56, 56]               0
        Conv_UNet-35            [-1, 8, 56, 56]               0
           Conv2d-36            [-1, 1, 56, 56]              73
          Sigmoid-37            [-1, 1, 56, 56]               0
================================================================
Total params: 80,865
Trainable params: 80,865
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 7.90
Params size (MB): 0.31
Estimated Total Size (MB): 8.25
----------------------------------------------------------------
'''''''''''

class UNet_90k(nn.Module):
    def __init__(self,in_channels):
        super(UNet_90k, self).__init__()
        self.in_channels=in_channels
        self.C1 = Conv_UNet(self.in_channels, 16)
        self.D1 = DownSampling_UNet(16, 32)
        self.C2 = Conv_UNet(32, 32)
        self.D2 = DownSampling_UNet(32, 64)
        self.C3 = Conv_UNet(64, 64)
        self.U1 = UpSampling_UNet(64, 32)
        self.C4 = Conv_UNet(64, 32)
        self.U2 = UpSampling_UNet(32, 16)
        self.C5 = Conv_UNet(32, 8)
        self.pred = nn.Conv2d(8, 1, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        R1 = self.C1(x)
        R2 = self.C2(self.D1(R1))
        R3 = self.C3(self.D2(R2))
        up1 = self.C4(self.U1(R3, R2))
        c = self.C5(self.U2(up1,R1))
        return self.sigmoid(self.pred(c))

'''''''''''
UnetwithSEBlock_90k:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 56, 56]             592
       BatchNorm2d-2           [-1, 16, 56, 56]              32
           Dropout-3           [-1, 16, 56, 56]               0
         LeakyReLU-4           [-1, 16, 56, 56]               0
         Conv_UNet-5           [-1, 16, 56, 56]               0
            Conv2d-6           [-1, 32, 28, 28]           2,080
         LeakyReLU-7           [-1, 32, 28, 28]               0
 DownSampling_UNet-8           [-1, 32, 28, 28]               0
            Conv2d-9           [-1, 32, 28, 28]           9,248
      BatchNorm2d-10           [-1, 32, 28, 28]              64
          Dropout-11           [-1, 32, 28, 28]               0
        LeakyReLU-12           [-1, 32, 28, 28]               0
        Conv_UNet-13           [-1, 32, 28, 28]               0
           Conv2d-14           [-1, 64, 14, 14]           8,256
        LeakyReLU-15           [-1, 64, 14, 14]               0
DownSampling_UNet-16           [-1, 64, 14, 14]               0
           Conv2d-17           [-1, 64, 14, 14]          36,928
      BatchNorm2d-18           [-1, 64, 14, 14]             128
          Dropout-19           [-1, 64, 14, 14]               0
        LeakyReLU-20           [-1, 64, 14, 14]               0
        Conv_UNet-21           [-1, 64, 14, 14]               0
           Linear-22                    [-1, 4]             256
           Linear-23                   [-1, 64]             256
          SEBlock-24           [-1, 64, 14, 14]               0
           Conv2d-25           [-1, 32, 28, 28]           2,080
  UpSampling_UNet-26           [-1, 64, 28, 28]               0
           Conv2d-27           [-1, 32, 28, 28]          18,464
      BatchNorm2d-28           [-1, 32, 28, 28]              64
          Dropout-29           [-1, 32, 28, 28]               0
        LeakyReLU-30           [-1, 32, 28, 28]               0
        Conv_UNet-31           [-1, 32, 28, 28]               0
           Conv2d-32           [-1, 16, 56, 56]             528
  UpSampling_UNet-33           [-1, 32, 56, 56]               0
           Conv2d-34            [-1, 8, 56, 56]           2,312
      BatchNorm2d-35            [-1, 8, 56, 56]              16
          Dropout-36            [-1, 8, 56, 56]               0
        LeakyReLU-37            [-1, 8, 56, 56]               0
        Conv_UNet-38            [-1, 8, 56, 56]               0
           Conv2d-39            [-1, 1, 56, 56]              73
          Sigmoid-40            [-1, 1, 56, 56]               0
================================================================
Total params: 81,377
Trainable params: 81,377
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 7.99
Params size (MB): 0.31
Estimated Total Size (MB): 8.35
----------------------------------------------------------------
'''''''''''
class UnetwithSEBlock_90k(nn.Module):
    def __init__(self,in_channels):
        super(UnetwithSEBlock_90k, self).__init__()
        self.in_channels=in_channels
        self.C1 = Conv_UNet(self.in_channels, 16)
        self.D1 = DownSampling_UNet(16, 32)
        self.C2 = Conv_UNet(32, 32)
        self.D2 = DownSampling_UNet(32, 64)
        self.C3 = Conv_UNet(64, 64)
        self.se1=SEBlock(64)
        self.U1 = UpSampling_UNet(64, 32)
        self.C4 = Conv_UNet(64, 32)
        self.U2 = UpSampling_UNet(32, 16)
        self.C5 = Conv_UNet(32, 8)
        self.pred = nn.Conv2d(8, 1, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        R1 = self.C1(x)
        R2 = self.C2(self.D1(R1))
        R3 = self.C3(self.D2(R2))
        R3=self.se1(R3)
        up1 = self.C4(self.U1(R3, R2))
        c = self.C5(self.U2(up1,R1))
        return self.sigmoid(self.pred(c))

''''''''''''''''
UnetwithSelfattention_90k:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 56, 56]             592
       BatchNorm2d-2           [-1, 16, 56, 56]              32
           Dropout-3           [-1, 16, 56, 56]               0
         LeakyReLU-4           [-1, 16, 56, 56]               0
         Conv_UNet-5           [-1, 16, 56, 56]               0
            Conv2d-6           [-1, 32, 28, 28]           2,080
         LeakyReLU-7           [-1, 32, 28, 28]               0
 DownSampling_UNet-8           [-1, 32, 28, 28]               0
            Conv2d-9           [-1, 32, 28, 28]           9,248
      BatchNorm2d-10           [-1, 32, 28, 28]              64
          Dropout-11           [-1, 32, 28, 28]               0
        LeakyReLU-12           [-1, 32, 28, 28]               0
        Conv_UNet-13           [-1, 32, 28, 28]               0
           Conv2d-14           [-1, 64, 14, 14]           8,256
        LeakyReLU-15           [-1, 64, 14, 14]               0
DownSampling_UNet-16           [-1, 64, 14, 14]               0
           Conv2d-17           [-1, 64, 14, 14]          36,928
      BatchNorm2d-18           [-1, 64, 14, 14]             128
          Dropout-19           [-1, 64, 14, 14]               0
        LeakyReLU-20           [-1, 64, 14, 14]               0
        Conv_UNet-21           [-1, 64, 14, 14]               0
           Conv2d-22           [-1, 64, 14, 14]           4,160
           Conv2d-23           [-1, 64, 14, 14]           4,160
           Conv2d-24           [-1, 64, 14, 14]           4,160
          Softmax-25             [-1, 196, 196]               0
    SelfAttention-26           [-1, 64, 14, 14]               0
           Conv2d-27           [-1, 32, 28, 28]           2,080
  UpSampling_UNet-28           [-1, 64, 28, 28]               0
           Conv2d-29           [-1, 32, 28, 28]          18,464
      BatchNorm2d-30           [-1, 32, 28, 28]              64
          Dropout-31           [-1, 32, 28, 28]               0
        LeakyReLU-32           [-1, 32, 28, 28]               0
        Conv_UNet-33           [-1, 32, 28, 28]               0
           Conv2d-34           [-1, 16, 56, 56]             528
  UpSampling_UNet-35           [-1, 32, 56, 56]               0
           Conv2d-36            [-1, 8, 56, 56]           2,312
      BatchNorm2d-37            [-1, 8, 56, 56]              16
          Dropout-38            [-1, 8, 56, 56]               0
        LeakyReLU-39            [-1, 8, 56, 56]               0
        Conv_UNet-40            [-1, 8, 56, 56]               0
           Conv2d-41            [-1, 1, 56, 56]              73
          Sigmoid-42            [-1, 1, 56, 56]               0
================================================================
Total params: 93,345
Trainable params: 93,345
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 8.57
Params size (MB): 0.36
Estimated Total Size (MB): 8.98
----------------------------------------------------------------
'''''''''''''''
class UnetwithSelfattention_90k(nn.Module):
    def __init__(self,in_channels):
        super(UnetwithSelfattention_90k, self).__init__()
        self.in_channels=in_channels
        self.C1 = Conv_UNet(self.in_channels, 16) 
        self.D1 = DownSampling_UNet(16, 32) 
        self.C2 = Conv_UNet(32, 32)
        self.D2 = DownSampling_UNet(32, 64) 
        self.C3 = Conv_UNet(64, 64)
        self.attention = SelfAttention(64)
        self.U1 = UpSampling_UNet(64, 32) 
        self.C4 = Conv_UNet(64, 32)  
        self.U2 = UpSampling_UNet(32, 16) 
        self.C5 = Conv_UNet(32, 8)  
        self.pred = nn.Conv2d(8, 1, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()  

    def forward(self, x):
        R1 = self.C1(x) 
        R2 = self.C2(self.D1(R1)) 
        R3 = self.C3(self.D2(R2))  
        R3=self.attention(R3)
        up1 = self.C4(self.U1(R3, R2))  
        c = self.C5(self.U2(up1,R1))  
        return self.sigmoid(self.pred(c))  


'''''''''''''''''
Unet3D_90k:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1         [-1, 2, 1, 56, 56]              56
            Conv3d-2         [-1, 2, 1, 56, 56]             152
            Conv3d-3         [-1, 2, 1, 56, 56]             296
            Conv3d-4         [-1, 2, 1, 56, 56]              56
            Conv3d-5         [-1, 2, 1, 56, 56]             152
            Conv3d-6         [-1, 2, 1, 56, 56]             296
            Conv2d-7           [-1, 16, 56, 56]           1,744
       BatchNorm2d-8           [-1, 16, 56, 56]              32
           Dropout-9           [-1, 16, 56, 56]               0
        LeakyReLU-10           [-1, 16, 56, 56]               0
        Conv_UNet-11           [-1, 16, 56, 56]               0
           Conv2d-12           [-1, 32, 28, 28]           2,080
        LeakyReLU-13           [-1, 32, 28, 28]               0
DownSampling_UNet-14           [-1, 32, 28, 28]               0
           Conv2d-15           [-1, 32, 28, 28]           9,248
      BatchNorm2d-16           [-1, 32, 28, 28]              64
          Dropout-17           [-1, 32, 28, 28]               0
        LeakyReLU-18           [-1, 32, 28, 28]               0
        Conv_UNet-19           [-1, 32, 28, 28]               0
           Conv2d-20           [-1, 64, 14, 14]           8,256
        LeakyReLU-21           [-1, 64, 14, 14]               0
DownSampling_UNet-22           [-1, 64, 14, 14]               0
           Conv2d-23           [-1, 64, 14, 14]          36,928
      BatchNorm2d-24           [-1, 64, 14, 14]             128
          Dropout-25           [-1, 64, 14, 14]               0
        LeakyReLU-26           [-1, 64, 14, 14]               0
        Conv_UNet-27           [-1, 64, 14, 14]               0
           Conv2d-28           [-1, 32, 28, 28]           2,080
  UpSampling_UNet-29           [-1, 64, 28, 28]               0
           Conv2d-30           [-1, 32, 28, 28]          18,464
      BatchNorm2d-31           [-1, 32, 28, 28]              64
          Dropout-32           [-1, 32, 28, 28]               0
        LeakyReLU-33           [-1, 32, 28, 28]               0
        Conv_UNet-34           [-1, 32, 28, 28]               0
           Conv2d-35           [-1, 16, 56, 56]             528
  UpSampling_UNet-36           [-1, 32, 56, 56]               0
           Conv2d-37            [-1, 8, 56, 56]           2,312
      BatchNorm2d-38            [-1, 8, 56, 56]              16
          Dropout-39            [-1, 8, 56, 56]               0
        LeakyReLU-40            [-1, 8, 56, 56]               0
        Conv_UNet-41            [-1, 8, 56, 56]               0
           Conv2d-42            [-1, 1, 56, 56]              73
          Sigmoid-43            [-1, 1, 56, 56]               0
             UNet-44            [-1, 1, 56, 56]               0
================================================================
Total params: 83,025
Trainable params: 83,025
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 8.21
Params size (MB): 0.32
Estimated Total Size (MB): 8.57
----------------------------------------------------------------
'''''''''''''''''

class Unet3D_90k(nn.Module):
    def __init__(self):
        super(Unet3D_90k,self).__init__()
        self.conv3x3x3 = nn.Conv3d(1, 2, kernel_size=3, padding=(0,1,1)) #(batch_size,2,1,56,56)
        self.conv3x5x5 = nn.Conv3d(1, 2, kernel_size=(3,5,5), padding=(0,2,2)) #(batch_size,2,1,56,56)
        self.conv3x7x7 = nn.Conv3d(1, 2, kernel_size=(3,7,7), padding=(0,3,3)) #(batch_size,2,1,56,56)
        self.unet=Unet3D_90k(12)
    
    def forward(self,x): #x=torch.cat((emcal,hcal,trkn,trkp),dim=1) (4,56,56)
        x = x.unsqueeze(1)
        x_e_h_n = x[:,:,:3,:,:]
        x_e_h_p = x[:,:,[0,1,3],:,:]
        x2 = self.conv3x3x3(x_e_h_n)
        x3 = self.conv3x5x5(x_e_h_n)
        x4 = self.conv3x7x7(x_e_h_n)
        x5 = self.conv3x3x3(x_e_h_p)
        x6 = self.conv3x5x5(x_e_h_p)
        x7 = self.conv3x7x7(x_e_h_p)
        x = torch.cat((x2,x3,x4,x5,x6,x7),dim=1).view(-1,12,56,56) #(batch_size,12,56,56)
        x=self.unet(x)
        
        return x

In [None]:
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch


def train_one_epoch(model, optimizer, data_loader, device, epoch, loss_function):
    model.train()
    mean_loss = torch.zeros(1).to(device)
    
    data_loader = tqdm(data_loader, file=sys.stdout)
    
    for step, (batch_X, batch_Y) in enumerate(data_loader):
        optimizer.zero_grad()
        outputs=model(batch_X.to(device))
        loss=loss_function(outputs,batch_Y.to(device))
        loss.backward()
        optimizer.step()
        
        mean_loss = (mean_loss * step + loss.detach()) / (step + 1)  # update mean losses
        # 打印平均loss
        data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 7))
        
        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)
        
        
    return mean_loss.item()


@torch.no_grad()
def evaluate(model, data_loader, device, loss_function):
    mean_loss = torch.zeros(1).to(device)
    model.eval()
    val_loss = 0
    for batch_X, batch_Y in data_loader:
        outputs=model(batch_X.to(device))
        mean_loss += loss_function(outputs,batch_Y.to(device)).detach()
    mean_loss /= len(data_loader)
    return mean_loss.item()

@torch.no_grad()
def plot_image(net, data_loader, device, label):
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 5)
    net.eval()
    fig_list = []
    for batch_X, batch_Y in data_loader:
        outputs=net(batch_X.to(device)).detach()
        for i in range(plot_num):
            fig = plt.figure()
            fig.suptitle(label, fontsize=16)
            ax1 = fig.add_subplot(121)
            ax2 = fig.add_subplot(122)
            ax1.imshow(batch_Y[i].cpu().numpy().squeeze(),cmap='jet')
            ax1.axis('off')
            ax1.set_title('Ground Truth')
            ax2.imshow(outputs[i].cpu().numpy().squeeze(),cmap='jet')
            ax2.axis('off')
            ax2.set_title(f'Prediction')
            
            fig_list.append(fig)
        break
    return fig_list

In [None]:
import os
import math
import argparse
import random
import numpy as np
import tifffile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
import shutil

#from model import CNN3D
#from DataSet import MaxMinNormalizeGlobalPerChannel,MyDataSet, dataset_2
#from train_and_eval import train_one_epoch, evaluate,plot_image

os.environ['PYTHONHASHSEED'] = str(26)
random.seed(26)
np.random.seed(26)
torch.manual_seed(26)
torch.cuda.manual_seed(26)
torch.cuda.manual_seed_all(26) 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

def train(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    tb_writer = SummaryWriter(log_dir="runs/CNN_80k/Demo1")
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    # 定义训练以及预测时的预处理方法
    data_transform = {
        "without_jet": transforms.Compose([MaxMinNormalizeGlobalPerChannel()]),
        "jet": transforms.Compose([MaxMinNormalizeGlobalPerChannel()])}

    # 实例化训练数据集
    data_set = MyDataSet(img_dir=args.img_dir,
                        group_size=10000,
                        size_in = 10000,
                        splition = True,
                        split_shuffle = False,
                        transform=data_transform["without_jet"])
    train_dataset = dataset_2(data_set.train_X, data_set.train_Y)
    val_dataset = dataset_2(data_set.val_X, data_set.val_Y)
    test_dataset = dataset_2(data_set.test_X, data_set.test_Y)
    data_set_jet = MyDataSet(img_dir=args.jet_dir,
                                    group_size=1000,
                                    size_in = 1000,
                                    splition= False,
                                    split_shuffle = False,
                                    transform=data_transform["jet"])
    jet_dataset = dataset_2(data_set_jet.X, data_set_jet.Y)
    
    batch_size = args.batch_size
    # 计算使用num_workers的数量
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 0])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=nw)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=nw)
    
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=nw)
    
    jet_loader = torch.utils.data.DataLoader(jet_dataset,
                                            batch_size=batch_size,
                                            shuffle=False)
    
    # 实例化模型
    model = CNN_90k().to(device)

    # 将模型写入tensorboard
    init_img = torch.zeros((1, 4, 56, 56), device=device)
    tb_writer.add_graph(model, init_img)

    # 如果存在预训练权重则载入
    if args.weights is None:
        print("No weights file provided. Using random defaults.")
    else:
        model.load_state_dict(torch.load(args.weights))
        print("using pretrain-weights.")

    # 是否冻结权重
    if args.freeze_layers:
        print("freeze layers except fc layer.")
        for name, para in model.named_parameters():
            # 除最后的全连接层外，其他权重全部冻结
            if "decoder" not in name:
                para.requires_grad_(False)
        
    warmup_epochs_1 = 40
    warmup_epochs_2 = 80
    warmup_epochs_3 = 83
    learningrate = args.lr

    def lf_function(epoch): 
        if epoch < warmup_epochs_1:
            return 1
        elif epoch < warmup_epochs_2: 
            return 0.1
        elif epoch < warmup_epochs_3:
            return((epoch - warmup_epochs_2) / (warmup_epochs_3 - warmup_epochs_2)) * 0.5 + 0.1
        else:
            return(((1 + math.cos((epoch - warmup_epochs_3) * math.pi / (args.epochs - warmup_epochs_3))) / 2) * 0.5 + 0.1)
    optimizer = optim.Adam(model.parameters(), lr=learningrate)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf_function)
    loss_function = torch.nn.MSELoss()
    
    for epoch in range(args.epochs):
        # train
        train_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch,
                                    loss_function=loss_function)
        # update learning rate
        scheduler.step()

        # validate
        if args.patten == "train":
            test_loss = evaluate(model=model,
                    data_loader=val_loader,
                    device=device,
                    loss_function=loss_function)
        else:
            test_loss = evaluate(model=model,
                    data_loader=test_loader,
                    device=device,
                    loss_function=loss_function)

        # add loss, acc and lr into tensorboard
        print("[epoch {}] loss: {}".format(epoch, round(test_loss, 7)))
        tags = ["train_loss", "test_loss", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], test_loss, epoch)
        tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

        # add figure into tensorboard
        if (epoch + 1) % 10 == 0:
            fig_test = plot_image(net = model, 
                                data_loader = val_loader,
                                device = device,
                                label = "test")
            fig_jet = plot_image(net = model,
                                data_loader = jet_loader,
                                device = device,
                                label = "jet")

            if fig_test is not None:
                tb_writer.add_figure("predictions without jet",
                                    figure=fig_test,
                                    global_step=epoch)
            if fig_jet is not None:
                tb_writer.add_figure("predictions with jet",
                                    figure=fig_jet,
                                    global_step=epoch)

        if ((epoch+1) % args.saving_routine == 0) or (epoch == args.epochs-1):
            # save weights
            torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))

    num_cases_to_plot=5
    test_samples=list(test_loader)[:num_cases_to_plot]
    model.eval()
    predicted_images=[]
    true_images=[]
    print(len(test_samples))
    
    with torch.no_grad():
        for X_test, Y_test in test_samples:
            # outputs=model(X_test.to(device),edge_index.to(device))
            outputs=model(X_test.to(device))
            predicted_images.append(outputs.cpu().detach().numpy())
            true_images.append(Y_test.cpu().detach().numpy())
    predicted_images = np.concatenate(predicted_images, axis=0)
    true_images = np.concatenate(true_images, axis=0)
    
    IMAGE_NAME = 'Gauss_S1.00_NL0.30_B0.50'
    error_list = []
    pre_list = np.empty([])
    true_list = np.empty([])
    # Plotting the results
    HEIGHT = 56
    WIDTH =56
    fig, axes = plt.subplots(num_cases_to_plot, 2, figsize=(10, 4 * num_cases_to_plot))
    for i in range(num_cases_to_plot):
        # 假设只有单一通道要显示，可以通过 denormalized_predicted_images 和 denormalized_true_images 访问真实与预测结果
        #以下记得修改
        pred_img = predicted_images[i].reshape(HEIGHT, WIDTH) # 假设输出是单通道形式
        true_img = true_images[i, 0].reshape(HEIGHT, WIDTH) # 假设通道在第一维度
        # 绘制真实图像
        axes[i, 0].imshow(true_img) # 使用灰度图显示
        axes[i, 0].set_title(f'True Image {i+1}')
        axes[i, 0].axis('off')
    
        # 绘制预测图像
        axes[i, 1].imshow(pred_img) # 使用灰度图显示
        axes[i, 1].set_title(f'Predicted Image {i+1}')
        axes[i, 1].axis('off')
        error_list.append(pred_img-true_img)
        pre_list = np.append(pre_list,pred_img)
        true_list = np.append(true_list,true_img)
    plt.show()

    ''''''''''''
    #假设已有模型,没有加载好jet数据集
    ''''''''''''
    data_transform = {
						"without_jet": transforms.Compose([MaxMinNormalizeGlobalPerChannel()]),
						"jet": transforms.Compose([MaxMinNormalizeGlobalPerChannel()])}
    data_set_jet = MyDataSet(img_dir=args.jet_dir,
    						group_size=1000,
    						size_in = 1000,
    						splition= False,
    						split_shuffle = False,
    						transform=data_transform["jet"])
    X=data_set_jet.X #(1000,4,56,56)
    Y=data_set_jet.Y #(1000,1,56,56)
    
    dataset1=dataset_2(X,Y)
    #分割数据集
    TEST_NUM=1000
    BATCH_SIZE=200
    print(TEST_NUM)
    test_loader_jet = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=False)
    save_dir = 'predicted_images_CNN3D'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print(f"文件夹 '{save_dir}' 已创建。")
    else:
        shutil.rmtree(save_dir)
        print(f"文件夹 '{save_dir}' 及其内容已删除。")
        os.makedirs(save_dir)
        print(f"文件夹 '{save_dir}' 已重新创建。")
        
    model.eval()
    predicted_images=[]
    
    with torch.no_grad():
        for i,(X_test, Y_test) in enumerate(test_loader_jet):
            outputs=model(X_test.to(device))
            predicted_images.append(outputs.cpu().detach().numpy())
    predicted_images = np.concatenate(predicted_images, axis=0)
    #print(predicted_images.shape)
    #print(type(predicted_images))
    for i in range(predicted_images.shape[0]):
        # 取出第 i 张图 (形状 [1, 56, 56])
        image_2d = predicted_images[i, 0]  # 形状 [56, 56]
        # 转为 NumPy 数组 (默认是 float32 或 float64，具体看你的张量类型)
        image_np = image_2d
        # 直接写入 TIFF，不做任何缩放，保留原始精度
        save_path = os.path.join(save_dir, f"predict_{i}.tiff")
        tifffile.imwrite(save_path, image_np)

    
class Args:
    def __init__(self):
        self.num_classes = 5
        self.epochs = 100
        self.saving_routine = 20
        self.batch_size = 400
        self.lr = 0.001
        self.patten = "Parameter"
        self.img_dir = '/kaggle/input/gauss-s1-00-nl0-30-b0-50'  # 修改为你的图片目录
        self.jet_dir = '/kaggle/input/gauss-s1-00-nl0-30-b0-50-jet'    # 修改为你的Jet目录
        self.weights = None  # 如果有预训练权重，修改为权重路径
        self.freeze_layers = False
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

opt = Args()
train(opt)

In [None]:
#文件夹内1000张tiff图片转化为txt文件,此处以CNN3D为例
import tifffile as tiff
import sys
import math
directory_path = '/kaggle/working/predicted_images_CNN3D_80k'
sys.stdout=open('/kaggle/working/output_CNN3D_80k_predict.txt','w')
images = []
nums = []
for num in range(1000):
    filename = f"predict_{str(num)}.tiff"
    file_path = os.path.join(directory_path, filename)
    image = tiff.imread(file_path)
    images.append(image)
    num=0
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            pixel_value = image[i, j]
            if pixel_value>0:
                num=num+1
    nums.append(num)
for i in range(len(nums)):
    print(nums[i])
for num in range(1000):
    filename = f"precdict_{str(num)}.tiff"
    file_path = os.path.join(directory_path, filename)
    image = tiff.imread(file_path)
    images.append(image)
    for i in range(image.shape[0]): 
        for j in range(image.shape[1]):
            pixel_value = image[i, j]
            if pixel_value>0:
                pz=i+0.5-image.shape[0]/2
                py=(56/2/math.pi)*math.cos((j+0.5)/56*2*math.pi)
                px=(56/2/math.pi)*math.sin((j+0.5)/56*2*math.pi)
                pr=math.sqrt(px*px+py*py+pz*pz)
                px=px/pr
                py=py/pr
                pz=pz/pr
                print(f'{px} {py} {pz} {pixel_value}') 

In [None]:
#无需运行，结果已记录
import os
import tifffile as tiff
import sys
import math
directory_path = '/kaggle/input/gauss-s1-00-nl0-30-b0-50-jet'
sys.stdout=open('/kaggle/working/output_truth.txt','w')
images = []
nums = []
for num in range(1000):
    filename = f"truth_{str(num)}.tiff"
    file_path = os.path.join(directory_path, filename)
    image = tiff.imread(file_path)
    images.append(image)
    num=0
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            pixel_value = image[i, j]
            if pixel_value>0:
                num=num+1
    nums.append(num)
for i in range(len(nums)):
    print(nums[i])
for num in range(1000):
    filename = f"truth_{str(num)}.tiff"
    file_path = os.path.join(directory_path, filename)
    image = tiff.imread(file_path)
    images.append(image)
    for i in range(image.shape[0]): 
        for j in range(image.shape[1]):
            pixel_value = image[i, j]
            if pixel_value>0:
                pz=i+0.5-image.shape[0]/2
                py=(56/2/math.pi)*math.cos((j+0.5)/56*2*math.pi)
                px=(56/2/math.pi)*math.sin((j+0.5)/56*2*math.pi)
                pr=math.sqrt(px*px+py*py+pz*pz)
                px=px/pr
                py=py/pr
                pz=pz/pr
                print(f'{px} {py} {pz} {pixel_value}') 

In [None]:
#无需运行，结果已记录
import os
import tifffile as tiff
import sys
import math
directory_path = '/kaggle/input/gauss-s1-00-nl0-30-b0-50-jet'
sys.stdout=open('/kaggle/working/output_jet.txt','w')
images = []
nums = []
for num in range(1000):
    filename = f"jet_{str(num)}.tiff"
    file_path = os.path.join(directory_path, filename)
    image = tiff.imread(file_path)
    images.append(image)
    num=0
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            pixel_value = image[i, j]
            if pixel_value>0:
                num=num+1
    nums.append(num)
for i in range(len(nums)):
    print(nums[i])
for num in range(1000):
    filename = f"jet_{str(num)}.tiff"
    file_path = os.path.join(directory_path, filename)
    image = tiff.imread(file_path)
    images.append(image)
    for i in range(image.shape[0]): 
        for j in range(image.shape[1]):
            pixel_value = image[i, j]
            if pixel_value>0:
                pz=i+0.5-image.shape[0]/2
                py=(56/2/math.pi)*math.cos((j+0.5)/56*2*math.pi)
                px=(56/2/math.pi)*math.sin((j+0.5)/56*2*math.pi)
                pr=math.sqrt(px*px+py*py+pz*pz)
                px=px/pr
                py=py/pr
                pz=pz/pr
                print(f'{px} {py} {pz} {pixel_value}') 


In [None]:
#此处以CNN3D为例
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import math
import random
from math import pi
import os
import sys
sys.stdout=open('/kaggle/working/errors_CNN3D_80k_jet_2dis.txt','w')
def chord_length(x1, y1, x2, y2, r):
    L = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
    d = math.sqrt(r ** 2 - (L / 2) ** 2)
    cos_theta = (L ** 2 + r ** 2 - d ** 2) / (2 * L * r)
    theta = math.acos(cos_theta)
    s = theta * r
    return s
    
def calculate_spatial_coordinates(rap, phi, Pt, E):
    p_T = Pt 
    p_x = p_T * math.cos(phi)
    p_y = p_T * math.sin(phi)
    p_z = E * math.sinh(rap) 
    x = p_x / E 
    y = p_y / E 
    z = p_z / E 
    return x, y, z

def plot_filled_circle_in_3d(px=0, py=0, area=0, height=0, color="black"):
    theta = np.linspace(0, 2 * np.pi, 100) 
    radius= math.sqrt(area/2/pi)
    ax.set_xlim(-3, 3) 
    ax.set_ylim(-1, 7)
    ax.set_zlim(0, 200) 
    x = radius * np.cos(theta)+px 
    y = radius * np.sin(theta)+py  
    z = np.zeros_like(x) 
    ax.plot(x, y, z, color=color)
    for t in theta:
        x_fill = np.linspace(-radius, radius, 50)
        y_fill = np.sqrt(radius ** 2 - x_fill ** 2) 
        z_fill = np.zeros_like(x_fill)
        ax.plot_surface(np.array([x_fill+px, x_fill+px]), np.array([y_fill+py, -y_fill+py]), np.array([z_fill, z_fill]),
                        color=color, alpha=0.5)
    x = 0.1*np.cos(theta)+px  
    y = 0.1*np.sin(theta)+py  
    z = np.linspace(0, height, 100) 
    X, Z = np.meshgrid(x, z)
    Y = np.meshgrid(y, z)[0]
    ax.plot_surface(X, Y, Z, color=color, alpha=0.8)
    Z_top = np.full_like(X, z[-1])
    Z_bottom = np.full_like(X, z[0])
    ax.plot_surface(X, Y, Z_top, color=color, alpha=0.5)
    ax.plot_surface(X, Y, Z_bottom, color=color, alpha=0.5)
    ax.set_xlabel('y')
    ax.set_ylabel('φ')
    ax.set_zlabel('Pt(GeV)')

# 打开文件
with open('kagge/input/result_truth.txt', 'r') as file:
    # 读取文件内容
    content = file.read()
    # 分割内容为行
    lines = content.strip().split('\n')
    # 读取每行的数字
    truth_numbers = []
    for line in lines:
        # 分割行中的数字
        truth_numbers_in_line = line.split()
        # 将数字转换为浮点数列表
        truth_numbers.extend(map(float, truth_numbers_in_line))
    # 打印所有数字
    #for number in numbers:
        #print(number)
truth_numbers_size = len(truth_numbers)
#print(truth_numbers_size)
num=0
truth_clusters= []
truth_raps= []
truth_phis= []
truth_pts= []
truth_E= []
truth_areas= []
truth_x= []
truth_y= []
truth_z= []
p=0
for i in range(truth_numbers_size):
    if p!=0:
        p=p-1
        continue
    if num==0:
        truth_clusters.append(int(truth_numbers[i]))
        num=int(truth_numbers[i])
        continue
    num=num-1;
    truth_raps.append(truth_numbers[i])
    truth_phis.append(truth_numbers[i+1])
    truth_pts.append(truth_numbers[i+2])
    truth_E.append(truth_numbers[i+3])
    truth_areas.append(truth_numbers[i+4])
    y, x, z = calculate_spatial_coordinates(truth_numbers[i], truth_numbers[i+1], truth_numbers[i+2],truth_numbers[i+3])
    truth_y.append(y)
    truth_x.append(x)
    truth_z.append(z)
    p=4
truth_clusters_size = len(truth_clusters)
#print(truth_clusters_size)
truth_pos=0
truth_clusters[-1]=0


with open('kagge/input/result_jet.txt', 'r') as file:
    # 读取文件内容
    content = file.read()
    # 分割内容为行
    lines = content.strip().split('\n')
    # 读取每行的数字
    jet_numbers = []
    for line in lines:
        # 分割行中的数字
        jet_numbers_in_line = line.split()
        # 将数字转换为浮点数列表
        jet_numbers.extend(map(float, jet_numbers_in_line))
    # 打印所有数字
    #for number in numbers:
        #print(number)
jet_numbers_size = len(jet_numbers)
#print(jet_numbers_size)
num=0
jet_clusters= []
jet_raps= []
jet_phis= []
jet_pts= []
jet_E= []
jet_areas= []
jet_x= []
jet_y= []
jet_z= []
p=0
for i in range(jet_numbers_size):
    if p!=0:
        p=p-1
        continue
    if num==0:
        jet_clusters.append(int(jet_numbers[i]))
        num=int(jet_numbers[i])
        continue
    num=num-1;
    jet_raps.append(jet_numbers[i])
    jet_phis.append(jet_numbers[i+1])
    jet_pts.append(jet_numbers[i+2])
    jet_E.append(jet_numbers[i+3])
    jet_areas.append(jet_numbers[i+4])
    y, x, z = calculate_spatial_coordinates(jet_numbers[i], jet_numbers[i+1], jet_numbers[i+2],jet_numbers[i+3])
    jet_y.append(y)
    jet_x.append(x)
    jet_z.append(z)
    p=4
jet_clusters_size = len(jet_clusters)
#print(jet_clusters_size)
jet_pos=0
jet_clusters[-1]=0

with open('kagge/input/result_predict.txt', 'r') as file:
    # 读取文件内容
    content = file.read()
    # 分割内容为行
    lines = content.strip().split('\n')
    # 读取每行的数字
    predict_numbers = []
    for line in lines:
        # 分割行中的数字
        predict_numbers_in_line = line.split()

        # 将数字转换为浮点数列表
        predict_numbers.extend(map(float, predict_numbers_in_line))

    # 打印所有数字
    #for number in numbers:
        #print(number)
predict_numbers_size = len(predict_numbers)
#print(f'predict_numbers_size:{predict_numbers_size}')
num=0
predict_clusters= []
predict_raps= []
predict_phis= []
predict_pts= []
predict_E= []
predict_areas= []
predict_x= []
predict_y= []
predict_z= []
p=0
for i in range(predict_numbers_size):
    if p!=0:
        p=p-1
        continue
    if num==0:
        predict_clusters.append(int(predict_numbers[i]))
        num=int(predict_numbers[i])
        continue
    num=num-1;
    predict_raps.append(predict_numbers[i])
    predict_phis.append(predict_numbers[i+1])
    predict_pts.append(predict_numbers[i+2])
    predict_E.append(predict_numbers[i+3])
    predict_areas.append(predict_numbers[i+4])
    y, x, z = calculate_spatial_coordinates(predict_numbers[i], predict_numbers[i+1], predict_numbers[i+2],predict_numbers[i+3])
    predict_y.append(y)
    predict_x.append(x)
    predict_z.append(z)
    p=4
predict_clusters_size = len(predict_clusters)
#print(predict_clusters_size)
predict_pos=0
predict_clusters[-1]=0

def distance(x1, y1, x2, y2):
    """Calculate the Euclidean distance between two points in 2D space."""
    return ((x1 - x2)**2 + (y1 - y2)**2)**0.5
def cos_between_vectors(x1,y1,x2, y2):
    # 计算点积
    dot_product = x1 * x2 + y1 * y2
    # 计算两个向量的模
    magnitude_v1 = math.sqrt(x1**2 + y1**2)
    magnitude_v2 = math.sqrt(x2**2 + y2**2)
    # 计算夹角的余弦值
    cos_angle = dot_product / (magnitude_v1 * magnitude_v2)
    return cos_angle

output_folder='output_figure'
for i in range(truth_clusters_size):
    truth_pos=truth_pos+truth_clusters[i-1]
    jet_pos = jet_pos + jet_clusters[i - 1]
    predict_pos = predict_pos + predict_clusters[i - 1]
    if jet_clusters[i]==0:
        continue
    truth_min_cos =-1
    predict_min_cos = -1
    truth_min_distance = math.sqrt(56*56+56*56)
    predict_min_distance = math.sqrt(56 * 56 + 56 * 56)
    truth_id=0
    for j in range(truth_clusters[i]):
       if truth_min_distance>distance(x1=truth_raps[j+ truth_pos],y1=truth_phis[j+ truth_pos],x2=jet_raps[jet_pos],y2=jet_phis[jet_pos]):
           truth_id=j
           truth_min_distance = distance(x1=truth_raps[j + truth_pos], y1=truth_phis[j + truth_pos], x2=jet_raps[jet_pos],y2=jet_phis[jet_pos])
    for j in range(predict_clusters[i]):
       if predict_min_distance > distance(x1=predict_raps[j + predict_pos], y1=predict_phis[j + predict_pos],x2=jet_raps[jet_pos], y2=jet_phis[jet_pos]):
           predict_id = j
           predict_min_distance = distance(x1=predict_raps[j + predict_pos], y1=predict_phis[j + predict_pos],x2=jet_raps[jet_pos], y2=jet_phis[jet_pos])
    x=jet_y[jet_pos]
    y=jet_x[jet_pos]
    z=jet_z[jet_pos]
    r=56/2/math.pi
    y=-y;
    theta=math.atan(x/math.fabs(y))
    dis=(math.pi-theta)*r
    x=truth_y[truth_id+truth_pos]
    y=truth_x[truth_id+truth_pos]
    z=truth_z[truth_id+truth_pos]
    r = 56 / 2 / math.pi
    y = -y;
    theta = math.atan(x / math.fabs(y))
    dis = (math.pi - theta) * r
    dis1=dis
    z1=z
    x1=x
    y1=y
    x=predict_y[predict_id+predict_pos]
    y=predict_x[predict_id+predict_pos]
    z=predict_z[predict_id+predict_pos]
    r = 56 / 2 / math.pi
    y = -y;
    theta = math.atan(x / math.fabs(y))
    dis = (math.pi - theta) * r
    # print(f'x:{x*r/math.sqrt(x*x+y*y)}  y:{y*r/math.sqrt(x*x+y*y)} width:{z+28}  r:{x*x+y*y}');
    #print(f'width:{dis}  height:{28 + z}');
    print(f'{math.sqrt((dis1-dis)*(dis1-dis)+(z1-z)*(z1-z))}')
   # print(f'{math.sqrt((x1 - x) * (x1 -x) + (y1-y)*(y1-y)+(z1 - z) * (z1 - z))}')
   # print(f'{(predict_E[predict_id+predict_pos]-truth_E[truth_id+truth_pos])/truth_E[truth_id+truth_pos]}')
    #print(f'{truth_pts[truth_id+truth_pos]}   {predict_pts[predict_id+predict_pos]}')