In [1]:
from model import ss_fusion_seg
import torch
from torch  import nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report,cohen_kappa_score
from model import split_data,utils,create_graph
from sklearn import metrics, preprocessing
from mmengine.optim import build_optim_wrapper
from mmcv_custom import custom_layer_decay_optimizer_constructor,layer_decay_optimizer_constructor_vit
import random
import os
import torch.utils.data as Data
import copy
import scipy.io as sio
import spectral as spy
from collections import Counter
from sklearn.decomposition import PCA

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [3]:
class DataReader():
    def __init__(self):
        self.data_cube = None
        self.g_truth = None

    @property
    def cube(self):
        """
        origin data
        """
        return self.data_cube

    @property
    def truth(self):
        return self.g_truth

    @property
    def normal_cube(self):
        """
        normalization data: range(0, 1)
        """
        return (self.data_cube - np.min(self.data_cube)) / (np.max(self.data_cube) - np.min(self.data_cube))
class IndianRaw(DataReader):
    def __init__(self):
        super(IndianRaw, self).__init__()
        raw_data_package = sio.loadmat(r"data/Indian_pines_corrected.mat")
        self.data_cube = raw_data_package["data"].astype(np.float32)
        truth = sio.loadmat(r"data/Indian_pines_gt.mat")
        self.g_truth = truth["groundT"].astype(np.float32)

In [4]:
def setup_seed(seed):
    random.seed(seed)  # Python的随机性
    os.environ['PYTHONHASHSEED'] = str(seed)  # 设置Python哈希种子，为了禁止hash随机化，使得实验可复现
    np.random.seed(seed)  # numpy的随机性
    torch.manual_seed(seed)  # torch的CPU随机性，为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # torch的GPU随机性，为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.   torch的GPU随机性，为所有GPU设置随机种子
    torch.backends.cudnn.deterministic = True # 选择确定性算法
    torch.backends.cudnn.benchmark = False # if benchmark=True, deterministic will be False
def evaluate_performance(network_output, train_samples_gt, train_samples_gt_onehot, zeros):
    with torch.no_grad():
        available_label_idx = (train_samples_gt!=0).float()        # 有效标签的坐标,用于排除背景
        available_label_count = available_label_idx.sum()          # 有效标签的个数
        correct_prediction = torch.where(network_output==torch.argmax(train_samples_gt_onehot, 1), available_label_idx, zeros).sum()
        OA= correct_prediction.cpu() / available_label_count
        return OA
def get_patch(img_size,data,data_gt,overlap_size):
    input_size=(img_size, img_size)
    height_orgin, width_orgin, bands = data.shape
    image_size=(height_orgin,width_orgin)

    LyEnd,LxEnd = np.subtract(image_size, input_size)
    Lx = np.linspace(0, LxEnd, int(np.ceil(LxEnd/float(input_size[1]-overlap_size)))+1, endpoint=True).astype('int')
    Ly = np.linspace(0, LyEnd, int(np.ceil(LyEnd/float(input_size[0]-overlap_size)))+1, endpoint=True).astype('int')
    N=len(Ly)*len(Lx)
    X_data=np.zeros([N,input_size[0],input_size[1],data.shape[-1]])#N,H,W,C
    y_data=np.zeros([N,input_size[0],input_size[1]])
    i=0
    for j in range(len(Ly)):
        for k in range(len(Lx)):
            rStart,cStart = (Ly[j],Lx[k])
            rEnd,cEnd = (rStart+input_size[0],cStart+input_size[1])
            X_data[i] = data[rStart:rEnd,cStart:cEnd,:]
            y_data[i] = data_gt[rStart:rEnd,cStart:cEnd]
            i+=1
    return X_data,y_data
