In [1]:
from scipy.io import loadmat
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import torch
from torch import nn
from mamba_ssm import Mamba
import math
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import cohen_kappa_score
import argparse
import os
import random
import time
from torchvision import models,transforms
import utils.data_load_operate as data_load_operate
from utils.Loss import head_loss,resize
from utils.evaluation import Evaluator
from utils.HSICommonUtils import normlize3D, ImageStretching
from utils.setup_logger import setup_logger
from utils.visual_predict import visualize_predict
from PIL import Image
from calflops import calculate_flops
import cv2
from typing import Tuple,Union
from GPUtil import showUtilization as gpu_usage
from torch.cuda.amp import GradScaler, autocast

In [2]:
class SpeMamba(nn.Module):
    def __init__(self,channels, token_num=8, use_residual=True, group_num=4):
        super(SpeMamba, self).__init__()
        self.token_num = token_num
        self.use_residual = use_residual

        self.group_channel_num = math.ceil(channels/token_num)
        self.channel_num = self.token_num * self.group_channel_num

        self.mamba = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters
                            d_model=self.group_channel_num,  # Model dimension d_model
                            d_state=16,  # SSM state expansion factor
                            d_conv=4,  # Local convolution width
                            expand=2,  # Block expansion factor
                            )

        self.proj = nn.Sequential(
            nn.GroupNorm(group_num, self.channel_num),
            nn.SiLU()
        )

    def padding_feature(self,x):
        B, C, H, W = x.shape
        if C < self.channel_num:
            pad_c = self.channel_num - C
            pad_features = torch.zeros((B, pad_c, H, W)).to(x.device)
            cat_features = torch.cat([x, pad_features], dim=1)
            return cat_features
        else:
            return x

    def forward(self,x):
        x_pad = self.padding_feature(x)
        x_pad = x_pad.permute(0, 2, 3, 1).contiguous()
        B, H, W, C_pad = x_pad.shape
        x_flat = x_pad.view(B * H * W, self.token_num, self.group_channel_num)
        x_flat = self.mamba(x_flat)
        x_recon = x_flat.view(B, H, W, C_pad)
        x_recon = x_recon.permute(0, 3, 1, 2).contiguous()
        x_proj = self.proj(x_recon)
        if self.use_residual:
            return x + x_proj
        else:
            return x_proj


class SpaMamba(nn.Module):
    def __init__(self,channels,use_residual=True,group_num=4,use_proj=True):
        super(SpaMamba, self).__init__()
        self.use_residual = use_residual
        self.use_proj = use_proj
        self.mamba = Mamba(  # This module uses roughly 3 * expand * d_model^2 parameters
                           d_model=channels,  # Model dimension d_model
                           d_state=16,  # SSM state expansion factor
                           d_conv=4,  # Local convolution width
                           expand=2,  # Block expansion factor
                           )
        if self.use_proj:
            self.proj = nn.Sequential(
                nn.GroupNorm(group_num, channels),
                nn.SiLU()
            )

    def forward(self,x):
        x_re = x.permute(0, 2, 3, 1).contiguous()
        B,H,W,C = x_re.shape
        x_flat = x_re.view(1,-1, C)
        x_flat = self.mamba(x_flat)

        x_recon = x_flat.view(B, H, W, C)
        x_recon = x_recon.permute(0, 3, 1, 2).contiguous()
        if self.use_proj:
            x_recon = self.proj(x_recon)
        if self.use_residual:
            return x_recon + x
        else:
            return x_recon


class BothMamba(nn.Module):
    def __init__(self,channels,token_num,use_residual,group_num=4,use_att=True):
        super(BothMamba, self).__init__()
        self.use_att = use_att
        self.use_residual = use_residual
        if self.use_att:
            self.weights = nn.Parameter(torch.ones(2) / 2)
            self.softmax = nn.Softmax(dim=0)

        self.spa_mamba = SpaMamba(channels,use_residual=use_residual,group_num=group_num)
        self.spe_mamba = SpeMamba(channels,token_num=token_num,use_residual=use_residual,group_num=group_num)

    def forward(self,x):
        spa_x = self.spa_mamba(x)
        spe_x = self.spe_mamba(x)
        if self.use_att:
            weights = self.softmax(self.weights)
            fusion_x = spa_x * weights[0] + spe_x * weights[1]
        else:
            fusion_x = spa_x + spe_x
        if self.use_residual:
            return fusion_x + x
        else:
            return fusion_x


In [3]:
# 可视化图片保存
def vis_a_image(gt_vis,pred_vis,save_single_predict_path,save_single_gt_path,only_vis_label=False):
    visualize_predict(gt_vis,pred_vis,save_single_predict_path,save_single_gt_path,only_vis_label=only_vis_label)
    visualize_predict(gt_vis,pred_vis,save_single_predict_path.replace('.png','_mask.png'),save_single_gt_path,only_vis_label=True)

