In [1]:
import argparse
import numpy as np
import torch.nn as nn
import torch.utils.data
from torchsummary import summary
from utils.dataset import load_mat_hsi, sample_gt, HSIDataset
from utils.utils import split_info_print, metrics, show_results
from utils.scheduler import load_scheduler
from models.get_model import get_model
from train import train, test
from utils.utils import Draw
import time
from sklearn.utils import shuffle
from tqdm import tqdm
import torch.backends.cudnn as cudnn
import os
import random

In [2]:
parser = argparse.ArgumentParser(description="run patch-based HSI classification")
parser.add_argument("--model", type=str, default='gscvit') # model name
parser.add_argument("--dataset_name", type=str, default="UP") # dataset name
parser.add_argument("--dataset_dir", type=str, default="../data") # dataset dir
parser.add_argument("--device", type=str, default="0")
parser.add_argument("--patch_size", type=int, default=8) # patch_size
parser.add_argument("--num_run", type=int, default=5)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--bs", type=int, default=128)  # bs = batch size
parser.add_argument("--ratio", type=float, default=0.02) # ratio of training + validation sample

opts,_ = parser.parse_known_args()

device = torch.device("cuda:{}".format(opts.device))

# print parameters
print("experiments will run on GPU device {}".format(opts.device))
print("model = {}".format(opts.model))
print("dataset = {}".format(opts.dataset_name))
print("dataset folder = {}".format(opts.dataset_dir))
print("patch size = {}".format(opts.patch_size))
print("batch size = {}".format(opts.bs))
print("total epoch = {}".format(opts.epoch))
print("{} for training, {} for validation and {} testing".format(opts.ratio / 2, opts.ratio / 2, 1 - opts.ratio))

# load data
image, gt, labels = load_mat_hsi(opts.dataset_name, opts.dataset_dir)

num_classes = len(labels)

num_bands = image.shape[-1]

# random seeds
seeds = [202401, 202402, 202403, 202404, 202405, 202406, 202407, 202408, 202409, 202410]

# empty list to storing results
results = []


experiments will run on GPU device 0
model = gscvit
dataset = UP
dataset folder = ../data
patch size = 8
batch size = 128
total epoch = 200
0.01 for training, 0.01 for validation and 0.98 testing


In [3]:
def setup_seed(seed):
    # PyTorch CPU随机种子
    torch.manual_seed(seed)
    # 所有GPU的随机种子
    torch.cuda.manual_seed_all(seed)
    # Python哈希种子（影响字典等数据结构的行为）
    os.environ['PYTHONHASHSEED'] = str(seed)
    # NumPy随机种子
    np.random.seed(seed)
    # Python内置随机种子
    random.seed(seed)
    # 启用确定性算法（降低性能但保证可重复）
    torch.backends.cudnn.deterministic = True
    # 关闭自动优化（固定卷积算法选择）
    torch.backends.cudnn.benchmark = False

In [4]:
def setup_dataloaders(image, train_gt, val_gt, opts, seed):
    """
    设置数据加载器
    """
    # 创建数据集
    train_set = HSIDataset(image, train_gt, patch_size=opts.patch_size, data_aug=True)
    
    val_set = HSIDataset(image, val_gt, patch_size=opts.patch_size, data_aug=True)

    # 创建数据加载器
    generator = torch.Generator()
    generator.manual_seed(seed)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=opts.bs,
        shuffle=True,
        drop_last=False,
        generator=generator,
        worker_init_fn=lambda worker_id: np.random.seed(seed + worker_id)
    )

    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=opts.bs,
        shuffle=False,
        drop_last=False
    )

    return train_loader, val_loader

In [5]:
time_all = []
for run in range(1):
    setup_seed(seeds[run])
    print("running an experiment with the {} model".format(opts.model))
    print("run {} / {}".format(run + 1, opts.num_run))

    # get train_gt, val_gt and test_gt
    train_gt, val_gt, test_gt = sample_gt(
            gt, 
            train_samples_per_class=30,  # 每类30个训练样本
            val_samples_per_class=10,    # 每类10个验证样本
            seed=seeds[run]
        )