def Get_train_and_test_data(img_size, img,img_gt):
    H0, W0, C = img.shape
    if H0<img_size:
        gap = img_size-H0
        mirror_img = img[(H0-gap):H0,:,:]
        mirror_img_gt = img_gt[(H0-gap):H0,:]
        img = np.concatenate([img,mirror_img],axis=0)
        img_gt = np.concatenate([img_gt,mirror_img_gt],axis=0)
    if W0<img_size:
        gap = img_size-W0
        mirror_img = img[:,(W0 - gap):W0,:]
        mirror_img_gt = img_gt[(W0-gap):W0,:]
        img = np.concatenate([img,mirror_img],axis=1)
        img_gt = np.concatenate([img_gt,mirror_img_gt],axis=1)
    H, W, C = img.shape

    num_H = H // img_size
    num_W = W // img_size
    sub_H = H % img_size
    sub_W = W % img_size
    if sub_H != 0:
        gap = (num_H+1)*img_size - H
        mirror_img = img[(H - gap):H, :, :]
        mirror_img_gt = img_gt[(H - gap):H, :]
        img = np.concatenate([img, mirror_img], axis=0)
        img_gt = np.concatenate([img_gt,mirror_img_gt],axis=0)

    if sub_W != 0:
        gap = (num_W + 1) * img_size - W
        mirror_img = img[:, (W - gap):W, :]
        mirror_img_gt = img_gt[:, (W - gap):W]
        img = np.concatenate([img, mirror_img], axis=1)
        img_gt = np.concatenate([img_gt,mirror_img_gt],axis=1)
        # gap = img_size - num_W*img_size
        # img = img[:,(W - gap):W,:]
    H, W, C = img.shape
    print('padding img:', img.shape)

    num_H = H // img_size
    num_W = W // img_size
    index = torch.arange(1, H*W+1)
    index=index.reshape(H,W)
    sub_imgs = []
    sub_indexs= []

    for i in range(num_H):
        for j in range(num_W):
            z = img[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :]
            sub_imgs.append(z)
            w = index[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size]
            sub_indexs.append(w)
    sub_imgs = np.array(sub_imgs)
    sub_indexs = np.array(sub_indexs)  # [num_H*num_W,img_size,img_size, C ]

    return sub_imgs,sub_indexs, num_H, num_W,img,img_gt
def patch_reshape(pred,num_H, num_W, class_num, img_size):
    pred = torch.reshape(pred, [num_H, num_W, class_num, img_size, img_size])
    pred = torch.permute(pred, [2, 0, 3, 1, 4])  # [2,num_H, img_size,num_W, img_size]]
    pred = torch.reshape(pred, [class_num, num_H * img_size* num_W * img_size])
    pred = torch.permute(pred, [1, 0]) 
    return pred

In [5]:
def load_data():
    data = IndianRaw().normal_cube
    data_gt = IndianRaw().truth
    return data, data_gt


In [6]:
data, data_gt = load_data()

In [7]:
# ============ 数据集参数配置（PU / IP / HH / HC） ============
CONFIGS = {
    "PU": {  
        "patch_size": 8,
        "img_size": 128,
        "pca_components": 20,
        "split_type": "number",
        "train_num": 20,
        "val_num": 20,
        "train_ratio": 0.05,
        "val_ratio": 0.01,
        "max_epoch": 500,
        "batch_size": 2,
        "overlap_size": 32
    },

    "IP": {  
        "patch_size": 8,
        "img_size": 145,
        "pca_components": 20,
        "split_type": "number",
        "train_num": 10,
        "val_num": 5,
        "train_ratio": 0.05,
        "val_ratio": 0.01,
        "max_epoch": 500,
        "batch_size": 1,
        "overlap_size": 0
    },

    "HH": {  
        "patch_size": 8,
        "img_size": 128,
        "pca_components": 30,
        "split_type": "number",
        "train_num": 50,
        "val_num": 50,
        "train_ratio": 0.05,
        "val_ratio": 0.01,
        "max_epoch": 500,
        "batch_size": 2,
        "overlap_size": 32
    },

    "HC": { 
       "patch_size": 8,
        "img_size": 128,
        "pca_components": 30,
        "split_type": "number",
        "train_num": 50,
        "val_num": 50,
        "train_ratio": 0.05,
        "val_ratio": 0.01,
        "max_epoch": 500,
        "batch_size": 2,
        "overlap_size": 32
    }
}

In [8]:
dataset_name = "IP"   # ← 修改这里即可切换 PU / IP / HH / HC

cfg = CONFIGS[dataset_name]

patch_size      = cfg["patch_size"]
img_size        = cfg["img_size"]
pca_components  = cfg["pca_components"]
split_type      = cfg["split_type"]
train_num       = cfg["train_num"]
val_num         = cfg["val_num"]
train_ratio     = cfg["train_ratio"]
val_ratio       = cfg["val_ratio"]
max_epoch       = cfg["max_epoch"]
batch_size      = cfg["batch_size"]
overlap_size    = cfg["overlap_size"]

