In [1]:
from scipy.io import loadmat
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import torch
from torch import nn
from mamba_ssm import Mamba
import math
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import cohen_kappa_score
import argparse
import os
import random
import time
from torchvision import models,transforms
import utils.data_load_operate as data_load_operate
from utils.Loss import head_loss,resize
from utils.evaluation import Evaluator
from utils.HSICommonUtils import normlize3D, ImageStretching
from utils.setup_logger import setup_logger
from utils.visual_predict import visualize_predict
from PIL import Image
from calflops import calculate_flops
import cv2
from typing import Tuple,Union
from GPUtil import showUtilization as gpu_usage

In [2]:
class SpeMamba(nn.Module):
    def __init__(self,channels, token_num=8, use_residual=True, group_num=4):
        super(SpeMamba, self).__init__()
        self.token_num = token_num
        self.use_residual = use_residual

        self.group_channel_num = math.ceil(channels/token_num)
        self.channel_num = self.token_num * self.group_channel_num

        self.mamba = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters
                            d_model=self.group_channel_num,  # Model dimension d_model
                            d_state=16,  # SSM state expansion factor
                            d_conv=4,  # Local convolution width
                            expand=2,  # Block expansion factor
                            )

        self.proj = nn.Sequential(
            nn.GroupNorm(group_num, self.channel_num),
            nn.SiLU()
        )

    def padding_feature(self,x):
        B, C, H, W = x.shape
        if C < self.channel_num:
            pad_c = self.channel_num - C
            pad_features = torch.zeros((B, pad_c, H, W)).to(x.device)
            cat_features = torch.cat([x, pad_features], dim=1)
            return cat_features
        else:
            return x

    def forward(self,x):
        x_pad = self.padding_feature(x)
        x_pad = x_pad.permute(0, 2, 3, 1).contiguous()
        B, H, W, C_pad = x_pad.shape
        x_flat = x_pad.view(B * H * W, self.token_num, self.group_channel_num)
        x_flat = self.mamba(x_flat)
        x_recon = x_flat.view(B, H, W, C_pad)
        x_recon = x_recon.permute(0, 3, 1, 2).contiguous()
        x_proj = self.proj(x_recon)
        if self.use_residual:
            return x + x_proj
        else:
            return x_proj


class SpaMamba(nn.Module):
    def __init__(self,channels,use_residual=True,group_num=4,use_proj=True):
        super(SpaMamba, self).__init__()
        self.use_residual = use_residual
        self.use_proj = use_proj
        self.mamba = Mamba(  # This module uses roughly 3 * expand * d_model^2 parameters
                           d_model=channels,  # Model dimension d_model
                           d_state=16,  # SSM state expansion factor
                           d_conv=4,  # Local convolution width
                           expand=2,  # Block expansion factor
                           )
        if self.use_proj:
            self.proj = nn.Sequential(
                nn.GroupNorm(group_num, channels),
                nn.SiLU()
            )

    def forward(self,x):
        x_re = x.permute(0, 2, 3, 1).contiguous()
        B,H,W,C = x_re.shape
        x_flat = x_re.view(1,-1, C)
        x_flat = self.mamba(x_flat)

        x_recon = x_flat.view(B, H, W, C)
        x_recon = x_recon.permute(0, 3, 1, 2).contiguous()
        if self.use_proj:
            x_recon = self.proj(x_recon)
        if self.use_residual:
            return x_recon + x
        else:
            return x_recon


class BothMamba(nn.Module):
    def __init__(self,channels,token_num,use_residual,group_num=4,use_att=True):
        super(BothMamba, self).__init__()
        self.use_att = use_att
        self.use_residual = use_residual
        if self.use_att:
            self.weights = nn.Parameter(torch.ones(2) / 2)
            self.softmax = nn.Softmax(dim=0)

        self.spa_mamba = SpaMamba(channels,use_residual=use_residual,group_num=group_num)
        self.spe_mamba = SpeMamba(channels,token_num=token_num,use_residual=use_residual,group_num=group_num)

    def forward(self,x):
        spa_x = self.spa_mamba(x)
        spe_x = self.spe_mamba(x)
        if self.use_att:
            weights = self.softmax(self.weights)
            fusion_x = spa_x * weights[0] + spe_x * weights[1]
        else:
            fusion_x = spa_x + spe_x
        if self.use_residual:
            return fusion_x + x
        else:
            return fusion_x


In [3]:
# 可视化图片保存
def vis_a_image(gt_vis,pred_vis,save_single_predict_path,save_single_gt_path,only_vis_label=False):
    visualize_predict(gt_vis,pred_vis,save_single_predict_path,save_single_gt_path,only_vis_label=only_vis_label)
    visualize_predict(gt_vis,pred_vis,save_single_predict_path.replace('.png','_mask.png'),save_single_gt_path,only_vis_label=True)

