In [1]:
import argparse
import numpy as np
import torch.nn as nn
import torch.utils.data
from torchsummary import summary
from utils.dataset import load_mat_hsi, sample_gt, HSIDataset
from utils.utils import split_info_print, metrics, show_results
from utils.scheduler import load_scheduler
from models.get_model import get_model
from train import train, test
from utils.utils import Draw
import time
from sklearn.utils import shuffle
from tqdm import tqdm
import torch.backends.cudnn as cudnn
import os
import random
from scipy.io import loadmat

In [2]:
parser = argparse.ArgumentParser(description="run patch-based HSI classification")
parser.add_argument("--model", type=str, default='gscvit_HL') # model name
parser.add_argument("--dataset_name", type=str, default="HongHu") # dataset name
parser.add_argument("--dataset_dir", type=str, default="../data") # dataset dir
parser.add_argument("--device", type=str, default="0")
parser.add_argument("--patch_size", type=int, default=8) # patch_size
parser.add_argument("--num_run", type=int, default=5)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--bs", type=int, default=128)  # bs = batch size
parser.add_argument("--ratio", type=float, default=0.02) # ratio of training + validation sample

opts,_ = parser.parse_known_args()

device = torch.device("cuda:{}".format(opts.device))

# print parameters
print("experiments will run on GPU device {}".format(opts.device))
print("model = {}".format(opts.model))
print("dataset = {}".format(opts.dataset_name))
print("dataset folder = {}".format(opts.dataset_dir))
print("patch size = {}".format(opts.patch_size))
print("batch size = {}".format(opts.bs))
print("total epoch = {}".format(opts.epoch))
print("{} for training, {} for validation and {} testing".format(opts.ratio / 2, opts.ratio / 2, 1 - opts.ratio))

# load data
image, gt, labels = load_mat_hsi(opts.dataset_name, opts.dataset_dir)

num_classes = len(labels)

num_bands = image.shape[-1]

# random seeds
seeds = [202401, 202402, 202403, 202404, 202405, 202406, 202407, 202408, 202409, 202410]

# empty list to storing results
results = []


experiments will run on GPU device 0
model = gscvit_HL
dataset = HongHu
dataset folder = ../data
patch size = 8
batch size = 128
total epoch = 200
0.01 for training, 0.01 for validation and 0.98 testing


In [3]:
def setup_dataloaders(image, train_gt, val_gt, opts, seed):
    """
    设置数据加载器
    """
    # 创建数据集
    train_set = HSIDataset(image, train_gt, patch_size=opts.patch_size, data_aug=True)
    
    val_set = HSIDataset(image, val_gt, patch_size=opts.patch_size, data_aug=True)

    # 创建数据加载器
    generator = torch.Generator()
    generator.manual_seed(seed)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=opts.bs,
        shuffle=True,
        drop_last=False,
        generator=generator,
        worker_init_fn=lambda worker_id: np.random.seed(seed + worker_id)
    )

    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=opts.bs,
        shuffle=False,
        drop_last=False
    )

    return train_loader, val_loader

In [4]:
def setup_seed(seed):
    # PyTorch CPU随机种子
    torch.manual_seed(seed)
    # 所有GPU的随机种子
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    # Python哈希种子（影响字典等数据结构的行为）
    os.environ['PYTHONHASHSEED'] = str(seed)
    # NumPy随机种子
    np.random.seed(seed)
    # Python内置随机种子
    random.seed(seed)
    # 启用确定性算法（降低性能但保证可重复）
    torch.backends.cudnn.deterministic = True
    # 关闭自动优化（固定卷积算法选择）
    torch.backends.cudnn.benchmark = False
    # 设置线程数，确保一致性
    torch.set_num_threads(1)

In [None]:
time_all = []
for run in range(5):
    setup_seed(seeds[run])
    print("running an experiment with the {} model".format(opts.model))
    print("run {} / {}".format(run + 1, opts.num_run))

    # get train_gt, val_gt and test_gt
    train_gt, val_gt, test_gt = sample_gt(
            gt, 
            train_samples_per_class=30,  # 每类30个训练样本
            val_samples_per_class=10,    # 每类10个验证样本
            seed=seeds[run]
        )