# 设置数据加载器
    train_loader, val_loader = setup_dataloaders(
        image=image,
        train_gt=train_gt,
        val_gt=val_gt,
        opts=opts,
        seed=seeds[run]
    )

    # load model and loss
    model = get_model(opts.model, opts.dataset_name, opts.patch_size)
    model = model.to(device)
    if run == 0:
        split_info_print(train_gt, val_gt, test_gt, labels)
        print("network information:")
        with torch.no_grad():
            input_shape = (1, num_bands, opts.patch_size, opts.patch_size)
            summary(model, input_shape)

    model = model.to(device)
    # print(model)
    optimizer, scheduler = load_scheduler(opts.model, model)

    criterion = nn.CrossEntropyLoss()

    # where to save checkpoint model
    model_dir = "./checkpoints/" + opts.model + '/' + opts.dataset_name + '/' + str(run)

    try:
        train(model, optimizer, criterion, train_loader, val_loader, opts.epoch, model_dir, device, scheduler)
    except KeyboardInterrupt:
        print('"ctrl+c" is pused, the training is over')
    start = time.time()
    # test the model
    probabilities = test(model, model_dir, image, opts.patch_size, num_classes, device)

    prediction = np.argmax(probabilities, axis=-1)
    end = time.time()
    test_time = end - start
    time_all.append(test_time)
    # computing metrics
    run_results = metrics(prediction, test_gt, n_classes=num_classes)  # only for test set
    results.append(run_results)
    show_results(run_results, label_values=labels)

    # draw the classification map
    # Draw(model,image,gt,opts.patch_size,opts.dataset_name,opts.model,num_classes)
if opts.num_run > 1:
        show_results(results, label_values=labels, agregated=True)

running an experiment with the gscvit model
run 1 / 5
目标: 每类30个训练样本, 10个验证样本
------------------------------------------------------------
类别 0: 训练集 30样本, 验证集 10样本, 测试集 6591样本
类别 1: 训练集 30样本, 验证集 10样本, 测试集 18609样本
类别 2: 训练集 30样本, 验证集 10样本, 测试集 2059样本
类别 3: 训练集 30样本, 验证集 10样本, 测试集 3024样本
类别 4: 训练集 30样本, 验证集 10样本, 测试集 1305样本
类别 5: 训练集 30样本, 验证集 10样本, 测试集 4989样本
类别 6: 训练集 30样本, 验证集 10样本, 测试集 1290样本
类别 7: 训练集 30样本, 验证集 10样本, 测试集 3642样本
类别 8: 训练集 30样本, 验证集 10样本, 测试集 907样本

验证采样结果:
------------------------------------------------------------
类别                               训练集        验证集        测试集
------------------------------------------------------------
name1                             30         10       6591
name2                             30         10      18609
name3                             30         10       2059
name4                             30         10       3024
name5                             30         10       1305
name6                             30        

  0%|▌                                                                                                                           | 1/200 [00:00<01:34,  2.11it/s]

train at epoch 1/200, loss=2.428562
epoch = 1: best OA = 0.1111, loss = 2.428562
epoch = 1: current loss = 2.428562 (best loss = 2.428562)
epoch = 2: best OA = 0.1111, loss = 1.725430
epoch = 2: current loss = 1.725430 (best loss = 1.725430)


  2%|███                                                                                                                         | 5/200 [00:00<00:26,  7.23it/s]

epoch = 3: best OA = 0.1889, loss = 1.446032
epoch = 3: current loss = 1.446032 (best loss = 1.446032)
epoch = 4: current loss = 1.213050 (best loss = 1.213050)
train at epoch 5/200, loss=1.129685
epoch = 5: current loss = 1.129685 (best loss = 1.129685)


  4%|████▎                                                                                                                       | 7/200 [00:00<00:21,  8.83it/s]

epoch = 6: current loss = 1.121097 (best loss = 1.121097)
epoch = 7: current loss = 1.009536 (best loss = 1.009536)


  6%|██████▊                                                                                                                    | 11/200 [00:01<00:18, 10.46it/s]

epoch = 9: current loss = 0.820491 (best loss = 0.820491)
train at epoch 10/200, loss=0.826428
epoch = 10: best OA = 0.2111, loss = 0.826428
epoch = 11: best OA = 0.2333, loss = 0.773180
epoch = 11: current loss = 0.773180 (best loss = 0.773180)


  6%|███████▉                                                                                                                   | 13/200 [00:01<00:17, 10.90it/s]

epoch = 12: best OA = 0.3222, loss = 0.770780
epoch = 12: current loss = 0.770780 (best loss = 0.770780)
epoch = 13: current loss = 0.717329 (best loss = 0.717329)
epoch = 14: best OA = 0.3556, loss = 0.853958


  8%|█████████▏                                                                                                                 | 15/200 [00:01<00:16, 11.26it/s]