# 设置种子使得所有参数的初始化相同，能够复现
def setup_seed(seed):
    # PyTorch CPU随机种子
    torch.manual_seed(seed)
    # 所有GPU的随机种子
    torch.cuda.manual_seed_all(seed)
    # Python哈希种子（影响字典等数据结构的行为）
    os.environ['PYTHONHASHSEED'] = str(seed)
    # NumPy随机种子
    np.random.seed(seed)
    # Python内置随机种子
    random.seed(seed)
    # 启用确定性算法（降低性能但保证可重复）
    torch.backends.cudnn.deterministic = True
    # 关闭自动优化（固定卷积算法选择）
    torch.backends.cudnn.benchmark = False

# 函数返回解析后的参数对象args，这样主程序可以通过args.dataset_index等方式访问这些参数的值
def get_parser():
    # 创建参数解析器实例
    parser = argparse.ArgumentParser(
        # 添加这两行配置
        allow_abbrev=False,
        add_help=False
    )
    # 添加参数定义
    parser.add_argument('--dataset_index', type=int, default=3)      # 数据集编号
    parser.add_argument('--data_set_path', type=str, default='./data') # 数据集路径
    parser.add_argument('--work_dir', type=str, default='./')        # 工作目录
    parser.add_argument('--lr', type=float, default=0.0003)         # 学习率
    parser.add_argument('--max_epoch', type=int, default=200)       # 最大训练轮次
    parser.add_argument('--train_samples', type=int, default=30)    # 每类训练样本数
    parser.add_argument('--val_samples', type=int, default=10)      # 每类验证样本数
    parser.add_argument('--exp_name', type=str, default='RUNS')     # 实验名称
    parser.add_argument('--record_computecost', type=bool, default=True) # 是否记录计算成本
    # 解析命令行参数
    args, _ = parser.parse_known_args()  # 改为解析已知参数
    return args


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args = get_parser()

record_computecost = args.record_computecost
exp_name = args.exp_name
dataset_index = args.dataset_index
max_epoch = args.max_epoch
learning_rate = args.lr
work_dir = args.work_dir

seed_list = [0,1,2,3,4,5,6,7,8,9]  #
#seed_list = [0] 

num_list = [args.train_samples, args.val_samples]
net_name = 'MambaHSI'

paras_dict = {'net_name':net_name,'dataset_index':dataset_index,'num_list':num_list,
              'lr':learning_rate,'seed_list':seed_list}


                      # 0        1         2         3        4
data_set_name_list = ['UP', 'HanChuan', 'HongHu', 'Houston']
data_set_name = data_set_name_list[dataset_index]

if data_set_name in ['HanChuan','Houston']:
    split_image = True
else:
    split_image = False
transform = transforms.Compose([
    # transforms.Resize((2048, 1024)),  # 调整图像尺寸（已注释）
    transforms.ToTensor(),              # 核心转换：将图像转为张量，将 PIL 图像或 NumPy 数组转换为张量，并自动归一化像素值到 [0,1] 范围
    # 标准化操作（以下两项均被注释）
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    # transforms.Normalize(mean=[123.6750, 116.2800, 103.5300], std=[58.395, 57.120, 57.3750]),
])


In [4]:
class MambaHSI(nn.Module):
    def __init__(self,in_channels=128,hidden_dim=64,num_classes=10,use_residual=True,mamba_type='both',token_num=4,group_num=4,use_att=True):
        super(MambaHSI, self).__init__()
        self.mamba_type = mamba_type

        self.patch_embedding = nn.Sequential(nn.Conv2d(in_channels=in_channels,out_channels=hidden_dim,kernel_size=1,stride=1,padding=0),
                                             nn.GroupNorm(group_num,hidden_dim),
                                             nn.SiLU())
        if mamba_type == 'spa':
            self.mamba = nn.Sequential(SpaMamba(hidden_dim,use_residual=use_residual,group_num=group_num),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
                                        SpaMamba(hidden_dim,use_residual=use_residual,group_num=group_num),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
                                        SpaMamba(hidden_dim,use_residual=use_residual,group_num=group_num),
                                        )
        elif mamba_type == 'spe':
            self.mamba = nn.Sequential(SpeMamba(hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0),

                                        SpeMamba(hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0),

                                        SpeMamba(hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num)
                                        )

        elif mamba_type=='both':
            self.mamba = nn.Sequential(BothMamba(channels=hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num,use_att=use_att),
                                       nn.AvgPool2d(kernel_size=2, stride=2, padding=0),

                                       BothMamba(channels=hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num,use_att=use_att),
                                       nn.AvgPool2d(kernel_size=2, stride=2, padding=0),

                                       BothMamba(channels=hidden_dim,token_num=token_num,use_residual=use_residual,group_num=group_num,use_att=use_att),
                                       )


        self.cls_head = nn.Sequential(nn.Conv2d(in_channels=hidden_dim, out_channels=128, kernel_size=1, stride=1, padding=0),
                                      nn.GroupNorm(group_num,128),
                                      nn.SiLU(),
                                      nn.Conv2d(in_channels=128,out_channels=num_classes,kernel_size=1,stride=1,padding=0))

    def forward(self,x):

        x = self.patch_embedding(x)
        x = self.mamba(x)

        logits = self.cls_head(x)
        return logits