# 设置数据加载器
    train_loader, val_loader = setup_dataloaders(
        image=image,
        train_gt=train_gt,
        val_gt=val_gt,
        opts=opts,
        seed=seeds[run]
    )

    # load model and loss
    model = get_model(opts.model, opts.dataset_name, opts.patch_size)
    model = model.to(device)
    if run == 0:
        split_info_print(train_gt, val_gt, test_gt, labels)
        print("network information:")
        with torch.no_grad():
            input_shape = (1, num_bands, opts.patch_size, opts.patch_size)
            summary(model, input_shape)

    model = model.to(device)
    # print(model)
    optimizer, scheduler = load_scheduler(opts.model, model)

    criterion = nn.CrossEntropyLoss()

    # where to save checkpoint model
    model_dir = "./checkpoints/" + opts.model + '/' + opts.dataset_name + '/' + str(run)

    try:
        train(model, optimizer, criterion, train_loader, val_loader, opts.epoch, model_dir, device, scheduler)
    except KeyboardInterrupt:
        print('"ctrl+c" is pused, the training is over')
    start = time.time()
    # test the model
    probabilities = test(model, model_dir, image, opts.patch_size, num_classes, device)

    prediction = np.argmax(probabilities, axis=-1)
    end = time.time()
    test_time = end - start
    time_all.append(test_time)
    # computing metrics
    run_results = metrics(prediction, test_gt, n_classes=num_classes)  # only for test set
    results.append(run_results)
    show_results(run_results, label_values=labels)

    # draw the classification map
    # Draw(model,image,gt,opts.patch_size,opts.dataset_name,opts.model,num_classes)
if opts.num_run > 1:
        show_results(results, label_values=labels, agregated=True)

running an experiment with the gscvit_HL model
run 1 / 5
目标: 每类30个训练样本, 10个验证样本
------------------------------------------------------------
类别 0: 训练集 30样本, 验证集 10样本, 测试集 14001样本
类别 1: 训练集 30样本, 验证集 10样本, 测试集 3472样本
类别 2: 训练集 30样本, 验证集 10样本, 测试集 21781样本
类别 3: 训练集 30样本, 验证集 10样本, 测试集 163245样本
类别 4: 训练集 30样本, 验证集 10样本, 测试集 6178样本
类别 5: 训练集 30样本, 验证集 10样本, 测试集 44517样本
类别 6: 训练集 30样本, 验证集 10样本, 测试集 24063样本
类别 7: 训练集 30样本, 验证集 10样本, 测试集 4014样本
类别 8: 训练集 30样本, 验证集 10样本, 测试集 10779样本
类别 9: 训练集 30样本, 验证集 10样本, 测试集 12354样本
类别 10: 训练集 30样本, 验证集 10样本, 测试集 10975样本
类别 11: 训练集 30样本, 验证集 10样本, 测试集 8914样本
类别 12: 训练集 30样本, 验证集 10样本, 测试集 22467样本
类别 13: 训练集 30样本, 验证集 10样本, 测试集 7316样本
类别 14: 训练集 30样本, 验证集 10样本, 测试集 962样本
类别 15: 训练集 30样本, 验证集 10样本, 测试集 7222样本
类别 16: 训练集 30样本, 验证集 10样本, 测试集 2970样本
类别 17: 训练集 30样本, 验证集 10样本, 测试集 3177样本
类别 18: 训练集 30样本, 验证集 10样本, 测试集 8672样本
类别 19: 训练集 30样本, 验证集 10样本, 测试集 3446样本
类别 20: 训练集 30样本, 验证集 10样本, 测试集 1288样本
类别 21: 训练集 30样本, 验证集 10样本, 测试集 4000样本

验证采样结果:
---------------

  0%|▌                                                                                                                        | 1/200 [00:00<02:03,  1.61it/s]

train at epoch 1/200, loss=2.660554
epoch = 1: best OA = 0.0455, loss = 2.660554
epoch = 1: current loss = 2.660554 (best loss = 2.660554)


  1%|█▏                                                                                                                       | 2/200 [00:00<01:25,  2.33it/s]

epoch = 2: best OA = 0.0500, loss = 2.176465
epoch = 2: current loss = 2.176465 (best loss = 2.176465)


  2%|█▊                                                                                                                       | 3/200 [00:01<01:11,  2.76it/s]

epoch = 3: current loss = 1.955819 (best loss = 1.955819)


  2%|██▍                                                                                                                      | 4/200 [00:01<01:05,  3.00it/s]

epoch = 4: current loss = 1.747265 (best loss = 1.747265)


  2%|███                                                                                                                      | 5/200 [00:01<01:02,  3.14it/s]

train at epoch 5/200, loss=1.719204
epoch = 5: current loss = 1.719204 (best loss = 1.719204)


  3%|███▋                                                                                                                     | 6/200 [00:02<00:59,  3.26it/s]