print("当前数据集:", dataset_name)
print(cfg)

path_weight = r"weights//"
path_result = r"result//"
height_orgin, width_orgin, bands = data.shape
class_num_level2 = np.max(data_gt)
class_num_level2 = class_num_level2.astype(int)
setup_seed(3704)


当前数据集: PU
{'patch_size': 8, 'img_size': 128, 'pca_components': 20, 'split_type': 'number', 'train_num': 20, 'val_num': 20, 'train_ratio': 0.05, 'val_ratio': 0.01, 'max_epoch': 500, 'batch_size': 2, 'overlap_size': 32}


In [9]:
data, pca = split_data.apply_PCA(data, num_components=pca_components)
height_orgin, width_orgin, bands = data.shape

In [10]:
gt_reshape = np.reshape(data_gt, [-1])
class_num = np.max(gt_reshape)
class_num = class_num.astype(int)
train_index, val_index, test_index = split_data.split_data(gt_reshape, 
            class_num, train_ratio, train_ratio, train_num, val_num, split_type)

In [11]:
train_index=train_index.astype(int)
val_index=val_index.astype(int)
test_index=test_index.astype(int)
class_num = np.max(gt_reshape)
class_num = class_num.astype(int)

In [12]:
train_samples_gt, test_samples_gt, val_samples_gt = create_graph.get_label(gt_reshape,
                                                train_index, val_index, test_index)

train_label_mask, test_label_mask, val_label_mask = create_graph.get_label_mask(train_samples_gt, 
                                        test_samples_gt, val_samples_gt, data_gt, class_num)



In [13]:
train_gt = np.reshape(train_samples_gt,[height_orgin,width_orgin])
test_gt = np.reshape(test_samples_gt,[height_orgin,width_orgin])
val_gt = np.reshape(val_samples_gt,[height_orgin,width_orgin])


train_gt_onehot = create_graph.label_to_one_hot(train_gt, class_num)
test_gt_onehot = create_graph.label_to_one_hot(test_gt, class_num)
val_gt_onehot = create_graph.label_to_one_hot(val_gt, class_num)



In [14]:
train_samples_gt=torch.from_numpy(train_samples_gt.astype(np.float32))
test_samples_gt=torch.from_numpy(test_samples_gt.astype(np.float32))
val_samples_gt=torch.from_numpy(val_samples_gt.astype(np.float32))

train_gt_onehot = torch.from_numpy(train_gt_onehot.astype(np.float32))
test_gt_onehot = torch.from_numpy(test_gt_onehot.astype(np.float32))
val_gt_onehot = torch.from_numpy(val_gt_onehot.astype(np.float32))

train_label_mask = torch.from_numpy(train_label_mask.astype(np.float32))
test_label_mask = torch.from_numpy(test_label_mask.astype(np.float32))
val_label_mask = torch.from_numpy(val_label_mask.astype(np.float32))




In [15]:


img_train_hsi,gt_train=get_patch(img_size, data,train_gt,overlap_size) 
img_val_hsi,gt_val=get_patch(img_size, data,val_gt,overlap_size) 
img_test_hsi,gt_test=get_patch(img_size, data,test_gt,overlap_size)  


gt_train = torch.from_numpy(gt_train).type(torch.LongTensor) 
gt_test = torch.from_numpy(gt_test).type(torch.LongTensor)  
gt_val = torch.from_numpy(gt_val).type(torch.LongTensor) 


####################################################################

img_train_hsi = torch.from_numpy(img_train_hsi.transpose(0,3,1,2)).type(torch.FloatTensor) 
img_test_hsi = torch.from_numpy(img_test_hsi.transpose(0,3,1,2)).type(torch.FloatTensor) 
img_val_hsi = torch.from_numpy(img_val_hsi.transpose(0,3,1,2)).type(torch.FloatTensor) 


In [16]:
input_size=(img_size, img_size)
height_orgin, width_orgin, bands = data.shape
image_size=(height_orgin,width_orgin)