train at epoch 15/200, loss=0.816652
epoch = 15: best OA = 0.3889, loss = 0.816652
epoch = 16: best OA = 0.4556, loss = 0.657204
epoch = 16: current loss = 0.657204 (best loss = 0.657204)
epoch = 17: best OA = 0.4889, loss = 0.580714
epoch = 17: current loss = 0.580714 (best loss = 0.580714)


 10%|███████████▋                                                                                                               | 19/200 [00:02<00:15, 11.42it/s]

epoch = 18: best OA = 0.5889, loss = 0.734310
epoch = 19: best OA = 0.6333, loss = 0.586981
train at epoch 20/200, loss=0.535312
epoch = 20: best OA = 0.6556, loss = 0.535312
epoch = 20: current loss = 0.535312 (best loss = 0.535312)


 12%|██████████████▏                                                                                                            | 23/200 [00:02<00:15, 11.20it/s]

epoch = 22: best OA = 0.6556, loss = 0.607126
epoch = 23: best OA = 0.7444, loss = 0.443095
epoch = 23: current loss = 0.443095 (best loss = 0.443095)
epoch = 24: best OA = 0.7889, loss = 0.478813


 12%|███████████████▍                                                                                                           | 25/200 [00:02<00:15, 11.42it/s]

train at epoch 25/200, loss=0.473851
epoch = 25: best OA = 0.8000, loss = 0.473851
epoch = 26: best OA = 0.8000, loss = 0.518911
epoch = 27: best OA = 0.8333, loss = 0.508677


 14%|█████████████████▊                                                                                                         | 29/200 [00:02<00:15, 11.08it/s]

epoch = 28: best OA = 0.8667, loss = 0.443251
train at epoch 30/200, loss=0.429924


 16%|███████████████████                                                                                                        | 31/200 [00:03<00:16, 10.35it/s]

epoch = 30: current loss = 0.429924 (best loss = 0.429924)
epoch = 31: current loss = 0.393374 (best loss = 0.393374)


 16%|████████████████████▎                                                                                                      | 33/200 [00:03<00:16, 10.11it/s]

epoch = 32: current loss = 0.366055 (best loss = 0.366055)
epoch = 33: best OA = 0.8778, loss = 0.382992


 18%|█████████████████████▌                                                                                                     | 35/200 [00:03<00:17,  9.61it/s]

epoch = 34: best OA = 0.8778, loss = 0.389648
train at epoch 35/200, loss=0.457788
epoch = 35: best OA = 0.8778, loss = 0.457788


 18%|██████████████████████▊                                                                                                    | 37/200 [00:03<00:17,  9.28it/s]

epoch = 36: best OA = 0.8778, loss = 0.303943
epoch = 36: current loss = 0.303943 (best loss = 0.303943)
epoch = 37: current loss = 0.284085 (best loss = 0.284085)


 20%|███████████████████████▉                                                                                                   | 39/200 [00:04<00:17,  9.11it/s]

epoch = 38: best OA = 0.9000, loss = 0.254412
epoch = 38: current loss = 0.254412 (best loss = 0.254412)
epoch = 39: best OA = 0.9111, loss = 0.341413


 20%|████████████████████████▌                                                                                                  | 40/200 [00:04<00:17,  9.02it/s]

train at epoch 40/200, loss=0.337031
epoch = 40: best OA = 0.9111, loss = 0.337031


 23%|████████████████████████████▎                                                                                              | 46/200 [00:04<00:14, 10.30it/s]

train at epoch 45/200, loss=0.254860


 24%|█████████████████████████████▌                                                                                             | 48/200 [00:05<00:14, 10.15it/s]

epoch = 48: best OA = 0.9111, loss = 0.294253
train at epoch 50/200, loss=0.266905
epoch = 50: best OA = 0.9222, loss = 0.266905


 27%|█████████████████████████████████▏                                                                                         | 54/200 [00:05<00:14, 10.28it/s]

epoch = 52: current loss = 0.250370 (best loss = 0.250370)


 28%|██████████████████████████████████▍                                                                                        | 56/200 [00:05<00:14,  9.85it/s]

train at epoch 55/200, loss=0.273476
epoch = 55: best OA = 0.9333, loss = 0.273476
epoch = 56: best OA = 0.9333, loss = 0.348710


 30%|████████████████████████████████████▎                                                                                      | 59/200 [00:06<00:14, 10.03it/s]