epoch = 6: current loss = 1.572338 (best loss = 1.572338)


  4%|████▏                                                                                                                    | 7/200 [00:02<00:58,  3.29it/s]

epoch = 7: best OA = 0.0591, loss = 1.465154
epoch = 7: current loss = 1.465154 (best loss = 1.465154)


  4%|████▊                                                                                                                    | 8/200 [00:02<00:58,  3.30it/s]

epoch = 8: best OA = 0.1545, loss = 1.376328
epoch = 8: current loss = 1.376328 (best loss = 1.376328)


  4%|█████▍                                                                                                                   | 9/200 [00:02<00:57,  3.31it/s]

epoch = 9: best OA = 0.3364, loss = 1.341211
epoch = 9: current loss = 1.341211 (best loss = 1.341211)


  5%|██████                                                                                                                  | 10/200 [00:03<00:57,  3.31it/s]

train at epoch 10/200, loss=1.218582
epoch = 10: best OA = 0.4773, loss = 1.218582
epoch = 10: current loss = 1.218582 (best loss = 1.218582)


  6%|██████▌                                                                                                                 | 11/200 [00:03<00:56,  3.33it/s]

epoch = 11: best OA = 0.5909, loss = 1.174028
epoch = 11: current loss = 1.174028 (best loss = 1.174028)


  6%|███████▏                                                                                                                | 12/200 [00:03<00:55,  3.37it/s]

epoch = 12: best OA = 0.6818, loss = 1.137229
epoch = 12: current loss = 1.137229 (best loss = 1.137229)


  6%|███████▊                                                                                                                | 13/200 [00:04<00:55,  3.37it/s]

epoch = 13: best OA = 0.7273, loss = 1.123157
epoch = 13: current loss = 1.123157 (best loss = 1.123157)


  7%|████████▍                                                                                                               | 14/200 [00:04<00:55,  3.33it/s]

epoch = 14: current loss = 1.100103 (best loss = 1.100103)


  8%|█████████                                                                                                               | 15/200 [00:04<00:56,  3.28it/s]

train at epoch 15/200, loss=0.973648
epoch = 15: best OA = 0.7455, loss = 0.973648
epoch = 15: current loss = 0.973648 (best loss = 0.973648)


  8%|█████████▌                                                                                                              | 16/200 [00:05<00:56,  3.25it/s]

epoch = 16: current loss = 0.956830 (best loss = 0.956830)


  8%|██████████▏                                                                                                             | 17/200 [00:05<00:55,  3.28it/s]

epoch = 17: best OA = 0.7636, loss = 0.949544
epoch = 17: current loss = 0.949544 (best loss = 0.949544)


  9%|██████████▊                                                                                                             | 18/200 [00:05<00:55,  3.25it/s]

epoch = 18: best OA = 0.7818, loss = 0.874506
epoch = 18: current loss = 0.874506 (best loss = 0.874506)


 10%|███████████▍                                                                                                            | 19/200 [00:05<00:54,  3.30it/s]

epoch = 19: current loss = 0.810659 (best loss = 0.810659)


 10%|████████████                                                                                                            | 20/200 [00:06<00:54,  3.31it/s]

train at epoch 20/200, loss=0.810598
epoch = 20: current loss = 0.810598 (best loss = 0.810598)


 10%|████████████▌                                                                                                           | 21/200 [00:06<00:54,  3.26it/s]

epoch = 21: best OA = 0.7909, loss = 0.834502


 11%|█████████████▏                                                                                                          | 22/200 [00:06<00:55,  3.23it/s]

epoch = 22: best OA = 0.7909, loss = 0.767533
epoch = 22: current loss = 0.767533 (best loss = 0.767533)


 12%|██████████████▍                                                                                                         | 24/200 [00:07<00:53,  3.28it/s]

epoch = 24: best OA = 0.7909, loss = 0.682579
epoch = 24: current loss = 0.682579 (best loss = 0.682579)


 12%|███████████████                                                                                                         | 25/200 [00:07<00:52,  3.33it/s]

train at epoch 25/200, loss=0.694251


 13%|███████████████▌                                                                                                        | 26/200 [00:08<00:52,  3.34it/s]

epoch = 26: current loss = 0.600039 (best loss = 0.600039)


 14%|████████████████▏                                                                                                       | 27/200 [00:08<00:52,  3.32it/s]

epoch = 27: best OA = 0.8091, loss = 0.616573


 14%|████████████████▊                                                                                                       | 28/200 [00:08<00:51,  3.34it/s]

epoch = 28: best OA = 0.8091, loss = 0.610493


 15%|██████████████████                                                                                                      | 30/200 [00:09<00:50,  3.38it/s]