# 设置种子使得所有参数的初始化相同，能够复现
def setup_seed(seed):
    # PyTorch CPU随机种子
    torch.manual_seed(seed)
    # 所有GPU的随机种子
    torch.cuda.manual_seed_all(seed)
    # Python哈希种子（影响字典等数据结构的行为）
    os.environ['PYTHONHASHSEED'] = str(seed)
    # NumPy随机种子
    np.random.seed(seed)
    # Python内置随机种子
    random.seed(seed)
    # 启用确定性算法（降低性能但保证可重复）
    torch.backends.cudnn.deterministic = True
    # 关闭自动优化（固定卷积算法选择）
    torch.backends.cudnn.benchmark = False

# 函数返回解析后的参数对象args，这样主程序可以通过args.dataset_index等方式访问这些参数的值
def get_parser():
    # 创建参数解析器实例
    parser = argparse.ArgumentParser(
        # 添加这两行配置
        allow_abbrev=False,
        add_help=False
    )
    # 添加参数定义
    parser.add_argument('--dataset_index', type=int, default=2)      # 数据集编号
    parser.add_argument('--data_set_path', type=str, default='./data') # 数据集路径
    parser.add_argument('--work_dir', type=str, default='./')        # 工作目录
    parser.add_argument('--lr', type=float, default=0.0003)         # 学习率
    parser.add_argument('--max_epoch', type=int, default=200)       # 最大训练轮次
    parser.add_argument('--train_samples', type=int, default=30)    # 每类训练样本数
    parser.add_argument('--val_samples', type=int, default=10)      # 每类验证样本数
    parser.add_argument('--exp_name', type=str, default='RESULT_FP16_Conv')     # 实验名称
    parser.add_argument('--record_computecost', type=bool, default=True) # 是否记录计算成本
    # 解析命令行参数
    args, _ = parser.parse_known_args()  # 改为解析已知参数
    return args


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args = get_parser()

record_computecost = args.record_computecost
exp_name = args.exp_name
dataset_index = args.dataset_index
max_epoch = args.max_epoch
learning_rate = args.lr
work_dir = args.work_dir

#seed_list = [0,1,2,3,4,5,6,7,8,9]
seed_list = [0] 

num_list = [args.train_samples, args.val_samples]
net_name = 'MambaHSI'

paras_dict = {'net_name':net_name,'dataset_index':dataset_index,'num_list':num_list,
              'lr':learning_rate,'seed_list':seed_list}


                      # 0        1         2         3        4
data_set_name_list = ['UP', 'HanChuan', 'HongHu', 'Houston']
data_set_name = data_set_name_list[dataset_index]

if data_set_name in ['HanChuan','Houston']:
    split_image = True
else:
    split_image = False
transform = transforms.Compose([
    # transforms.Resize((2048, 1024)),  # 调整图像尺寸（已注释）
    transforms.ToTensor(),              # 核心转换：将图像转为张量，将 PIL 图像或 NumPy 数组转换为张量，并自动归一化像素值到 [0,1] 范围
    # 标准化操作（以下两项均被注释）
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    # transforms.Normalize(mean=[123.6750, 116.2800, 103.5300], std=[58.395, 57.120, 57.3750]),
])

In [4]:
class HalfPrecisionConv(nn.Conv2d):
    """自定义半精度卷积层"""
    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
        # 接受所有可能的卷积参数（padding、stride等）
        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
        
        # 强制参数为float16（需先完成父类初始化）
        self.weight = nn.Parameter(self.weight.half())
        if self.bias is not None:
            self.bias = nn.Parameter(self.bias.half())
            
    def forward(self, x):
    # 确保输出梯度类型
        return super().forward(x.half()).float()  # 输出转回float32
            
class SpectralDecomp(nn.Module):
    """支持多通道高光谱输入的频域分解"""
    def __init__(self, spatial_radius=7, texture_contrast=0.8):
        super().__init__()
        self.spatial_radius = spatial_radius
        self.texture_contrast = texture_contrast
        
        # 多通道高斯模糊
        self.gaussian_blur = nn.Sequential(
            nn.ReflectionPad2d(spatial_radius),
            nn.Conv2d(1, 1, kernel_size=2*spatial_radius+1, bias=False)
        )
        self._init_gaussian_weights()
        
    def _init_gaussian_weights(self):
        """初始化适用于任意通道数的高斯核"""
        sigma = self.spatial_radius/3
        x = torch.arange(-self.spatial_radius, self.spatial_radius+1)
        kernel = torch.exp(-x.pow(2)/(2*sigma**2))
        kernel = kernel.view(1,1,-1) * kernel.view(1,1,-1,1)
        self.gaussian_blur[1].weight.data = kernel / kernel.sum()
        
    def forward(self, x):
        """
        输入: (B, C, H, W) 张量
        输出: (low_freq, high_freq) 元组
        """
        B, C, H, W = x.shape
        
        # 多通道处理
        x_flat = x.view(B*C, 1, H, W)  # 展平为(B*C, 1, H, W)
        
        # 低频分量
        low = self.gaussian_blur(x_flat)
        low = low.view(B, C, H, W)  # 恢复原始形状
        
        # 高频分量
        high = (x - low) * self.texture_contrast
        
        return low, high
        