epoch = 57: current loss = 0.204181 (best loss = 0.204181)


 30%|████████████████████████████████████▉                                                                                      | 60/200 [00:06<00:14,  9.97it/s]

train at epoch 60/200, loss=0.208975


 31%|██████████████████████████████████████▏                                                                                    | 62/200 [00:06<00:13, 10.00it/s]

epoch = 62: current loss = 0.202581 (best loss = 0.202581)
epoch = 64: best OA = 0.9333, loss = 0.220690


 33%|████████████████████████████████████████▌                                                                                  | 66/200 [00:06<00:13, 10.14it/s]

train at epoch 65/200, loss=0.307624


 34%|█████████████████████████████████████████▊                                                                                 | 68/200 [00:06<00:13,  9.87it/s]

epoch = 67: current loss = 0.184440 (best loss = 0.184440)
epoch = 68: best OA = 0.9333, loss = 0.334845


 36%|███████████████████████████████████████████▋                                                                               | 71/200 [00:07<00:12, 10.17it/s]

epoch = 69: best OA = 0.9444, loss = 0.298304
train at epoch 70/200, loss=0.194212


 36%|████████████████████████████████████████████▉                                                                              | 73/200 [00:07<00:12, 10.19it/s]

epoch = 72: current loss = 0.153545 (best loss = 0.153545)


 38%|██████████████████████████████████████████████▏                                                                            | 75/200 [00:07<00:12, 10.07it/s]

train at epoch 75/200, loss=0.231907
epoch = 75: best OA = 0.9556, loss = 0.231907


 40%|█████████████████████████████████████████████████▊                                                                         | 81/200 [00:08<00:11, 10.65it/s]

train at epoch 80/200, loss=0.230127


 44%|█████████████████████████████████████████████████████▌                                                                     | 87/200 [00:08<00:10, 10.78it/s]

train at epoch 85/200, loss=0.242596


 46%|███████████████████████████████████████████████████████▉                                                                   | 91/200 [00:09<00:10, 10.66it/s]

epoch = 89: current loss = 0.133978 (best loss = 0.133978)
train at epoch 90/200, loss=0.181533


 46%|█████████████████████████████████████████████████████████▏                                                                 | 93/200 [00:09<00:10, 10.32it/s]

epoch = 92: best OA = 0.9556, loss = 0.165759
epoch = 93: best OA = 0.9556, loss = 0.249829


 48%|██████████████████████████████████████████████████████████▍                                                                | 95/200 [00:09<00:10,  9.96it/s]

epoch = 94: best OA = 0.9667, loss = 0.107224
epoch = 94: current loss = 0.107224 (best loss = 0.107224)
train at epoch 95/200, loss=0.175964


 50%|████████████████████████████████████████████████████████████▉                                                              | 99/200 [00:09<00:09, 10.21it/s]

epoch = 98: current loss = 0.100357 (best loss = 0.100357)
train at epoch 100/200, loss=0.274861


 52%|████████████████████████████████████████████████████████████████                                                          | 105/200 [00:10<00:09, 10.41it/s]

train at epoch 105/200, loss=0.302019


 56%|███████████████████████████████████████████████████████████████████▋                                                      | 111/200 [00:11<00:08, 10.50it/s]

train at epoch 110/200, loss=0.257060


 58%|███████████████████████████████████████████████████████████████████████▎                                                  | 117/200 [00:11<00:06, 11.87it/s]

train at epoch 115/200, loss=0.126424


 60%|████████████████████████████████████████████████████████████████████████▌                                                 | 119/200 [00:11<00:06, 11.73it/s]

epoch = 118: current loss = 0.089712 (best loss = 0.089712)
train at epoch 120/200, loss=0.417147


 62%|████████████████████████████████████████████████████████████████████████████▎                                             | 125/200 [00:12<00:06, 10.81it/s]

epoch = 124: best OA = 0.9667, loss = 0.225920
train at epoch 125/200, loss=0.113897


 66%|███████████████████████████████████████████████████████████████████████████████▉                                          | 131/200 [00:12<00:06, 10.77it/s]

train at epoch 130/200, loss=0.263522


 68%|██████████████████████████████████████████████████████████████████████████████████▎                                       | 135/200 [00:13<00:06, 10.83it/s]

train at epoch 135/200, loss=0.232171


 70%|██████████████████████████████████████████████████████████████████████████████████████                                    | 141/200 [00:13<00:05, 10.63it/s]