train at epoch 30/200, loss=0.597091
epoch = 30: current loss = 0.597091 (best loss = 0.597091)


 16%|██████████████████▌                                                                                                     | 31/200 [00:09<00:50,  3.37it/s]

epoch = 31: current loss = 0.559278 (best loss = 0.559278)


 16%|███████████████████▊                                                                                                    | 33/200 [00:10<00:49,  3.38it/s]

epoch = 33: best OA = 0.8091, loss = 0.471116
epoch = 33: current loss = 0.471116 (best loss = 0.471116)


 18%|█████████████████████                                                                                                   | 35/200 [00:10<00:49,  3.30it/s]

train at epoch 35/200, loss=0.467389
epoch = 35: best OA = 0.8273, loss = 0.467389
epoch = 35: current loss = 0.467389 (best loss = 0.467389)


 20%|███████████████████████▍                                                                                                | 39/200 [00:11<00:48,  3.31it/s]

epoch = 39: current loss = 0.451503 (best loss = 0.451503)


 20%|████████████████████████                                                                                                | 40/200 [00:12<00:49,  3.25it/s]

train at epoch 40/200, loss=0.461906
epoch = 40: best OA = 0.8409, loss = 0.461906


 20%|████████████████████████▌                                                                                               | 41/200 [00:12<00:48,  3.26it/s]

epoch = 41: best OA = 0.8591, loss = 0.429047
epoch = 41: current loss = 0.429047 (best loss = 0.429047)


 22%|█████████████████████████▊                                                                                              | 43/200 [00:13<00:47,  3.30it/s]

epoch = 43: current loss = 0.417893 (best loss = 0.417893)


 22%|██████████████████████████▍                                                                                             | 44/200 [00:13<00:47,  3.29it/s]

epoch = 44: current loss = 0.398015 (best loss = 0.398015)


 22%|███████████████████████████                                                                                             | 45/200 [00:13<00:47,  3.28it/s]

train at epoch 45/200, loss=0.371779
epoch = 45: current loss = 0.371779 (best loss = 0.371779)


 23%|███████████████████████████▌                                                                                            | 46/200 [00:14<00:46,  3.28it/s]

epoch = 46: current loss = 0.352306 (best loss = 0.352306)


 24%|█████████████████████████████▍                                                                                          | 49/200 [00:14<00:44,  3.39it/s]

epoch = 49: current loss = 0.337498 (best loss = 0.337498)


 25%|██████████████████████████████                                                                                          | 50/200 [00:15<00:44,  3.40it/s]

train at epoch 50/200, loss=0.402287


 26%|██████████████████████████████▌                                                                                         | 51/200 [00:15<00:44,  3.36it/s]

epoch = 51: current loss = 0.319382 (best loss = 0.319382)


 28%|█████████████████████████████████                                                                                       | 55/200 [00:16<00:42,  3.42it/s]

train at epoch 55/200, loss=0.294139
epoch = 55: current loss = 0.294139 (best loss = 0.294139)


 30%|███████████████████████████████████▍                                                                                    | 59/200 [00:17<00:40,  3.45it/s]

epoch = 59: current loss = 0.290089 (best loss = 0.290089)


 30%|████████████████████████████████████                                                                                    | 60/200 [00:18<00:39,  3.50it/s]

train at epoch 60/200, loss=0.325509


 32%|█████████████████████████████████████▊                                                                                  | 63/200 [00:19<00:39,  3.43it/s]

epoch = 63: current loss = 0.250182 (best loss = 0.250182)


 32%|███████████████████████████████████████                                                                                 | 65/200 [00:19<00:39,  3.38it/s]

train at epoch 65/200, loss=0.240137
epoch = 65: current loss = 0.240137 (best loss = 0.240137)


 33%|███████████████████████████████████████▌                                                                                | 66/200 [00:19<00:39,  3.38it/s]

epoch = 66: current loss = 0.217662 (best loss = 0.217662)


 34%|████████████████████████████████████████▏                                                                               | 67/200 [00:20<00:39,  3.37it/s]

epoch = 67: current loss = 0.195527 (best loss = 0.195527)


 35%|██████████████████████████████████████████                                                                              | 70/200 [00:21<00:38,  3.37it/s]

train at epoch 70/200, loss=0.319725


 38%|█████████████████████████████████████████████                                                                           | 75/200 [00:22<00:36,  3.41it/s]

train at epoch 75/200, loss=0.288447


 39%|██████████████████████████████████████████████▊                                                                         | 78/200 [00:23<00:35,  3.44it/s]