In [None]:
if __name__ == '__main__':
    data_set_path = args.data_set_path
    work_dir = args.work_dir
    setting_name = 'tr{}val{}'.format(str(args.train_samples),str(args.val_samples)) + '_lr{}'.format(str(learning_rate))
    dataset_name = data_set_name

    exp_name = args.exp_name

    save_folder = os.path.join(work_dir, exp_name, net_name, dataset_name)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
        print("makedirs {}".format(save_folder))
    
    save_log_path = os.path.join(save_folder,'train_tr{}_val{}.log'.format(num_list[0],num_list[1]))
    # 创建与数据集同名的日志记录器，并指定日志文件路径
    logger = setup_logger(
        name='{}'.format(dataset_name),  # 日志器名称 = 数据集名称
        logfile=save_log_path            # 日志文件保存路径
    )
    torch.cuda.empty_cache()#手动释放 GPU 未使用缓存内存 的函数

    logger.info(save_folder)

    data, gt = data_load_operate.load_data(data_set_name, data_set_path)

    height, width, channels = data.shape

    gt_reshape = gt.reshape(-1)
    
    height, width, channels = data.shape
    img = ImageStretching(data)#将每个通道的数据变成0-255

    class_count = max(np.unique(gt))
    flag_list = [1, 0]  # ratio or num
    ratio_list = [0.1, 0.01]  # [train_ratio,val_ratio]

    loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1)

    OA_ALL = []
    AA_ALL = []
    KPP_ALL = []
    EACH_ACC_ALL = []
    Train_Time_ALL = []
    Test_Time_ALL = []
    CLASS_ACC = np.zeros([len(seed_list), class_count])
    evaluator = Evaluator(num_class=class_count)

    for exp_idx,curr_seed in enumerate(seed_list):
        setup_seed(curr_seed)
        single_experiment_name = 'run{}_seed{}'.format(str(exp_idx), str(curr_seed))
        save_single_experiment_folder = os.path.join(save_folder, single_experiment_name)
        if not os.path.exists(save_single_experiment_folder):
            os.mkdir(save_single_experiment_folder)
        save_vis_folder = os.path.join(save_single_experiment_folder, 'vis')
        if not os.path.exists(save_vis_folder):
            os.makedirs(save_vis_folder)
            print("makedirs {}".format(save_vis_folder))

        save_weight_path = os.path.join(save_single_experiment_folder, "best_tr{}_val{}.pth".format(num_list[0], num_list[1]))
        results_save_path = os.path.join(save_single_experiment_folder, 'result_tr{}_val{}.txt'.format(num_list[0], num_list[1]))
        predict_save_path = os.path.join(save_single_experiment_folder, 'pred_vis_tr{}_val{}.png'.format(num_list[0], num_list[1]))
        gt_save_path = os.path.join(save_single_experiment_folder, 'gt_vis_tr{}_val{}.png'.format(num_list[0], num_list[1]))

        train_data_index, val_data_index, test_data_index, all_data_index = data_load_operate.sampling(ratio_list,
                                                                                                       num_list,
                                                                                                       gt_reshape,
                                                                                                       class_count,
                                                                                                       flag_list[0])#返回一维的索引
        index = (train_data_index, val_data_index, test_data_index)
        train_label, val_label, test_label = data_load_operate.generate_image_iter(data, height, width, gt_reshape, index)#返回和原图一样大小的二维数据，
                                                                                                  #以train_label为例，选中的二维索引是对应的标签值-1，其他的都是-1

    #     # build Model

        net = MambaHSI(in_channels=channels, num_classes=class_count, hidden_dim=128)
        logger.info(paras_dict)
        logger.info(net)

        x = transform(np.array(img))#转化为张量
        x = x.unsqueeze(0).float().to(device)

        train_label = train_label.to(device)
        test_label = test_label.to(device)
        val_label = val_label.to(device)
        # ############################################
        # val_label = test_label
        # ############################################

        net.to(device)

        train_loss_list = [100]
        train_acc_list = [0]
        val_loss_list = [100]
        val_acc_list = [0]

        optimizer = torch.optim.Adam(net.parameters(),lr=learning_rate)

        logger.info(optimizer)
        best_loss = 99999
        if record_computecost:#计算运算成本
            net.eval()
            flops, macs1, para = calculate_flops(model=net,
                                                 input_shape=(1, x.shape[1], x.shape[2], x.shape[3]), )
            logger.info("para:{}\n,flops:{}".format(para, flops))#FLOPs	浮点运算次数 (Floating Point Operations)，衡量模型计算复杂度
                                                                 #MACs	乘加运算次数 (Multiply-Accumulate Operations)，硬件性能参考指标
                                                                 #Parameters	模型参数量，决定模型大小和内存占用

        tic1 = time.perf_counter()#记录开始时间
        best_val_acc = 0


        for epoch in range(max_epoch):#全图进行输入，对大数据从高度进行切分
            y_train = train_label.unsqueeze(0)
            train_acc_sum, trained_samples_counter = 0.0, 0
            batch_counter, train_loss_sum = 0, 0
            time_epoch = time.time()
            loss_dict = {}

            net.train()

            if split_image:#对汉川，休斯顿从图的宽进行切片，从中间切
                x_part1 = x[:, :, :x.shape[2] // 2+5, :]
                y_part1 = y_train[:,:x.shape[2] // 2+5,:]
                x_part2 = x[:, :, x.shape[2] // 2 - 5: , :]
                y_part2 = y_train[:,x.shape[2] // 2 - 5:,:]
                
                y_pred_part1 = net(x_part1)
                ls1 = head_loss(loss_func,y_pred_part1, y_part1.long())#计算损失值
                optimizer.zero_grad()
                ls1.backward()
                optimizer.step()
                torch.cuda.empty_cache()

                y_pred_part2 = net(x_part2)
                ls2 = head_loss(loss_func,y_pred_part2, y_part2.long())
                optimizer.zero_grad()
                ls2.backward()
                optimizer.step()
                torch.cuda.empty_cache()
                logger.info('Iter:{}|loss:{}'.format(epoch, (ls1 + ls2).detach().cpu().numpy()))

            else:
                try:
                    y_pred = net(x)
                    ls = head_loss(loss_func,y_pred, y_train.long())
                    optimizer.zero_grad()
                    ls.backward()
                    optimizer.step()
                    logger.info('Iter:{}|loss:{}'.format(epoch, ls.detach().cpu().numpy()))
                except:
                    optimizer.zero_grad()
                    torch.cuda.empty_cache()
                    split_image=True
                    x_part1 = x[:, :, :x.shape[2] // 2 + 5, :]
                    y_part1 = y_train[:, :x.shape[2] // 2 + 5, :]
                    x_part2 = x[:, :, x.shape[2] // 2 - 5:, :]
                    y_part2 = y_train[:, x.shape[2] // 2 - 5:, :]

                    y_pred_part1 = net(x_part1)
                    ls1 = head_loss(loss_func, y_pred_part1, y_part1.long())
                    optimizer.zero_grad()
                    ls1.backward()
                    optimizer.step()

                    y_pred_part2 = net(x_part2)
                    ls2 = head_loss(loss_func, y_pred_part2, y_part2.long())
                    optimizer.zero_grad()
                    ls2.backward()
                    optimizer.step()

                    logger.info(
                        'Iter:{}|loss:{}'.format(epoch, (ls1 + ls2).detach().cpu().numpy()))

            torch.cuda.empty_cache()
            # evaluate stage
            net.eval()
            with torch.no_grad():
                evaluator.reset()
                # output_val = net(x)
                output_val = net(x)
                y_val = val_label.unsqueeze(0)
                seg_logits = resize(input=output_val,
                                    size=y_val.shape[1:],
                                    mode='bilinear',
                                    align_corners=True)
                predict = torch.argmax(seg_logits,dim=1).cpu().numpy()#dim是维度，这里就是沿着第一个维度取最大值的索引，并且去掉这个维度
                Y_val_np = val_label.cpu().numpy()
                Y_val_255 = np.where(Y_val_np==-1,255,Y_val_np)#将值为-1的替换为255
                evaluator.add_batch(np.expand_dims(Y_val_255,axis=0),predict)
                OA = evaluator.Pixel_Accuracy()
                mIOU, IOU = evaluator.Mean_Intersection_over_Union()
                mAcc, Acc = evaluator.Pixel_Accuracy_Class()
                Kappa = evaluator.Kappa()
                logger.info('Evaluate {}|OA:{}|MACC:{}|Kappa:{}|MIOU:{}|IOU:{}|ACC:{}'.format(epoch, OA,mAcc,Kappa,mIOU,IOU,Acc))
                # save weight
                if OA>=best_val_acc:
                    best_epoch = epoch + 1
                    best_val_acc = OA
                    # torch.save(net,save_weight_path)
                    torch.save(net.state_dict(), save_weight_path)
                    # save_epoch_weight_path = os.path.join(save_folder,'{}.pth'.format(str(epoch+1)))
                    # torch.save(net.state_dict(), save_epoch_weight_path)
                if (epoch+1)%50==0:
                    save_single_predict_path = os.path.join(save_vis_folder,'predict_{}.png'.format(str(epoch+1)))
                    save_single_gt_path = os.path.join(save_vis_folder,'gt.png')
                    vis_a_image(gt,predict,save_single_predict_path, save_single_gt_path)

                # net.train()
            torch.cuda.empty_cache()


        logger.info("\n\n====================Starting evaluation for testing set.========================\n")
        pred_test = []
        start_time = time.time()
        load_weight_path = save_weight_path
        net.update_params = None
        # best_net = copy.deepcopy(net)
        best_net = MambaHSI(in_channels=channels, num_classes=class_count, hidden_dim=128)

        best_net.to(device)
        best_net.load_state_dict(torch.load(load_weight_path))
        best_net.eval()
        test_evaluator = Evaluator(num_class=class_count)
        with torch.no_grad():
            test_evaluator.reset()
            output_test = best_net(x)

            y_test = test_label.unsqueeze(0)
            seg_logits_test = resize(input=output_test,
                                size=y_test.shape[1:],
                                mode='bilinear',
                                align_corners=True)
            predict_test = torch.argmax(seg_logits_test, dim=1).cpu().numpy()
            Y_test_np = test_label.cpu().numpy()
            Y_test_255 = np.where(Y_test_np == -1, 255, Y_test_np)
            test_evaluator.add_batch(np.expand_dims(Y_test_255, axis=0), predict_test)
            OA_test = test_evaluator.Pixel_Accuracy()
            mIOU_test, IOU_test = test_evaluator.Mean_Intersection_over_Union()
            mAcc_test, Acc_test = test_evaluator.Pixel_Accuracy_Class()
            Kappa_test = evaluator.Kappa()
            logger.info('Test {}|OA:{}|MACC:{}|Kappa:{}|MIOU:{}|IOU:{}|ACC:{}'.format(epoch, OA_test, mAcc_test, Kappa_test, mIOU_test, IOU_test,
                                                                                    Acc_test))
            vis_a_image(gt, predict_test, predict_save_path, gt_save_path)
        # 结束计时
        end_time = time.time()
    
        # 计算运行时间
        elapsed_time = end_time - start_time
        print(elapsed_time)
        # Output infors
        f = open(results_save_path, 'a+')
        str_results = '\n======================' \
                      + " exp_idx=" + str(exp_idx) \
                      + " seed=" + str(curr_seed) \
                      + " learning rate=" + str(learning_rate) \
                      + " epochs=" + str(max_epoch) \
                      + " train ratio=" + str(ratio_list[0]) \
                      + " val ratio=" + str(ratio_list[1]) \
                      + " ======================" \
                      + "\nOA=" + str(OA_test) \
                      + "\nAA=" + str(mAcc_test) \
                      + '\nkpp=' + str(Kappa_test) \
                      + '\nmIOU_test:' + str(mIOU_test) \
                      + "\nIOU_test:" + str(IOU_test) \
                      + "\nAcc_test:" + str(Acc_test) + "\n"
        logger.info(str_results)
        f.write(str_results)
        f.close()

        OA_ALL.append(OA_test)
        AA_ALL.append(mAcc_test)
        KPP_ALL.append(Kappa_test)
        EACH_ACC_ALL.append(Acc_test)
        Test_Time_ALL.append(elapsed_time)

        torch.cuda.empty_cache()

    OA_ALL = np.array(OA_ALL)
    AA_ALL = np.array(AA_ALL)
    KPP_ALL = np.array(KPP_ALL)
    EACH_ACC_ALL = np.array(EACH_ACC_ALL)
    Train_Time_ALL = np.array(Train_Time_ALL)
    Test_Time_ALL = np.array(Test_Time_ALL)

    np.set_printoptions(precision=4)
    logger.info("\n====================Mean result of {} times runs =========================".format(len(seed_list)))
    logger.info('List of OA:', list(OA_ALL))
    logger.info('List of AA:', list(AA_ALL))
    logger.info('List of KPP:', list(KPP_ALL))
    logger.info('OA=', round(np.mean(OA_ALL) * 100, 2), '+-', round(np.std(OA_ALL) * 100, 2))
    logger.info('AA=', round(np.mean(AA_ALL) * 100, 2), '+-', round(np.std(AA_ALL) * 100, 2))
    logger.info('Kpp=', round(np.mean(KPP_ALL) * 100, 2), '+-', round(np.std(KPP_ALL) * 100, 2))
    logger.info('Acc per class=', np.round(np.mean(EACH_ACC_ALL, 0) * 100, decimals=2), '+-',
          np.round(np.std(EACH_ACC_ALL, 0) * 100, decimals=2))

    logger.info("Average training time=", round(np.mean(Train_Time_ALL), 2), '+-', round(np.std(Train_Time_ALL), 3))
    logger.info("Average testing time=", round(np.mean(Test_Time_ALL) * 1000, 2), '+-',
          round(np.std(Test_Time_ALL) * 1000, 3))

    # Output infors
    mean_result_path = os.path.join(save_folder,'mean_result.txt')
    f = open(mean_result_path, 'w')
    str_results = '\n\n***************Mean result of ' + str(len(seed_list)) + 'times runs ********************' \
                  + '\nList of OA:' + str(list(OA_ALL)) \
                  + '\nList of AA:' + str(list(AA_ALL)) \
                  + '\nList of KPP:' + str(list(KPP_ALL)) \
                  + '\nOA=' + str(round(np.mean(OA_ALL) * 100, 2)) + '+-' + str(round(np.std(OA_ALL) * 100, 2)) \
                  + '\nAA=' + str(round(np.mean(AA_ALL) * 100, 2)) + '+-' + str(round(np.std(AA_ALL) * 100, 2)) \
                  + '\nKpp=' + str(round(np.mean(KPP_ALL) * 100, 2)) + '+-' + str(
        round(np.std(KPP_ALL) * 100, 2)) \
                  + '\nAcc per class=\n' + str(np.round(np.mean(EACH_ACC_ALL, 0) * 100, 2)) + '+-' + str(
        np.round(np.std(EACH_ACC_ALL, 0) * 100, 2)) \
                  + "\nAverage training time=" + str(
        np.round(np.mean(Train_Time_ALL), decimals=2)) + '+-' + str(
        np.round(np.std(Train_Time_ALL), decimals=3)) \
                  + "\nAverage testing time=" + str(
        np.round(np.mean(Test_Time_ALL) * 1000, decimals=2)) + '+-' + str(
        np.round(np.std(Test_Time_ALL) * 100, decimals=3))
    f.write(str_results)
    f.close()

    del net


[I 250410 11:40:26 724153362:22] ./RUNS/MambaHSI/Houston
[I 250410 11:40:32 724153362:76] {'net_name': 'MambaHSI', 'dataset_index': 3, 'num_list': [30, 10], 'lr': 0.0003, 'seed_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}
[I 250410 11:40:32 724153362:77] MambaHSI(
      (patch_embedding): Sequential(
        (0): Conv2d(144, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): GroupNorm(4, 128, eps=1e-05, affine=True)
        (2): SiLU()
      )
      (mamba): Sequential(
        (0): BothMamba(
          (softmax): Softmax(dim=0)
          (spa_mamba): SpaMamba(
            (mamba): Mamba(
              (in_proj): Linear(in_features=128, out_features=512, bias=False)
              (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
              (act): SiLU()
              (x_proj): Linear(in_features=256, out_features=40, bias=False)
              (dt_proj): Linear(in_features=8, out_features=256, bias=True)
              (out_proj): Linear(in_features=256, 


------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  418.26 K
fwd MACs:                                                               65.24 GMACs
fwd FLOPs:                                                              132.57 GFLOPS
fwd+bwd MACs:                                                           195.73 GMACs
fwd+bwd FLOPs:                                                          397.7 GFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module c

[I 250410 11:40:39 724153362:140] Iter:0|loss:5.499206066131592
[I 250410 11:40:40 724153362:194] Evaluate 0|OA:0.08|MACC:0.08000000000000002|Kappa:0.014285714285714289|MIOU:0.01728494623655914|IOU:[0.         0.         0.         0.         0.09677419 0.
     0.         0.1        0.         0.         0.         0.
     0.         0.         0.0625    ]|ACC:[0.  0.  0.  0.  0.9 0.  0.  0.2 0.  0.  0.  0.  0.  0.  0.1]
[I 250410 11:40:43 724153362:140] Iter:1|loss:5.273512363433838
[I 250410 11:40:44 724153362:194] Evaluate 1|OA:0.12666666666666668|MACC:0.12666666666666668|Kappa:0.0642857142857143|MIOU:0.05423400673400673|IOU:[0.3125    0.        0.        0.        0.1010101 0.        0.
     0.3       0.        0.        0.        0.1       0.        0.
     0.       ]|ACC:[0.5 0.  0.  0.  1.  0.  0.  0.3 0.  0.  0.  0.1 0.  0.  0. ]
[I 250410 11:40:47 724153362:140] Iter:2|loss:5.14153528213501
[I 250410 11:40:48 724153362:194] Evaluate 2|OA:0.17333333333333334|MACC:0.173333333333

1.1018519401550293


[I 250410 11:55:00 724153362:98] Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        capturable: False
        differentiable: False
        eps: 1e-08
        foreach: None
        fused: None
        lr: 0.0003
        maximize: False
        weight_decay: 0
    )
[I 250410 11:55:00 724153362:104] para:418.26 K
    ,flops:132.57 GFLOPS



------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  418.26 K
fwd MACs:                                                               65.24 GMACs
fwd FLOPs:                                                              132.57 GFLOPS
fwd+bwd MACs:                                                           195.73 GMACs
fwd+bwd FLOPs:                                                          397.7 GFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module c

[I 250410 11:55:04 724153362:140] Iter:0|loss:5.4585700035095215
[I 250410 11:55:05 724153362:194] Evaluate 0|OA:0.16|MACC:0.16|Kappa:0.1|MIOU:0.07258467023172906|IOU:[0.81818182 0.         0.         0.         0.         0.
     0.         0.         0.         0.         0.         0.
     0.         0.07058824 0.2       ]|ACC:[0.9 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.6 0.9]
[I 250410 11:55:08 724153362:140] Iter:1|loss:5.2978949546813965
[I 250410 11:55:09 724153362:194] Evaluate 1|OA:0.12|MACC:0.12000000000000001|Kappa:0.05714285714285714|MIOU:0.027792207792207792|IOU:[0.3        0.         0.         0.         0.         0.
     0.         0.         0.         0.         0.         0.
     0.         0.         0.11688312]|ACC:[0.9 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.9]
[I 250410 11:55:12 724153362:140] Iter:2|loss:5.174700736999512
[I 250410 11:55:13 724153362:194] Evaluate 2|OA:0.14666666666666667|MACC:0.14666666666666667|Kappa:0.085714285714285

1.1081578731536865
makedirs ./RUNS/MambaHSI/Houston/run2_seed2/vis


[I 250410 12:09:47 724153362:98] Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        capturable: False
        differentiable: False
        eps: 1e-08
        foreach: None
        fused: None
        lr: 0.0003
        maximize: False
        weight_decay: 0
    )
[I 250410 12:09:47 724153362:104] para:418.26 K
    ,flops:132.57 GFLOPS



------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  418.26 K
fwd MACs:                                                               65.24 GMACs
fwd FLOPs:                                                              132.57 GFLOPS
fwd+bwd MACs:                                                           195.73 GMACs
fwd+bwd FLOPs:                                                          397.7 GFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module c

[I 250410 12:09:51 724153362:140] Iter:0|loss:5.507107734680176
[I 250410 12:09:52 724153362:194] Evaluate 0|OA:0.12666666666666668|MACC:0.12666666666666665|Kappa:0.0642857142857143|MIOU:0.07280706221882692|IOU:[0.         0.         0.         0.         0.63636364 0.0952381
     0.14285714 0.1        0.11764706 0.         0.         0.
     0.         0.         0.        ]|ACC:[0.  0.  0.  0.  0.7 0.2 0.3 0.1 0.6 0.  0.  0.  0.  0.  0. ]
[I 250410 12:09:55 724153362:140] Iter:1|loss:5.2886457443237305
[I 250410 12:09:56 724153362:194] Evaluate 1|OA:0.20666666666666667|MACC:0.20666666666666667|Kappa:0.15000000000000002|MIOU:0.10442216652742968|IOU:[0.57142857 0.         0.         0.         0.26315789 0.1
     0.14285714 0.4        0.08888889 0.         0.         0.
     0.         0.         0.        ]|ACC:[0.8 0.  0.  0.  1.  0.2 0.3 0.4 0.4 0.  0.  0.  0.  0.  0. ]
[I 250410 12:10:00 724153362:140] Iter:2|loss:5.146531581878662
[I 250410 12:10:00 724153362:194] Evaluate 2|OA:0.

1.0982389450073242
makedirs ./RUNS/MambaHSI/Houston/run3_seed3/vis


[I 250410 12:24:11 724153362:98] Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        capturable: False
        differentiable: False
        eps: 1e-08
        foreach: None
        fused: None
        lr: 0.0003
        maximize: False
        weight_decay: 0
    )
[I 250410 12:24:11 724153362:104] para:418.26 K
    ,flops:132.57 GFLOPS



------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  418.26 K
fwd MACs:                                                               65.24 GMACs
fwd FLOPs:                                                              132.57 GFLOPS
fwd+bwd MACs:                                                           195.73 GMACs
fwd+bwd FLOPs:                                                          397.7 GFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module c

[I 250410 12:24:15 724153362:140] Iter:0|loss:5.404987335205078
[I 250410 12:24:16 724153362:194] Evaluate 0|OA:0.04|MACC:0.04|Kappa:-0.02857142857142857|MIOU:0.010687285223367699|IOU:[0.         0.         0.         0.         0.01030928 0.
     0.         0.1        0.         0.         0.         0.
     0.05       0.         0.        ]|ACC:[0.  0.  0.  0.  0.1 0.  0.  0.4 0.  0.  0.  0.  0.1 0.  0. ]
[I 250410 12:24:19 724153362:140] Iter:1|loss:5.194299697875977
[I 250410 12:24:20 724153362:194] Evaluate 1|OA:0.14|MACC:0.14|Kappa:0.07857142857142858|MIOU:0.06149740297027118|IOU:[0.25       0.         0.         0.         0.06976744 0.
     0.         0.14814815 0.         0.         0.         0.
     0.         0.         0.45454545]|ACC:[0.6 0.  0.  0.  0.6 0.  0.  0.4 0.  0.  0.  0.  0.  0.  0.5]
[I 250410 12:24:24 724153362:140] Iter:2|loss:5.077144622802734
[I 250410 12:24:24 724153362:194] Evaluate 2|OA:0.15333333333333332|MACC:0.15333333333333335|Kappa:0.092857142857142

1.0960330963134766
makedirs ./RUNS/MambaHSI/Houston/run4_seed4/vis


[I 250410 12:38:42 724153362:98] Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        capturable: False
        differentiable: False
        eps: 1e-08
        foreach: None
        fused: None
        lr: 0.0003
        maximize: False
        weight_decay: 0
    )
[I 250410 12:38:42 724153362:104] para:418.26 K
    ,flops:132.57 GFLOPS



------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  418.26 K
fwd MACs:                                                               65.24 GMACs
fwd FLOPs:                                                              132.57 GFLOPS
fwd+bwd MACs:                                                           195.73 GMACs
fwd+bwd FLOPs:                                                          397.7 GFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module c

[I 250410 12:38:46 724153362:140] Iter:0|loss:5.466779708862305
[I 250410 12:38:47 724153362:194] Evaluate 0|OA:0.17333333333333334|MACC:0.1733333333333333|Kappa:0.1142857142857143|MIOU:0.05246748624680626|IOU:[0.         0.23255814 0.11764706 0.         0.07317073 0.
     0.         0.         0.         0.         0.         0.
     0.         0.36363636 0.        ]|ACC:[0.  1.  0.2 0.  0.6 0.  0.  0.  0.  0.  0.  0.  0.  0.8 0. ]
[I 250410 12:38:50 724153362:140] Iter:1|loss:5.286522388458252
[I 250410 12:38:51 724153362:194] Evaluate 1|OA:0.13333333333333333|MACC:0.13333333333333333|Kappa:0.07142857142857142|MIOU:0.02210354546803145|IOU:[0.         0.23809524 0.         0.         0.09345794 0.
     0.         0.         0.         0.         0.         0.
     0.         0.         0.        ]|ACC:[0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[I 250410 12:38:54 724153362:140] Iter:2|loss:5.160701751708984
[I 250410 12:38:55 724153362:194] Evaluate 2|OA:0.13333333333333333|MACC:0.

In [6]:
def measure_memory(model, input_shape):
    # 初始化
    torch.cuda.empty_cache()
    model = model.to('cuda')
    
    # 预热
    _ = model(input_shape)
    
    # 内存基准
    torch.cuda.reset_peak_memory_stats()
    start_mem = torch.cuda.memory_allocated()
    
    # 前向传播
    output = model(input_shape)
    
    # 反向传播（如需）
    # loss = output.sum()
    # loss.backward()
    
    # 计算峰值
    peak_mem = torch.cuda.max_memory_allocated()
    del output
    return (peak_mem - start_mem) / 1024**3  # 转换为 GB

In [10]:
best_net = MambaHSI(in_channels=channels, num_classes=class_count, hidden_dim=128)
best_net.to(device)
best_net.load_state_dict(torch.load("./RUNS/MambaHSI/Houston/run0_seed0/best_tr30_val10.pth"))
best_net.eval()

MambaHSI(
  (patch_embedding): Sequential(
    (0): Conv2d(144, 128, kernel_size=(1, 1), stride=(1, 1))
    (1): GroupNorm(4, 128, eps=1e-05, affine=True)
    (2): SiLU()
  )
  (mamba): Sequential(
    (0): BothMamba(
      (softmax): Softmax(dim=0)
      (spa_mamba): SpaMamba(
        (mamba): Mamba(
          (in_proj): Linear(in_features=128, out_features=512, bias=False)
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
          (act): SiLU()
          (x_proj): Linear(in_features=256, out_features=40, bias=False)
          (dt_proj): Linear(in_features=8, out_features=256, bias=True)
          (out_proj): Linear(in_features=256, out_features=128, bias=False)
        )
        (proj): Sequential(
          (0): GroupNorm(4, 128, eps=1e-05, affine=True)
          (1): SiLU()
        )
      )
      (spe_mamba): SpeMamba(
        (mamba): Mamba(
          (in_proj): Linear(in_features=32, out_features=128, bias=False)
          (conv1d): C

In [11]:
model_a_mem = measure_memory(best_net,x)
# model_b_mem = measure_memory(model,x)

print(f"Model A 峰值内存: {model_a_mem:.3f} GB")
# print(f"Model B 峰值内存: {model_b_mem:.3f} GB")
# print(f"内存差异: {(model_b_mem - model_a_mem)/model_a_mem*100:.1f}%")

Model A 峰值内存: 17.504 GB


In [14]:
# 开始计时
start_time = time.time()
evaluator = Evaluator(num_class=class_count)
test_evaluator = Evaluator(num_class=class_count)
with torch.no_grad():
    test_evaluator.reset()
    output_test = best_net(x)

    y_test = test_label.unsqueeze(0)
    seg_logits_test = resize(input=output_test,
                        size=y_test.shape[1:],
                        mode='bilinear',
                        align_corners=True)
    predict_test = torch.argmax(seg_logits_test, dim=1).cpu().numpy()
    Y_test_np = test_label.cpu().numpy()
    Y_test_255 = np.where(Y_test_np == -1, 255, Y_test_np)
    test_evaluator.add_batch(np.expand_dims(Y_test_255, axis=0), predict_test)
    OA_test = test_evaluator.Pixel_Accuracy()
    mIOU_test, IOU_test = test_evaluator.Mean_Intersection_over_Union()
    mAcc_test, Acc_test = test_evaluator.Pixel_Accuracy_Class()
    Kappa_test = evaluator.Kappa()
    # logger.info('Test {}|OA:{}|MACC:{}|Kappa:{}|MIOU:{}|IOU:{}|ACC:{}'.format(epoch, OA_test, mAcc_test, Kappa_test, mIOU_test, IOU_test,
                                                                            # Acc_test))

# 结束计时
end_time = time.time()

# 计算运行时间
elapsed_time = end_time - start_time
print(elapsed_time)

0.5481879711151123
