In [1]:
import argparse
import numpy as np
import torch.nn as nn
import torch.utils.data
from torchsummary import summary
from utils.dataset import load_mat_hsi, sample_gt, HSIDataset
from utils.utils import split_info_print, metrics, show_results
from utils.scheduler import load_scheduler
from models.get_model import get_model
from utils.utils import Draw
import time
from sklearn.utils import shuffle
from tqdm import tqdm
import torch.backends.cudnn as cudnn

import os
from utils.utils import grouper, sliding_window, count_sliding_window

In [2]:
parser = argparse.ArgumentParser(description="run patch-based HSI classification")
parser.add_argument("--model", type=str, default='gscvit_HL_INT') # model name
parser.add_argument("--dataset_name", type=str, default="HanChuan") # dataset name
parser.add_argument("--dataset_dir", type=str, default="../data") # dataset dir
parser.add_argument("--device", type=str, default="0")
parser.add_argument("--patch_size", type=int, default=8) # patch_size
parser.add_argument("--num_run", type=int, default=5)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--bs", type=int, default=128)  # bs = batch size
parser.add_argument("--ratio", type=float, default=0.02) # ratio of training + validation sample

opts,_ = parser.parse_known_args()

device = torch.device("cuda:{}".format(opts.device))

# print parameters
print("experiments will run on GPU device {}".format(opts.device))
print("model = {}".format(opts.model))
print("dataset = {}".format(opts.dataset_name))
print("dataset folder = {}".format(opts.dataset_dir))
print("patch size = {}".format(opts.patch_size))
print("batch size = {}".format(opts.bs))
print("total epoch = {}".format(opts.epoch))
print("{} for training, {} for validation and {} testing".format(opts.ratio / 2, opts.ratio / 2, 1 - opts.ratio))

# load data
image, gt, labels = load_mat_hsi(opts.dataset_name, opts.dataset_dir)

num_classes = len(labels)

num_bands = image.shape[-1]

# random seeds
seeds = [202401, 202402, 202403, 202404, 202405, 202406, 202407, 202408, 202409, 202410]

# empty list to storing results
results = []


experiments will run on GPU device 0
model = gscvit_HL_INT
dataset = HanChuan
dataset folder = ../data
patch size = 8
batch size = 128
total epoch = 200
0.01 for training, 0.01 for validation and 0.98 testing


In [3]:
def setup_dataloaders(image, train_gt, val_gt, opts, seed):
    """
    设置数据加载器
    """
    # 创建数据集
    train_set = HSIDataset(image, train_gt, patch_size=opts.patch_size, data_aug=True)
    
    val_set = HSIDataset(image, val_gt, patch_size=opts.patch_size, data_aug=True)

    # 创建数据加载器
    generator = torch.Generator()
    generator.manual_seed(seed)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=opts.bs,
        shuffle=True,
        drop_last=False,
        generator=generator,
        worker_init_fn=lambda worker_id: np.random.seed(seed + worker_id)
    )

    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=opts.bs,
        shuffle=False,
        drop_last=False
    )

    return train_loader, val_loader

In [4]:
def setup_seed(seed):
    # PyTorch CPU随机种子
    torch.manual_seed(seed)
    # 所有GPU的随机种子
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    # Python哈希种子（影响字典等数据结构的行为）
    os.environ['PYTHONHASHSEED'] = str(seed)
    # NumPy随机种子
    np.random.seed(seed)
    # Python内置随机种子
    random.seed(seed)
    # 启用确定性算法（降低性能但保证可重复）
    torch.backends.cudnn.deterministic = True
    # 关闭自动优化（固定卷积算法选择）
    torch.backends.cudnn.benchmark = False
    # 设置线程数，确保一致性
    torch.set_num_threads(1)