class Mamba_float16(nn.Module):
    def __init__(self,in_channels=128,hidden_dim=64,num_classes=10,use_residual=True,mamba_type='both',token_num=4,group_num=4,use_att=True):
        super(Mamba_float16, self).__init__()
        self.mamba_type = mamba_type
        
        self.patch_embedding = nn.Sequential(nn.Conv2d(in_channels=in_channels,out_channels=hidden_dim,kernel_size=1,stride=1,padding=0),
                                             nn.GroupNorm(group_num,hidden_dim),
                                             nn.SiLU())
        
        self.high_channel_128_32 = nn.Sequential(nn.Conv2d(in_channels=hidden_dim,out_channels=32,kernel_size=1,stride=1,padding=0),
                                             nn.GroupNorm(group_num,32),
                                             nn.SiLU())
        
        self.high_channel_160_32 = nn.Sequential(nn.Conv2d(in_channels=160,out_channels=32,kernel_size=1,stride=1,padding=0),
                                             nn.GroupNorm(group_num,32),
                                             nn.SiLU())
        
        self.low_channel_160_128 = nn.Sequential(HalfPrecisionConv(in_channels=160,out_channels=128,kernel_size=1,stride=1,padding=0),
                                             nn.GroupNorm(group_num,128),
                                             nn.SiLU())
        
        self.low_channel_128_128 = nn.Sequential(HalfPrecisionConv(in_channels=128,out_channels=128,kernel_size=1,stride=1,padding=0),
                                             nn.GroupNorm(group_num,128),
                                             nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
                                             nn.SiLU())

        self.decomp = SpectralDecomp(spatial_radius=7)
        
        if mamba_type == 'spa':
            self.mamba = nn.Sequential(SpaMamba(hidden_dim,use_residual=use_residual,group_num=group_num),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
                                        SpaMamba(hidden_dim,use_residual=use_residual,group_num=group_num),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
                                        SpaMamba(hidden_dim,use_residual=use_residual,group_num=group_num),
                                        )
        elif mamba_type == 'spe':
            self.mamba = nn.Sequential(SpeMamba(hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0),

                                        SpeMamba(hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0),

                                        SpeMamba(hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num)
                                        )

        elif mamba_type=='both':
            self.high_mamba = nn.Sequential(BothMamba(channels=32,token_num=token_num,use_residual=use_residual,group_num=group_num,use_att=use_att),
                                       nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
                                       )
            
            self.high_mamba_finally = BothMamba(channels=32,token_num=token_num,use_residual=use_residual,group_num=group_num,use_att=use_att)
            self.low_mamba_finally = BothMamba(channels=hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num,use_att=use_att)
            
        self.cls_head = nn.Sequential(nn.Conv2d(in_channels=hidden_dim+32, out_channels=64, kernel_size=1, stride=1, padding=0),
                                      nn.GroupNorm(group_num,64),
                                      nn.SiLU(),
                                      nn.Conv2d(in_channels=64,out_channels=num_classes,kernel_size=1,stride=1,padding=0))

    def forward(self,x):
        #卷积层
        x = self.patch_embedding(x)#128通道
        #高低频拆解
        low,high = self.decomp(x)
        high = self.high_channel_128_32(high)
        high_mamba = self.high_mamba(high)
        low_mamba = self.low_channel_128_128(low)
        combined = torch.cat([low_mamba, high_mamba], dim=1)
        #高低频拆解
        low,high = self.decomp(combined)
        high = self.high_channel_160_32(high)
        low = self.low_channel_160_128(low)
        high_mamba = self.high_mamba(high)
        low_mamba = self.low_channel_128_128(low) 

        combined = torch.cat([low_mamba, high_mamba], dim=1)
        #高低频拆解
        low,high = self.decomp(combined)
        high = self.high_channel_160_32(high)
        low = self.low_channel_160_128(low)
        high_mamba = self.high_mamba_finally(high)
        low_mamba = self.low_mamba_finally(low) 

        combined = torch.cat([low_mamba, high_mamba], dim=1)
        logits = self.cls_head(combined)
        return logits