train at epoch 140/200, loss=0.176639


 72%|████████████████████████████████████████████████████████████████████████████████████████▍                                 | 145/200 [00:14<00:05, 10.53it/s]

train at epoch 145/200, loss=0.121587
epoch = 145: best OA = 0.9667, loss = 0.121587


 76%|████████████████████████████████████████████████████████████████████████████████████████████                              | 151/200 [00:14<00:04, 10.77it/s]

train at epoch 150/200, loss=0.184233


 78%|██████████████████████████████████████████████████████████████████████████████████████████████▌                           | 155/200 [00:15<00:04, 10.84it/s]

train at epoch 155/200, loss=0.193220
epoch = 156: best OA = 0.9667, loss = 0.096569


 80%|████████████████████████████████████████████████████████████████████████████████████████████████▉                         | 159/200 [00:15<00:03, 10.40it/s]

epoch = 157: best OA = 0.9778, loss = 0.133281
epoch = 159: best OA = 0.9778, loss = 0.155026


 80%|██████████████████████████████████████████████████████████████████████████████████████████████████▏                       | 161/200 [00:15<00:03, 10.32it/s]

train at epoch 160/200, loss=0.186511
epoch = 161: current loss = 0.073748 (best loss = 0.073748)


 82%|███████████████████████████████████████████████████████████████████████████████████████████████████▍                      | 163/200 [00:15<00:03, 10.11it/s]

epoch = 162: best OA = 0.9778, loss = 0.166162
epoch = 163: best OA = 0.9778, loss = 0.077001


 82%|████████████████████████████████████████████████████████████████████████████████████████████████████▋                     | 165/200 [00:16<00:03, 10.34it/s]

train at epoch 165/200, loss=0.082011


 84%|███████████████████████████████████████████████████████████████████████████████████████████████████████                   | 169/200 [00:16<00:02, 10.56it/s]

epoch = 167: best OA = 0.9889, loss = 0.083794


 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎                 | 171/200 [00:16<00:02, 10.34it/s]

train at epoch 170/200, loss=0.072496
epoch = 170: current loss = 0.072496 (best loss = 0.072496)


 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▊               | 175/200 [00:16<00:02, 10.19it/s]

epoch = 174: current loss = 0.062790 (best loss = 0.062790)
train at epoch 175/200, loss=0.075213


 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍           | 181/200 [00:17<00:01, 11.14it/s]

train at epoch 180/200, loss=0.367062


 94%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████        | 187/200 [00:17<00:01, 11.90it/s]

train at epoch 185/200, loss=0.120718
epoch = 187: current loss = 0.055310 (best loss = 0.055310)


 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎      | 189/200 [00:18<00:00, 11.50it/s]

epoch = 188: current loss = 0.041280 (best loss = 0.041280)
train at epoch 190/200, loss=0.063892


 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 197/200 [00:18<00:00, 12.64it/s]

train at epoch 195/200, loss=0.162822


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:19<00:00, 10.50it/s]


train at epoch 200/200, loss=0.056381


inference on the HSI: 3256it [00:13, 244.42it/s]                                                                                                                 

Confusion matrix :
[[ 6168     0    14     0     0     0   353    56     0]
 [    0 16478     0   254     0  1877     0     0     0]
 [   17     0  1938     0     0     0     0   103     1]
 [   32    19     0  2944     0     7     0    14     8]
 [    1     0     0     0  1304     0     0     0     0]
 [    0   477     0     0     0  4512     0     0     0]
 [    5     0     0     0     0     0  1284     1     0]
 [   26     0   129    23     0     0     0  3464     0]
 [    0     0     0     0     5     0     0     0   902]]---
Accuracy : 91.93%
---
class acc :
	name1: 93.58
	name2: 88.55
	name3: 94.12
	name4: 97.35
	name5: 99.92
	name6: 90.44
	name7: 99.53
	name8: 95.11
	name9: 99.45
---
AA: 95.34%
Kappa: 89.47

Agregated results :
Confusion matrix :
[[6.1680e+03 0.0000e+00 1.4000e+01 0.0000e+00 0.0000e+00 0.0000e+00
  3.5300e+02 5.6000e+01 0.0000e+00]
 [0.0000e+00 1.6478e+04 0.0000e+00 2.5400e+02 0.0000e+00 1.8770e+03
  0.0000e+00 0.0000e+00 0.0000e+00]
 [1.7000e+01 0.0000e+00 1.93