In [5]:
def train(network, optimizer, criterion, train_loader, val_loader, epoch, saving_path, device, scheduler=None):

    best_acc = -0.1
    losses = []

    for e in tqdm(range(1, epoch+1), desc=""):
        network.train()
        for batch_idx, (images, targets) in enumerate(train_loader):
            images, targets = images.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = network(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        if e % 10 == 0 or e == 1:
            mean_losses = np.mean(losses)
            train_info = "train at epoch {}/{}, loss={:.6f}"
            train_info = train_info.format(e, epoch,  mean_losses)
            tqdm.write(train_info)
            losses = []
        else:
            losses = []

        val_acc = validation(network, val_loader, device)

        if scheduler is not None:
            scheduler.step()

        is_best = val_acc >= best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(network, is_best, saving_path, epoch=e, acc=best_acc)


def validation(network, val_loader, device):
    num_correct = 0.
    total_num = 0.
    network.eval()
    for batch_idx, (images, targets) in enumerate(val_loader):
        images, targets = images.to(device), targets.to(device)
        outputs = network(images)
        _, outputs = torch.max(outputs, dim=1) 
        for output, target in zip(outputs, targets):
            num_correct = num_correct + (output.item() == target.item())
            total_num = total_num + 1
    overall_acc = num_correct / total_num
    return overall_acc


def test(network, model_dir, image, patch_size, n_classes, device):
    #network.load_state_dict(torch.load(model_dir + "/model_best.pth"))
    network.eval()

    patch_size = patch_size
    batch_size = 64
    window_size = (patch_size, patch_size)
    image_w, image_h = image.shape[:2]
    pad_size = patch_size // 2

    # pad the image
    image = np.pad(image, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode='reflect')

    probs = np.zeros(image.shape[:2] + (n_classes, ))

    iterations = count_sliding_window(image, window_size=window_size) // batch_size
    for batch in tqdm(grouper(batch_size, sliding_window(image, window_size=window_size)),
                      total=iterations,
                      desc="inference on the HSI"):
        with torch.no_grad():
            data = [b[0] for b in batch]
            data = np.copy(data)
            data = data.transpose((0, 3, 1, 2))
            data = torch.from_numpy(data)
            data = data.unsqueeze(1)

            indices = [b[1:] for b in batch]
            data = data.to(device)
            output = network(data)
            if isinstance(output, tuple):
                output = output[0]
            output = output.to('cpu').numpy()

            for (x, y, w, h), out in zip(indices, output):
                probs[x + w // 2, y + h // 2] += out
    return probs[pad_size:image_w + pad_size, pad_size:image_h + pad_size, :]


def save_checkpoint(network, is_best, saving_path, **kwargs):
    if not os.path.isdir(saving_path):
        os.makedirs(saving_path, exist_ok=True)

    if is_best:
        tqdm.write("epoch = {epoch}: best OA = {acc:.4f}".format(**kwargs))
        torch.save(network.state_dict(), os.path.join(saving_path, 'model_best.pth'))
    else:  # save the ckpt for each 10 epoch
        if kwargs['epoch'] % 10 == 0:
            torch.save(network.state_dict(), os.path.join(saving_path, 'model.pth'))



In [None]:
time_all = []
for run in range(opts.num_run):
    np.random.seed(seeds[run])
    torch.manual_seed(seeds[run])
    torch.cuda.manual_seed(seeds[run])
    cudnn.deterministic = True
    cudnn.benchmark = False
    print("running an experiment with the {} model".format(opts.model))
    print("run {} / {}".format(run + 1, opts.num_run))

    # get train_gt, val_gt and test_gt
    train_gt, val_gt, test_gt = sample_gt(
            gt, 
            train_samples_per_class=30,  # 每类30个训练样本
            val_samples_per_class=10,    # 每类10个验证样本
            seed=seeds[run]
        )

    # 设置数据加载器
    train_loader, val_loader = setup_dataloaders(
        image=image,
        train_gt=train_gt,
        val_gt=val_gt,
        opts=opts,
        seed=seeds[run]
    )

    # load model and loss
    model = get_model(opts.model, opts.dataset_name, opts.patch_size)
    model = model.to(device)
    
    model.set_qconfig()
    torch.ao.quantization.prepare_qat(model, inplace=True)
    if run == 0:
        split_info_print(train_gt, val_gt, test_gt, labels)
        print("network information:")
        with torch.no_grad():
            input_shape = (1, num_bands, opts.patch_size, opts.patch_size)
            summary(model, input_shape)

    model = model.to(device)
    # print(model)
    optimizer, scheduler = load_scheduler(opts.model, model)

    criterion = nn.CrossEntropyLoss()

    # where to save checkpoint model
    model_dir = "./checkpoints/" + opts.model + '/' + opts.dataset_name + '/' + str(run)

    try:
        train(model, optimizer, criterion, train_loader, val_loader, opts.epoch, model_dir, device, scheduler)
    except KeyboardInterrupt:
        print('"ctrl+c" is pused, the training is over')
    start = time.time()
    # test the model
    model.eval()
    torch.ao.quantization.convert(model, inplace=True)
    probabilities = test(model, model_dir, image, opts.patch_size, num_classes, device)

    prediction = np.argmax(probabilities, axis=-1)
    end = time.time()
    test_time = end - start
    time_all.append(test_time)
    # computing metrics
    run_results = metrics(prediction, test_gt, n_classes=num_classes)  # only for test set
    results.append(run_results)
    show_results(run_results, label_values=labels)

    # draw the classification map
    # Draw(model,image,gt,opts.patch_size,opts.dataset_name,opts.model,num_classes)
if opts.num_run > 1:
        show_results(results, label_values=labels, agregated=True)

running an experiment with the gscvit_HL_INT model
run 1 / 5
目标: 每类30个训练样本, 10个验证样本
------------------------------------------------------------
类别 0: 训练集 30样本, 验证集 10样本, 测试集 44695样本
类别 1: 训练集 30样本, 验证集 10样本, 测试集 22713样本
类别 2: 训练集 30样本, 验证集 10样本, 测试集 10247样本
类别 3: 训练集 30样本, 验证集 10样本, 测试集 5313样本
类别 4: 训练集 30样本, 验证集 10样本, 测试集 1160样本
类别 5: 训练集 30样本, 验证集 10样本, 测试集 4493样本
类别 6: 训练集 30样本, 验证集 10样本, 测试集 5863样本
类别 7: 训练集 30样本, 验证集 10样本, 测试集 17938样本
类别 8: 训练集 30样本, 验证集 10样本, 测试集 9429样本
类别 9: 训练集 30样本, 验证集 10样本, 测试集 10476样本
类别 10: 训练集 30样本, 验证集 10样本, 测试集 16871样本
类别 11: 训练集 30样本, 验证集 10样本, 测试集 3639样本
类别 12: 训练集 30样本, 验证集 10样本, 测试集 9076样本
类别 13: 训练集 30样本, 验证集 10样本, 测试集 18520样本
类别 14: 训练集 30样本, 验证集 10样本, 测试集 1096样本
类别 15: 训练集 30样本, 验证集 10样本, 测试集 75361样本

验证采样结果:
------------------------------------------------------------
类别                               训练集        验证集        测试集
------------------------------------------------------------
name1                             30         10      44695


  0%|▏                                         | 1/200 [00:00<02:21,  1.41it/s]

train at epoch 1/200, loss=2.537418
epoch = 1: best OA = 0.0625


  1%|▍                                         | 2/200 [00:00<01:30,  2.19it/s]

epoch = 2: best OA = 0.0625


  2%|▋                                         | 3/200 [00:01<01:14,  2.65it/s]

epoch = 3: best OA = 0.0813


  2%|▊                                         | 4/200 [00:01<01:06,  2.96it/s]

epoch = 4: best OA = 0.0938


  4%|█▋                                        | 8/200 [00:02<00:53,  3.57it/s]

epoch = 8: best OA = 0.1000


  4%|█▉                                        | 9/200 [00:02<00:53,  3.55it/s]

epoch = 9: best OA = 0.1062


  5%|██                                       | 10/200 [00:03<00:53,  3.56it/s]

train at epoch 10/200, loss=1.084388
epoch = 10: best OA = 0.1125


  6%|██▎                                      | 11/200 [00:03<00:53,  3.56it/s]

epoch = 11: best OA = 0.1625


  6%|██▍                                      | 12/200 [00:03<00:52,  3.55it/s]

epoch = 12: best OA = 0.1938


  7%|██▊                                      | 14/200 [00:04<00:51,  3.63it/s]

epoch = 14: best OA = 0.2188


  8%|███▎                                     | 16/200 [00:04<00:50,  3.63it/s]

epoch = 16: best OA = 0.2437


  8%|███▍                                     | 17/200 [00:05<00:50,  3.62it/s]

epoch = 17: best OA = 0.3438


  9%|███▋                                     | 18/200 [00:05<00:50,  3.61it/s]

epoch = 18: best OA = 0.4500


 10%|███▉                                     | 19/200 [00:05<00:50,  3.60it/s]

epoch = 19: best OA = 0.5375


 10%|████                                     | 20/200 [00:05<00:50,  3.58it/s]

train at epoch 20/200, loss=0.630276
epoch = 20: best OA = 0.5625


 10%|████▎                                    | 21/200 [00:06<00:50,  3.58it/s]

epoch = 21: best OA = 0.6813


 11%|████▌                                    | 22/200 [00:06<00:49,  3.58it/s]

epoch = 22: best OA = 0.7312


 12%|████▋                                    | 23/200 [00:06<00:49,  3.58it/s]

epoch = 23: best OA = 0.7625


 12%|█████▏                                   | 25/200 [00:07<00:48,  3.64it/s]

epoch = 25: best OA = 0.8000


 14%|█████▉                                   | 29/200 [00:08<00:45,  3.75it/s]

epoch = 29: best OA = 0.8125


 15%|██████▏                                  | 30/200 [00:08<00:45,  3.70it/s]

train at epoch 30/200, loss=0.504542


 16%|██████▌                                  | 32/200 [00:09<00:45,  3.71it/s]

epoch = 32: best OA = 0.8125


 18%|███████▏                                 | 35/200 [00:09<00:44,  3.73it/s]

epoch = 35: best OA = 0.8187


 19%|███████▊                                 | 38/200 [00:10<00:43,  3.76it/s]

epoch = 38: best OA = 0.8187


 20%|████████▏                                | 40/200 [00:11<00:42,  3.74it/s]

train at epoch 40/200, loss=0.357439


 25%|██████████▎                              | 50/200 [00:13<00:40,  3.75it/s]

train at epoch 50/200, loss=0.263092


 26%|██████████▋                              | 52/200 [00:14<00:39,  3.71it/s]

epoch = 52: best OA = 0.8250


 28%|███████████▎                             | 55/200 [00:15<00:38,  3.75it/s]

epoch = 55: best OA = 0.8313


 28%|███████████▍                             | 56/200 [00:15<00:38,  3.71it/s]

epoch = 56: best OA = 0.8562


 30%|████████████▎                            | 60/200 [00:16<00:37,  3.76it/s]

train at epoch 60/200, loss=0.196005


 35%|██████████████▎                          | 70/200 [00:19<00:35,  3.69it/s]

train at epoch 70/200, loss=0.164008


 40%|████████████████▍                        | 80/200 [00:21<00:32,  3.74it/s]

train at epoch 80/200, loss=0.162081


 45%|██████████████████▍                      | 90/200 [00:24<00:29,  3.74it/s]

train at epoch 90/200, loss=0.149350


 50%|████████████████████                    | 100/200 [00:27<00:26,  3.83it/s]

train at epoch 100/200, loss=0.092868


 52%|████████████████████▌                   | 103/200 [00:27<00:25,  3.79it/s]

epoch = 103: best OA = 0.8688


 55%|██████████████████████                  | 110/200 [00:29<00:24,  3.74it/s]

train at epoch 110/200, loss=0.073864


 60%|████████████████████████                | 120/200 [00:32<00:20,  3.83it/s]

train at epoch 120/200, loss=0.083121


 65%|██████████████████████████              | 130/200 [00:34<00:18,  3.77it/s]

train at epoch 130/200, loss=0.114193


 70%|████████████████████████████            | 140/200 [00:37<00:15,  3.90it/s]

train at epoch 140/200, loss=0.096641


 75%|██████████████████████████████          | 150/200 [00:39<00:13,  3.74it/s]

train at epoch 150/200, loss=0.086045


 80%|████████████████████████████████        | 160/200 [00:42<00:10,  3.82it/s]

train at epoch 160/200, loss=0.091842


 85%|██████████████████████████████████      | 170/200 [00:45<00:07,  3.81it/s]

train at epoch 170/200, loss=0.075529


 88%|███████████████████████████████████▏    | 176/200 [00:46<00:06,  3.81it/s]

epoch = 176: best OA = 0.8688


 90%|████████████████████████████████████    | 180/200 [00:47<00:05,  3.84it/s]

train at epoch 180/200, loss=0.153613


 95%|██████████████████████████████████████  | 190/200 [00:50<00:02,  3.83it/s]

train at epoch 190/200, loss=0.059368


100%|████████████████████████████████████████| 200/200 [00:52<00:00,  3.78it/s]


train at epoch 200/200, loss=0.057114


inference on the HSI: 5786it [00:40, 142.83it/s]                               


Confusion matrix :
[[34812    20     0     0   468   135     8     0    16     4    18   301
   8380   510    23     0]
 [   45 16557   329    45   768  1848   544   508   748     0     0     9
   1309     0     3     0]
 [    0    75  8637    30     0  1058     1   303     6     0     0     0
    137     0     0     0]
 [    0     0    27  5176     0     2     0     0   108     0     0     0
      0     0     0     0]
 [    0     0     0     0  1160     0     0     0     0     0     0     0
      0     0     0     0]
 [  312   110    65     4    65  3466     0   103   132     1     0    12
    215     0     8     0]
 [    0     0     0     0     0     0  4355   283   661     0     0     5
    559     0     0     0]
 [  169   353  1273   677    25  1090   213 13126   748    13     0     1
    237    12     1     0]
 [  209   397     0    74   851    60    10   196  7330     0     0     0
    302     0     0     0]
 [    0     3     1     0     0     0     0     1    66 10165    19     

  0%|▏                                         | 1/200 [00:00<00:56,  3.53it/s]

train at epoch 1/200, loss=2.484281
epoch = 1: best OA = 0.0625


  1%|▍                                         | 2/200 [00:00<00:55,  3.56it/s]

epoch = 2: best OA = 0.0625


  2%|▋                                         | 3/200 [00:00<00:55,  3.58it/s]

epoch = 3: best OA = 0.0625


  2%|▊                                         | 4/200 [00:01<00:55,  3.53it/s]

epoch = 4: best OA = 0.0625


  2%|█                                         | 5/200 [00:01<00:54,  3.55it/s]

epoch = 5: best OA = 0.0875


  3%|█▎                                        | 6/200 [00:01<00:54,  3.57it/s]

epoch = 6: best OA = 0.0875


  4%|█▋                                        | 8/200 [00:02<00:52,  3.66it/s]

epoch = 8: best OA = 0.0875


  4%|█▉                                        | 9/200 [00:02<00:52,  3.65it/s]

epoch = 9: best OA = 0.0938


  5%|██                                       | 10/200 [00:02<00:52,  3.63it/s]

train at epoch 10/200, loss=1.067634
epoch = 10: best OA = 0.1375


  6%|██▎                                      | 11/200 [00:03<00:52,  3.63it/s]

epoch = 11: best OA = 0.1437


  6%|██▍                                      | 12/200 [00:03<00:51,  3.62it/s]

epoch = 12: best OA = 0.1625


  6%|██▋                                      | 13/200 [00:03<00:51,  3.62it/s]

epoch = 13: best OA = 0.1938


  7%|██▊                                      | 14/200 [00:03<00:50,  3.71it/s]

epoch = 14: best OA = 0.2375


  8%|███                                      | 15/200 [00:04<00:50,  3.69it/s]

epoch = 15: best OA = 0.3563


  8%|███▎                                     | 16/200 [00:04<00:50,  3.67it/s]

epoch = 16: best OA = 0.3625


  8%|███▍                                     | 17/200 [00:04<00:50,  3.66it/s]

epoch = 17: best OA = 0.4625


  9%|███▋                                     | 18/200 [00:04<00:50,  3.62it/s]

epoch = 18: best OA = 0.5500


 10%|███▉                                     | 19/200 [00:05<00:49,  3.64it/s]

epoch = 19: best OA = 0.6438


 10%|████                                     | 20/200 [00:05<00:48,  3.68it/s]

train at epoch 20/200, loss=0.696743


 10%|████▎                                    | 21/200 [00:05<00:47,  3.80it/s]

epoch = 21: best OA = 0.7688


 11%|████▌                                    | 22/200 [00:05<00:45,  3.91it/s]

epoch = 22: best OA = 0.7688


 12%|████▋                                    | 23/200 [00:06<00:46,  3.80it/s]

epoch = 23: best OA = 0.8000


 12%|█████▏                                   | 25/200 [00:06<00:47,  3.70it/s]

epoch = 25: best OA = 0.8000


 14%|█████▋                                   | 28/200 [00:07<00:47,  3.64it/s]

epoch = 28: best OA = 0.8438


 15%|██████▏                                  | 30/200 [00:08<00:46,  3.65it/s]

train at epoch 30/200, loss=0.418000


 20%|████████▏                                | 40/200 [00:10<00:42,  3.72it/s]

train at epoch 40/200, loss=0.390314


 24%|█████████▋                               | 47/200 [00:12<00:41,  3.65it/s]

epoch = 47: best OA = 0.8562


 24%|██████████                               | 49/200 [00:13<00:42,  3.55it/s]

epoch = 49: best OA = 0.8562


 25%|██████████▎                              | 50/200 [00:13<00:42,  3.57it/s]

train at epoch 50/200, loss=0.270543


 27%|███████████                              | 54/200 [00:14<00:39,  3.71it/s]

epoch = 54: best OA = 0.8625


 30%|████████████▎                            | 60/200 [00:16<00:36,  3.84it/s]

train at epoch 60/200, loss=0.211697


 30%|████████████▌                            | 61/200 [00:16<00:37,  3.76it/s]

epoch = 61: best OA = 0.8938


 35%|██████████████▎                          | 70/200 [00:18<00:34,  3.80it/s]

train at epoch 70/200, loss=0.154588


 36%|██████████████▉                          | 73/200 [00:19<00:33,  3.78it/s]

epoch = 73: best OA = 0.9062


 40%|████████████████▍                        | 80/200 [00:21<00:32,  3.74it/s]

train at epoch 80/200, loss=0.099964


 45%|██████████████████▍                      | 90/200 [00:23<00:27,  3.98it/s]

train at epoch 90/200, loss=0.067807


 48%|███████████████████▋                     | 96/200 [00:25<00:27,  3.81it/s]

epoch = 96: best OA = 0.9125


 50%|████████████████████                    | 100/200 [00:26<00:26,  3.77it/s]

train at epoch 100/200, loss=0.078409


 55%|██████████████████████                  | 110/200 [00:29<00:24,  3.69it/s]

train at epoch 110/200, loss=0.079835


 60%|████████████████████████                | 120/200 [00:31<00:21,  3.70it/s]

train at epoch 120/200, loss=0.103338


 65%|██████████████████████████              | 130/200 [00:34<00:18,  3.78it/s]

train at epoch 130/200, loss=0.057645


 70%|████████████████████████████            | 140/200 [00:37<00:15,  3.79it/s]

train at epoch 140/200, loss=0.085303


 75%|██████████████████████████████          | 150/200 [00:39<00:13,  3.70it/s]

train at epoch 150/200, loss=0.152111


 80%|████████████████████████████████        | 160/200 [00:42<00:10,  3.78it/s]

train at epoch 160/200, loss=0.066089


 85%|██████████████████████████████████      | 170/200 [00:45<00:07,  3.76it/s]

train at epoch 170/200, loss=0.033085


 90%|████████████████████████████████████    | 180/200 [00:47<00:05,  3.74it/s]

train at epoch 180/200, loss=0.026896


 95%|██████████████████████████████████████  | 190/200 [00:50<00:02,  3.75it/s]

train at epoch 190/200, loss=0.057407


100%|████████████████████████████████████████| 200/200 [00:52<00:00,  3.78it/s]


train at epoch 200/200, loss=0.031745


inference on the HSI:  56%|████████▍      | 3231/5785 [00:21<00:15, 160.34it/s]