In [5]:
data_set_path = args.data_set_path
data, gt = data_load_operate.load_data(data_set_name, data_set_path)
img = ImageStretching(data)#将每个通道的数据变成0-255
x = transform(np.array(img))#转化为张量
original = x.permute(1, 2, 0)
x = x.unsqueeze(0).float()

gt_reshape = gt.reshape(-1)
height, width, channels = data.shape
ratio_list = [0.1, 0.01]
class_count = max(np.unique(gt))
flag_list = [1, 0]
train_data_index, val_data_index, test_data_index, all_data_index = data_load_operate.sampling(ratio_list,num_list,gt_reshape,class_count,flag_list[0])
index = (train_data_index, val_data_index, test_data_index)
train_label, val_label, test_label = data_load_operate.generate_image_iter(data, height, width, gt_reshape, index)
train_label = train_label.to(device)
test_label = test_label.to(device)
val_label = val_label.to(device)
model = Mamba_float16(in_channels=x.shape[1],num_classes=9,hidden_dim=128)

In [5]:
class HybridOptimizer:
    def __init__(self, model, base_lr=0.0003):
        # 分离参数组
        self.fp16_params = []
        self.fp32_params = []
        
        # 自动识别参数类型
        for name, param in model.named_parameters():
            if param.dtype == torch.float16:
                self.fp16_params.append(param)
                param.requires_grad_(True)  # 确保梯度计算
            else:
                self.fp32_params.append(param)
        
        # 仅初始化FP32优化器
        self.optimizer = torch.optim.Adam(
            self.fp32_params, 
            lr=base_lr,
        )
        
        # FP16手动优化参数
        self.fp16_lr = base_lr * 10  # 更高学习率
        self.grad_clip = 1.0

    def step(self):
        """执行参数更新"""
        # 先更新FP32参数
        self.optimizer.step()
        
        # 手动更新FP16参数
        with torch.no_grad():
            for param in self.fp16_params:
                if param.grad is not None:
                    # 梯度裁剪
                    grad = torch.clamp(param.grad, -self.grad_clip, self.grad_clip)
                    # 参数更新 (SGD示例)
                    param.data.sub_(self.fp16_lr * grad)
                    
                    # 可选：Adam风格更新
                    # self._adam_update(param, grad)
    
    def zero_grad(self):
        """统一清空梯度"""
        self.optimizer.zero_grad()
        for param in self.fp16_params:
            if param.grad is not None:
                param.grad.detach_()
                param.grad.zero_()

In [6]:
save_folder = os.path.join(work_dir, exp_name, net_name, data_set_name)
if not os.path.exists(save_folder):
    os.makedirs(save_folder)
    print("makedirs {}".format(save_folder))
    
save_log_path = os.path.join(save_folder,'train_tr{}_val{}.log'.format(num_list[0],num_list[1]))
# 创建与数据集同名的日志记录器，并指定日志文件路径
logger = setup_logger(
    name='{}'.format(data_set_name),  # 日志器名称 = 数据集名称
    logfile=save_log_path            # 日志文件保存路径
)
torch.cuda.empty_cache()#手动释放 GPU 未使用缓存内存 的函数
logger.info(save_folder)

data_set_path = args.data_set_path
data, gt = data_load_operate.load_data(data_set_name, data_set_path)
gt_reshape = gt.reshape(-1)
height, width, channels = data.shape
img = ImageStretching(data)#将每个通道的数据变成0-255
class_count = max(np.unique(gt))
ratio_list = [0.1, 0.01]
flag_list = [1, 0]

OA_ALL = []
AA_ALL = []
KPP_ALL = []
EACH_ACC_ALL = []
Train_Time_ALL = []
Test_Time_ALL = []
CLASS_ACC = np.zeros([len(seed_list), class_count])
evaluator = Evaluator(num_class=class_count)
loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1)