LyEnd,LxEnd = np.subtract(image_size, input_size)
Lx = np.linspace(0, LxEnd, int(np.ceil(LxEnd/float(input_size[1]-overlap_size)))+1, endpoint=True).astype('int')
Ly = np.linspace(0, LyEnd, int(np.ceil(LyEnd/float(input_size[0]-overlap_size)))+1, endpoint=True).astype('int')

In [17]:
class TrainDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = img_train_hsi.shape[0]
        self.x_data_hsi = torch.FloatTensor(img_train_hsi)
        self.y_data = torch.LongTensor(gt_train)
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data_hsi[index], self.y_data[index]

    def __len__(self):
        # 返回文件数据的数目
        return self.len


""" Val dataset"""
class ValDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = img_val_hsi.shape[0]
        self.x_data_hsi = torch.FloatTensor(img_val_hsi)
        self.y_data = torch.LongTensor(gt_val)
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data_hsi[index], self.y_data[index]

    def __len__(self):
        # 返回文件数据的数目
        return self.len



""" Testing dataset"""


class TestDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = img_test_hsi.shape[0]
        self.x_data_hsi = torch.FloatTensor(img_test_hsi)
        self.y_data = torch.LongTensor(gt_test) 
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data_hsi[index], self.y_data[index]

    def __len__(self):
        # 返回文件数据的数目
        return self.len


# 创建 trainloader 和 testloader
trainset = TrainDS()
valset = ValDS()
testset = TestDS()
train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True, num_workers=0,drop_last=True)
Val_loader = torch.utils.data.DataLoader(dataset=valset, batch_size=batch_size, shuffle=True, num_workers=0,drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, num_workers=0,drop_last=True)

In [18]:
test_samples_gt=test_samples_gt.to(device)
test_gt_onehot=test_gt_onehot.to(device)

In [19]:
zeros = torch.zeros([height_orgin * width_orgin]).to(device).float()
model = ss_fusion_seg.SSFusionFramework(
                img_size = img_size,
                in_channels = pca_components,
                patch_size=patch_size,
                classes = class_num,
                model_size='base'#The optional values are 'base','large' and 'huge'
)
optim_wrapper = dict(
    optimizer=dict(
    type='AdamW', lr=6e-5, betas=(0.9, 0.999), weight_decay=0.05),
    constructor='LayerDecayOptimizerConstructor_ViT', 
    paramwise_cfg=dict(
        num_layers=12, 
        layer_decay_rate=0.9,
        )
        )
optimizer = build_optim_wrapper(model, optim_wrapper)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.optimizer, max_epoch, eta_min=0, last_epoch=-1)
criterion = nn.CrossEntropyLoss()
model.to(device)
count = 0
best_loss = 99999
train_losses = []
val_losses = []
train_index =train_index.reshape(-1,)
test_index =test_index.reshape(-1,)
val_loss = 0
train_loss = 0
train_correct = 0
train_total= 1
val_correct = 0
val_total= 0
best_loss=999
test_correct=0
test_total=0