for exp_idx,curr_seed in enumerate(seed_list):
    setup_seed(curr_seed)
    single_experiment_name = 'run{}_seed{}'.format(str(exp_idx), str(curr_seed))
    save_single_experiment_folder = os.path.join(save_folder, single_experiment_name)
    if not os.path.exists(save_single_experiment_folder):
        os.mkdir(save_single_experiment_folder)
    save_vis_folder = os.path.join(save_single_experiment_folder, 'vis')
    if not os.path.exists(save_vis_folder):
        os.makedirs(save_vis_folder)
        print("makedirs {}".format(save_vis_folder))
    save_weight_path = os.path.join(save_single_experiment_folder, "best_tr{}_val{}.pth".format(num_list[0], num_list[1]))
    results_save_path = os.path.join(save_single_experiment_folder, 'result_tr{}_val{}.txt'.format(num_list[0], num_list[1]))
    predict_save_path = os.path.join(save_single_experiment_folder, 'pred_vis_tr{}_val{}.png'.format(num_list[0], num_list[1]))
    gt_save_path = os.path.join(save_single_experiment_folder, 'gt_vis_tr{}_val{}.png'.format(num_list[0], num_list[1]))

    train_data_index, val_data_index, test_data_index, all_data_index = data_load_operate.sampling(ratio_list,num_list,gt_reshape,class_count,flag_list[0])
    index = (train_data_index, val_data_index, test_data_index)
    train_label, val_label, test_label = data_load_operate.generate_image_iter(data, height, width, gt_reshape, index)

    #建立模型
    model = Mamba_float16(in_channels=channels,num_classes=class_count,hidden_dim=128).to(device)
    
    logger.info(paras_dict)
    logger.info(model)
    
    x = transform(np.array(img))#转化为张量
    x = x.unsqueeze(0).float().to(device)
    
    train_label = train_label.to(device)
    test_label = test_label.to(device)
    val_label = val_label.to(device)

    # ############################################
    # val_label = test_label
    # ############################################
    train_loss_list = [100]
    train_acc_list = [0]
    val_loss_list = [100]
    val_acc_list = [0]
    
    optimizer = HybridOptimizer(model)

    logger.info(optimizer)
    best_loss = 99999


    if record_computecost:#计算运算成本
        model.eval()
        flops, macs1, para = calculate_flops(model=model,
                                             input_shape=(1, x.shape[1], x.shape[2], x.shape[3]), )
        logger.info("para:{}\n,flops:{}".format(para, flops))#FLOPs	浮点运算次数 (Floating Point Operations)，衡量模型计算复杂度
                                                             #MACs	乘加运算次数 (Multiply-Accumulate Operations)，硬件性能参考指标
                                                             #Parameters	模型参数量，决定模型大小和内存占用

    tic1 = time.perf_counter()#记录开始时间
    best_val_acc = 0

    
    for epoch in range(max_epoch):#全图进行输入，对大数据从高度进行切分
        y_train = train_label.unsqueeze(0)
        train_acc_sum, trained_samples_counter = 0.0, 0
        batch_counter, train_loss_sum = 0, 0
        time_epoch = time.time()
        loss_dict = {}
    
        model.train()
    
        if split_image:#对汉川，休斯顿从图的宽进行切片，从中间切
            x_part1 = x[:, :, :x.shape[2] // 2+5, :]
            y_part1 = y_train[:,:x.shape[2] // 2+5,:]
            x_part2 = x[:, :, x.shape[2] // 2 - 5: , :]
            y_part2 = y_train[:,x.shape[2] // 2 - 5:,:]
            
            y_pred_part1 = model(x_part1)
            ls1 = head_loss(loss_func,y_pred_part1, y_part1.long())#计算损失值
            optimizer.zero_grad()
            ls1.backward()
            optimizer.step()
            torch.cuda.empty_cache()
    
            y_pred_part2 = model(x_part2)
            ls2 = head_loss(loss_func,y_pred_part2, y_part2.long())
            optimizer.zero_grad()
            ls2.backward()
            optimizer.step()
            torch.cuda.empty_cache()
            logger.info('Iter:{}|loss:{}'.format(epoch, (ls1 + ls2).detach().cpu().numpy()))
    
        else:
            try:
                y_pred = model(x)
                ls = head_loss(loss_func,y_pred, y_train.long())
                optimizer.zero_grad()
                ls.backward()
                optimizer.step()
                logger.info('Iter:{}|loss:{}'.format(epoch, ls.detach().cpu().numpy()))
            except:
                optimizer.zero_grad()
                torch.cuda.empty_cache()
                split_image=True
                x_part1 = x[:, :, :x.shape[2] // 2 + 5, :]
                y_part1 = y_train[:, :x.shape[2] // 2 + 5, :]
                x_part2 = x[:, :, x.shape[2] // 2 - 5:, :]
                y_part2 = y_train[:, x.shape[2] // 2 - 5:, :]
    
                y_pred_part1 = model(x_part1)
                ls1 = head_loss(loss_func, y_pred_part1, y_part1.long())
                optimizer.zero_grad()
                ls1.backward()
                optimizer.step()
    
                y_pred_part2 = model(x_part2)
                ls2 = head_loss(loss_func, y_pred_part2, y_part2.long())
                optimizer.zero_grad()
                ls2.backward()
                optimizer.step()
    
                logger.info('Iter:{}|loss:{}'.format(epoch, (ls1 + ls2).detach().cpu().numpy()))
        torch.cuda.empty_cache()
        model.eval()
        with torch.no_grad():
            evaluator.reset()
            # output_val = net(x)
            output_val = model(x)
            y_val = val_label.unsqueeze(0)
            seg_logits = resize(input=output_val,
                                size=y_val.shape[1:],
                                mode='bilinear',
                                align_corners=True)
            predict = torch.argmax(seg_logits,dim=1).cpu().numpy()#dim是维度，这里就是沿着第一个维度取最大值的索引，并且去掉这个维度
            Y_val_np = val_label.cpu().numpy()
            Y_val_255 = np.where(Y_val_np==-1,255,Y_val_np)#将值为-1的替换为255
            evaluator.add_batch(np.expand_dims(Y_val_255,axis=0),predict)
            OA = evaluator.Pixel_Accuracy()
            mIOU, IOU = evaluator.Mean_Intersection_over_Union()
            mAcc, Acc = evaluator.Pixel_Accuracy_Class()
            Kappa = evaluator.Kappa()
            
            logger.info('Evaluate {}|OA:{}|MACC:{}|Kappa:{}|MIOU:{}|IOU:{}|ACC:{}'.format(epoch, OA,mAcc,Kappa,mIOU,IOU,Acc))
            # save weight
            if OA>=best_val_acc:
                best_epoch = epoch + 1
                best_val_acc = OA
                torch.save(model.state_dict(), save_weight_path)
            if (epoch+1)%50==0:
                save_single_predict_path = os.path.join(save_vis_folder,'predict_{}.png'.format(str(epoch+1)))
                save_single_gt_path = os.path.join(save_vis_folder,'gt.png')
                vis_a_image(gt,predict,save_single_predict_path, save_single_gt_path)
        torch.cuda.empty_cache()
        
    logger.info("\n\n====================Starting evaluation for testing set.========================\n")
    pred_test = []
    # 开始计时
    start_time = time.time()
    load_weight_path = save_weight_path
    model.update_params = None
    # best_net = copy.deepcopy(net)
    best_model = Mamba_float16(in_channels=channels, num_classes=class_count, hidden_dim=128)

    best_model.to(device)
    best_model.load_state_dict(torch.load(load_weight_path))
    best_model.eval()
    test_evaluator = Evaluator(num_class=class_count)

    with torch.no_grad():
        test_evaluator.reset()
        output_test = best_model(x)
    
        y_test = test_label.unsqueeze(0)
        seg_logits_test = resize(input=output_test,
                            size=y_test.shape[1:],
                            mode='bilinear',
                            align_corners=True)
        predict_test = torch.argmax(seg_logits_test, dim=1).cpu().numpy()
        Y_test_np = test_label.cpu().numpy()
        Y_test_255 = np.where(Y_test_np == -1, 255, Y_test_np)
        test_evaluator.add_batch(np.expand_dims(Y_test_255, axis=0), predict_test)
        OA_test = test_evaluator.Pixel_Accuracy()
        mIOU_test, IOU_test = test_evaluator.Mean_Intersection_over_Union()
        mAcc_test, Acc_test = test_evaluator.Pixel_Accuracy_Class()
        Kappa_test = evaluator.Kappa()
        logger.info('Test {}|OA:{}|MACC:{}|Kappa:{}|MIOU:{}|IOU:{}|ACC:{}'.format(epoch, OA_test, mAcc_test, Kappa_test, mIOU_test, IOU_test,
                                                                                Acc_test))
        vis_a_image(gt, predict_test, predict_save_path, gt_save_path)
    # 结束计时
    end_time = time.time()
    
    # 计算运行时间
    elapsed_time = end_time - start_time
    print(elapsed_time)
    f = open(results_save_path, 'a+')
    str_results = '\n======================' \
                  + " exp_idx=" + str(0) \
                  + " seed=" + str(0) \
                  + " learning rate=" + str(learning_rate) \
                  + " epochs=" + str(max_epoch) \
                  + " train ratio=" + str(ratio_list[0]) \
                  + " val ratio=" + str(ratio_list[1]) \
                  + " ======================" \
                  + "\nOA=" + str(OA_test) \
                  + "\nAA=" + str(mAcc_test) \
                  + '\nkpp=' + str(Kappa_test) \
                  + '\nmIOU_test:' + str(mIOU_test) \
                  + "\nIOU_test:" + str(IOU_test) \
                  + "\nAcc_test:" + str(Acc_test) + "\n"
    logger.info(str_results)
    f.write(str_results)
    f.close()

    OA_ALL.append(OA_test)
    AA_ALL.append(mAcc_test)
    KPP_ALL.append(Kappa_test)
    EACH_ACC_ALL.append(Acc_test)
    Test_Time_ALL.append(elapsed_time)

    torch.cuda.empty_cache()
    
#十轮实验结束汇总    
OA_ALL = np.array(OA_ALL)
AA_ALL = np.array(AA_ALL)
KPP_ALL = np.array(KPP_ALL)
EACH_ACC_ALL = np.array(EACH_ACC_ALL)
Train_Time_ALL = np.array(Train_Time_ALL)
Test_Time_ALL = np.array(Test_Time_ALL)

np.set_printoptions(precision=4)
logger.info("\n====================Mean result of {} times runs =========================".format(len(seed_list)))
logger.info('List of OA:', list(OA_ALL))
logger.info('List of AA:', list(AA_ALL))
logger.info('List of KPP:', list(KPP_ALL))
logger.info('OA=', round(np.mean(OA_ALL) * 100, 2), '+-', round(np.std(OA_ALL) * 100, 2))
logger.info('AA=', round(np.mean(AA_ALL) * 100, 2), '+-', round(np.std(AA_ALL) * 100, 2))
logger.info('Kpp=', round(np.mean(KPP_ALL) * 100, 2), '+-', round(np.std(KPP_ALL) * 100, 2))
logger.info('Acc per class=', np.round(np.mean(EACH_ACC_ALL, 0) * 100, decimals=2), '+-',
      np.round(np.std(EACH_ACC_ALL, 0) * 100, decimals=2))

logger.info("Average training time=", round(np.mean(Train_Time_ALL), 2), '+-', round(np.std(Train_Time_ALL), 3))
logger.info("Average testing time=", round(np.mean(Test_Time_ALL) * 1000, 2), '+-',round(np.std(Test_Time_ALL) * 1000, 3))

# Output infors
mean_result_path = os.path.join(save_folder,'mean_result.txt')
f = open(mean_result_path, 'w')
str_results = '\n\n***************Mean result of ' + str(len(seed_list)) + 'times runs ********************' \
              + '\nList of OA:' + str(list(OA_ALL)) \
              + '\nList of AA:' + str(list(AA_ALL)) \
              + '\nList of KPP:' + str(list(KPP_ALL)) \
              + '\nOA=' + str(round(np.mean(OA_ALL) * 100, 2)) + '+-' + str(round(np.std(OA_ALL) * 100, 2)) \
              + '\nAA=' + str(round(np.mean(AA_ALL) * 100, 2)) + '+-' + str(round(np.std(AA_ALL) * 100, 2)) \
              + '\nKpp=' + str(round(np.mean(KPP_ALL) * 100, 2)) + '+-' + str(
    round(np.std(KPP_ALL) * 100, 2)) \
              + '\nAcc per class=\n' + str(np.round(np.mean(EACH_ACC_ALL, 0) * 100, 2)) + '+-' + str(
    np.round(np.std(EACH_ACC_ALL, 0) * 100, 2)) \
              + "\nAverage training time=" + str(
    np.round(np.mean(Train_Time_ALL), decimals=2)) + '+-' + str(
    np.round(np.std(Train_Time_ALL), decimals=3)) \
              + "\nAverage testing time=" + str(
    np.round(np.mean(Test_Time_ALL) * 1000, decimals=2)) + '+-' + str(
    np.round(np.std(Test_Time_ALL) * 100, decimals=3))
f.write(str_results)
f.close()

del model

[I 250519 10:03:38 3913338442:13] ./RESULT_FP16_Conv/MambaHSI/HongHu


makedirs ./RESULT_FP16_Conv/MambaHSI/HongHu
makedirs ./RESULT_FP16_Conv/MambaHSI/HongHu/run0_seed0/vis


[I 250519 10:03:48 3913338442:56] {'net_name': 'MambaHSI', 'dataset_index': 2, 'num_list': [30, 10], 'lr': 0.0003, 'seed_list': [0]}
[I 250519 10:03:48 3913338442:57] Mamba_float16(
      (patch_embedding): Sequential(
        (0): Conv2d(270, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): GroupNorm(4, 128, eps=1e-05, affine=True)
        (2): SiLU()
      )
      (high_channel_128_32): Sequential(
        (0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): GroupNorm(4, 32, eps=1e-05, affine=True)
        (2): SiLU()
      )
      (high_channel_160_32): Sequential(
        (0): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): GroupNorm(4, 32, eps=1e-05, affine=True)
        (2): SiLU()
      )
      (low_channel_160_128): Sequential(
        (0): HalfPrecisionConv(160, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): GroupNorm(4, 128, eps=1e-05, affine=True)
        (2): SiLU()
      )
      (low_channel_128_128): Sequential(
        (0): HalfP


------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  243.68 K
fwd MACs:                                                               53.82 GMACs
fwd FLOPs:                                                              109.12 GFLOPS
fwd+bwd MACs:                                                           161.47 GMACs
fwd+bwd FLOPs:                                                          327.36 GFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module 

[I 250519 10:03:58 3913338442:129] Iter:0|loss:3.205014944076538
[I 250519 10:03:58 3913338442:172] Evaluate 0|OA:0.05454545454545454|MACC:0.05454545454545454|Kappa:0.00952380952380952|MIOU:0.009983437615016562|IOU:[0.         0.         0.         0.         0.         0.
     0.         0.         0.         0.06578947 0.         0.
     0.         0.         0.         0.         0.         0.
     0.15384615 0.         0.         0.        ]|ACC:[0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.
     0.2 0.  0.  0. ]
[I 250519 10:04:00 3913338442:129] Iter:1|loss:3.096360921859741
[I 250519 10:04:00 3913338442:172] Evaluate 1|OA:0.10454545454545454|MACC:0.10454545454545454|Kappa:0.06190476190476189|MIOU:0.023516268501803342|IOU:[0.         0.         0.         0.         0.         0.
     0.         0.         0.         0.08474576 0.07692308 0.
     0.         0.         0.         0.13207547 0.         0.
     0.15384615 0.         0.06976744 0.        ]|AC

0.4053347110748291


In [None]:
print(model.low_channel_128_128[0].weight.dtype)
print(model.low_channel_128_128[0].bias.dtype)
print(model.low_channel_160_128[0].bias.dtype)
print(model.low_channel_160_128[0].bias.dtype)

In [9]:
def measure_memory(model, input_shape):
    # 初始化
    torch.cuda.empty_cache()
    model = model.to('cuda')
    
    # 预热
    _ = model(input_shape)
    
    # 内存基准
    torch.cuda.reset_peak_memory_stats()
    start_mem = torch.cuda.memory_allocated()
    
    # 前向传播
    output = model(input_shape)
    
    # 反向传播（如需）
    # loss = output.sum()
    # loss.backward()
    
    # 计算峰值
    peak_mem = torch.cuda.max_memory_allocated()
    del output
    return (peak_mem - start_mem) / 1024**3  # 转换为 GB

In [10]:
# model_a_mem = measure_memory(best_net,x)
best_model = Mamba_float16(in_channels=channels, num_classes=class_count, hidden_dim=128)
best_model.to(device)
best_model.load_state_dict(torch.load(load_weight_path))
best_model.eval()
model_b_mem = measure_memory(best_model,x)

# print(f"Model A 峰值内存: {model_a_mem:.3f} GB")
print(f"Model B 峰值内存: {model_b_mem:.3f} GB")
# print(f"内存差异: {(model_b_mem - model_a_mem)/model_a_mem*100:.1f}%")

Model B 峰值内存: 9.444 GB


In [6]:
data_set_path = args.data_set_path
data, gt = data_load_operate.load_data(data_set_name, data_set_path)
gt_reshape = gt.reshape(-1)
height, width, channels = data.shape
img = ImageStretching(data)#将每个通道的数据变成0-255
class_count = max(np.unique(gt))
ratio_list = [0.1, 0.01]
flag_list = [1, 0]
x = transform(np.array(img))#转化为张量
x = x.unsqueeze(0).float().to(device)

In [7]:
model = Mamba_float16(in_channels=channels, num_classes=class_count, hidden_dim=128).to(device)
model.load_state_dict(torch.load("./RESULT_FLOAT/MambaHSI/UP/run0_seed0/best_tr30_val10.pth"))

<All keys matched successfully>

In [None]:
model.eval()
test_evaluator = Evaluator(num_class=9)
with torch.no_grad():
    test_evaluator.reset()
    output_test = model(x)

    y_test = test_label.unsqueeze(0)
    seg_logits_test = resize(input=output_test,
                        size=y_test.shape[1:],
                        mode='bilinear',
                        align_corners=True)
    predict_test = torch.argmax(seg_logits_test, dim=1).cpu().numpy()
    Y_test_np = test_label.cpu().numpy()
    Y_test_255 = np.where(Y_test_np == -1, 255, Y_test_np)
    test_evaluator.add_batch(np.expand_dims(Y_test_255, axis=0), predict_test)
    OA_test = test_evaluator.Pixel_Accuracy()
    mIOU_test, IOU_test = test_evaluator.Mean_Intersection_over_Union()
    mAcc_test, Acc_test = test_evaluator.Pixel_Accuracy_Class()
    Kappa_test = test_evaluator.Kappa()
    print('Test {}|OA:{}|MACC:{}|Kappa:{}|MIOU:{}|IOU:{}|ACC:{}'.format(1, OA_test, mAcc_test, Kappa_test, mIOU_test, IOU_test,
                                                                            Acc_test))