{'num_layers': 12, 'layer_decay_rate': 0.9}
Build LayerDecayOptimizerConstructor 0.900000 - 14
Param groups = {
  "layer_13_no_decay": {
    "param_names": [
      "spat_encoder.pos_embed",
      "spat_encoder.classifier.0.bias",
      "spat_encoder.classifier.1.bias",
      "spat_encoder.classifier.2.bias",
      "spat_encoder.patch_embed.proj.bias",
      "spat_encoder.blocks.0.norm1.weight",
      "spat_encoder.blocks.0.norm1.bias",
      "spat_encoder.blocks.0.attn.qkv.bias",
      "spat_encoder.blocks.0.attn.sampling_offsets.bias",
      "spat_encoder.blocks.0.attn.proj.bias",
      "spat_encoder.blocks.0.norm2.weight",
      "spat_encoder.blocks.0.norm2.bias",
      "spat_encoder.blocks.0.mlp.fc1.bias",
      "spat_encoder.blocks.0.mlp.fc2.bias",
      "spat_encoder.blocks.1.norm1.weight",
      "spat_encoder.blocks.1.norm1.bias",
      "spat_encoder.blocks.1.attn.qkv.bias",
      "spat_encoder.blocks.1.attn.sampling_offsets.bias",
      "spat_encoder.blocks.1.attn.proj.bias",
  

In [20]:
model_params =model.state_dict()
spat_net = torch.load((r"spat-base.pth"), map_location=torch.device('cpu'))
for k in list(spat_net['model'].keys()):
    if 'patch_embed.proj' in k:
        del spat_net['model'][k]
for k in list(spat_net['model'].keys()):
    if 'spat_map' in k:
        del spat_net['model'][k]
for k in list(spat_net['model'].keys()):
    if 'spat_output_maps' in k:
        del spat_net['model'][k]
for k in list(spat_net['model'].keys()):
    if 'pos_embed' in k:
        del spat_net['model'][k]
spat_weights = {}
prefix = 'spat_encoder.'
for key, value in spat_net['model'].items():
    new_key = prefix + key
    spat_weights[new_key] = value
per_net = torch.load((r"spec-base.pth"), map_location=torch.device('cpu'))
model_params =model.state_dict()
for k in list(per_net['model'].keys()):
    if 'patch_embed.proj' in k:
        del per_net['model'][k]
    if 'spat_map' in k:
        del per_net['model'][k]
    if 'fpn1.0.weight' in k:
        del per_net['model'][k]
spec_weights = {}
prefix = 'spec_encoder.'
for key, value in per_net['model'].items():
    new_key = prefix + key
    spec_weights[new_key] = value
model_params =model.state_dict()
for k in list(spec_weights.keys()):
    if 'spec_encoder.patch_embed' in k:
        del spec_weights[k]
merged_params = {**spat_weights, **spec_weights}
same_parsms = {k: v for k, v in merged_params.items() if k in model_params.keys()}
model_params.update(same_parsms)
model.load_state_dict(model_params)

<All keys matched successfully>

In [21]:
img_hsi=torch.from_numpy(data).permute(2,0,1).to(device)
img_hsi = img_hsi.to(torch.float32)


In [22]:
for epoch in range(100+ 1):
    count=0
    train_loss = 0
    for netinput_hsi,netinput_gt in train_loader:
        model.train()
        netinput_hsi = netinput_hsi.to(device)
        b,c,h,w = netinput_hsi.shape
        batch_pred = model(netinput_hsi)
        batch_pred = batch_pred.reshape(b*img_size*img_size,-1)
        netinput_gt = netinput_gt.reshape(-1)
        available_label_idx = np.where(netinput_gt!=0)
        if len(available_label_idx[0]) != 0:
            available_label_idx= torch.from_numpy(available_label_idx[0]).to(device)
            netinput_gt = netinput_gt.to(device)
            netinput_gt = netinput_gt-1
            batch_pred_loss = batch_pred[available_label_idx]
            netinput_gt_loss = netinput_gt[available_label_idx]
            loss = criterion(batch_pred_loss, netinput_gt_loss.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss = train_loss+loss.item()
            y_pred = torch.argmax(batch_pred_loss, dim=1)
            train_correct += (y_pred == netinput_gt_loss).sum().item()
            train_total += netinput_gt_loss.size(0)
            if count == 0:
                y_pred_train =  y_pred.cpu().numpy()
                y_gt_train = netinput_gt_loss.cpu().numpy()
                count = 1
            else:
                y_pred_train = np.concatenate( (y_pred_train, y_pred.cpu().numpy()) )
                y_gt_train = np.concatenate( (y_gt_train, netinput_gt_loss.cpu().numpy()) ) 
    OA_train = train_correct / train_total
    if epoch%10==0:
        count=0
        val_loss = 0
        torch.cuda.empty_cache()
        with torch.no_grad():
            model.eval()
            for netinput_hsi,netinput_gt in Val_loader:
                netinput_hsi = netinput_hsi.to(device)
                b,c,h,w = netinput_hsi.shape
                batch_pred = model(netinput_hsi)
                batch_pred = batch_pred.reshape(b*img_size*img_size,-1)
                netinput_gt = netinput_gt.reshape(-1)
                available_label_idx = np.where(netinput_gt!=0)
                if len(available_label_idx[0]) != 0:
                    available_label_idx= torch.from_numpy(available_label_idx[0]).to(device)
                    netinput_gt = netinput_gt.to(device)
                    netinput_gt = netinput_gt-1
                    batch_pred_loss = batch_pred[available_label_idx]
                    netinput_gt_loss = netinput_gt[available_label_idx]
                    loss_val = criterion(batch_pred_loss, netinput_gt_loss.long())
                    val_loss =val_loss+loss_val.item()
                    y_pred = torch.argmax(batch_pred_loss, dim=1)
                    val_correct += (y_pred == netinput_gt_loss).sum().item()
                    val_total += netinput_gt_loss.size(0)
                    if count == 0:
                        y_pred_val =  y_pred.cpu().numpy()
                        y_gt_val = netinput_gt_loss.cpu().numpy()
                        count = 1
                    else:
                        y_pred_val = np.concatenate( (y_pred_val, y_pred.cpu().numpy()) )
                        y_gt_val = np.concatenate( (y_gt_val, netinput_gt_loss.cpu().numpy()) )  
            OA_val = val_correct / val_total 

        print('epoch',epoch)
        print('OA_train',OA_train)
        print('loss_train',train_loss)
        print('OA_val',OA_val)
        print('loss_val',val_loss)

        if val_loss < best_loss :
            best_loss = val_loss
            print('######################save model######################')
            torch.save(model.state_dict(), path_weight + r"model.pt")
        count=0

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


epoch 0
OA_train 0.1650943396226415
loss_train 51.218028485774994
OA_val 0.29603729603729606
loss_val 26.367133021354675
######################save model######################
epoch 10
OA_train 0.8158573270305114
loss_train 0.7228915365412831
OA_val 0.5990675990675991
loss_val 4.920184265822172
######################save model######################
epoch 20
OA_train 0.9010580819450698
loss_train 0.08274528034962714
OA_val 0.6985236985236986
loss_val 3.8707724497653544
######################save model######################


In [None]:
torch.cuda.empty_cache()
with torch.no_grad():
    model.load_state_dict(torch.load(path_weight + r"model.pt"))
    model.eval()
    y_tes_pred_combine = np.zeros([height_orgin,width_orgin])
    for j in range(len(Ly)):
        for k in range(len(Lx)):
            rStart, cStart = (Ly[j], Lx[k])
            rEnd, cEnd = (rStart + input_size[0], cStart + input_size[1])
            img_part_hsi = img_hsi[:,rStart:rEnd,cStart:cEnd].unsqueeze(0)
            batch_pred = model(img_part_hsi)
            batch_pred = torch.squeeze(batch_pred, 0)
            batch_pred = batch_pred.detach().cpu().numpy()
            batch_pred = np.argmax(batch_pred,1).reshape(img_size,img_size)
            if j == 0 and k == 0:
                y_tes_pred_combine[rStart:rEnd, cStart:cEnd] = batch_pred
            elif j == 0 and k > 0:
                y_tes_pred_combine[rStart:rEnd, cStart + int(overlap_size / 2):cEnd] = batch_pred[:,
                                                                                    int(overlap_size / 2):]
            elif j > 0 and k == 0:
                y_tes_pred_combine[rStart + int(overlap_size / 2):rEnd, cStart:cEnd] = batch_pred[
                                                                                    int(overlap_size / 2):,
                                                                                    :]
            else:
                y_tes_pred_combine[rStart + int(overlap_size / 2):rEnd,
                cStart + int(overlap_size / 2):cEnd] = batch_pred[int(overlap_size / 2):,
                                                            int(overlap_size / 2):]

y_tes_pred_combine =torch.from_numpy(y_tes_pred_combine).type(torch.LongTensor).to(device)+1
overall_acc,OA_hi1,average_acc,kappa,each_acc=utils.evaluate_performance_all(y_tes_pred_combine.reshape(height_orgin*width_orgin), test_samples_gt, test_gt_onehot,  height_orgin, width_orgin, class_num, test_gt,device, require_AA_KPP=True, printFlag=False)
print("test OA={:.4f}".format(overall_acc))
print('kappa=',kappa)
print('each_acc=',each_acc)
print('average_acc=',average_acc)
testOA_combine = evaluate_performance(y_tes_pred_combine.reshape(height_orgin*width_orgin), test_samples_gt, test_gt_onehot, zeros)

test OA=0.9242
kappa= [0.8999523172057162]
each_acc= [array([0.72477621, 0.98113816, 0.98737251, 0.79563492, 0.95555556,
       0.98296252, 0.96434109, 0.98682043, 0.81477398])]
average_acc= [0.9103750408174558]
