In [7]:
from model import ss_fusion_cls
import torch
from torch  import nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report,cohen_kappa_score
from model import split_data,utils
from sklearn import metrics, preprocessing
from mmengine.optim import build_optim_wrapper
from mmcv_custom import custom_layer_decay_optimizer_constructor,layer_decay_optimizer_constructor_vit
import scipy.io as sio
from thop import profile
from multiprocessing import shared_memory

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
class DataReader():
    def __init__(self):
        self.data_cube = None
        self.g_truth = None

    @property
    def cube(self):
        """
        origin data
        """
        return self.data_cube

    @property
    def truth(self):
        return self.g_truth

    @property
    def normal_cube(self):
        """
        normalization data: range(0, 1)
        """
        return (self.data_cube - np.min(self.data_cube)) / (np.max(self.data_cube) - np.min(self.data_cube))
class dataRaw(DataReader):
    def __init__(self):
        super(dataRaw, self).__init__()
        raw_data_package = sio.loadmat(r"data/Indian_pines_corrected.mat")
        self.data_cube = raw_data_package["data"].astype(np.float32)

In [10]:
def load_data():
    data = dataRaw().normal_cube
    return data


In [11]:
img_size =9
patch_size=2
pca_components = 10
class_num = 16
max_epoch = 100
batch_size = 64
learning_rate = 0.00001 
path_weight = r"weights//"
path_result = r"result//"
data = load_data()
height, width, bands = data.shape

In [12]:
data, pca = split_data.apply_PCA(data, num_components=pca_components)
data_all = split_data.create_patches_inference(data, window_size=img_size)

In [13]:
data_all = data_all.transpose(0, 3, 1, 2)

print('after transpose: train shape: ', data_all.shape)


after transpose: train shape:  (22201, 10, 9, 9)


In [14]:
class TrainDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = data_all.shape[0]
        self.x_data = torch.FloatTensor(data_all)
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data[index]

    def __len__(self):
        # 返回文件数据的数目
        return self.len




# 创建 trainloader 和 testloader
trainset = TrainDS()

train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=64, shuffle=True, num_workers=0)


In [15]:
model = ss_fusion_cls.SSFusionFramework(
                img_size = img_size,
                in_channels = pca_components,
                patch_size=patch_size,
                classes = class_num+1,
                model_size='base' #The optional values are 'base','large' and 'huge'
).to(device)


In [16]:
model_params =model.state_dict()
spat_net = torch.load((r"spat-base.pth"), map_location=torch.device('cpu'))
for k in list(spat_net['model'].keys()):
    if 'patch_embed.proj' in k:
        del spat_net['model'][k]
for k in list(spat_net['model'].keys()):
    if 'spat_map' in k:
        del spat_net['model'][k]
for k in list(spat_net['model'].keys()):
    if 'spat_output_maps' in k:
        del spat_net['model'][k]
for k in list(spat_net['model'].keys()):
    if 'pos_embed' in k:
        del spat_net['model'][k]
spat_weights = {}
prefix = 'spat_encoder.'
for key, value in spat_net['model'].items():
    new_key = prefix + key
    spat_weights[new_key] = value
per_net = torch.load((r"spec-base.pth"), map_location=torch.device('cpu'))
model_params =model.state_dict()
for k in list(per_net['model'].keys()):
    if 'patch_embed.proj' in k:
        del per_net['model'][k]
    if 'spat_map' in k:
        del per_net['model'][k]
    if 'fpn1.0.weight' in k:
        del per_net['model'][k]
spec_weights = {}
prefix = 'spec_encoder.'
for key, value in per_net['model'].items():
    new_key = prefix + key
    spec_weights[new_key] = value
model_params =model.state_dict()
for k in list(spec_weights.keys()):
    if 'spec_encoder.patch_embed' in k:
        del spec_weights[k]
merged_params = {**spat_weights, **spec_weights}
same_parsms = {k: v for k, v in merged_params.items() if k in model_params.keys()}
model_params.update(same_parsms)
model.load_state_dict(model_params)

<All keys matched successfully>

In [17]:
count = 0
model.eval()
with torch.no_grad():
    for x in train_loader:
        if torch.cuda.is_available():
            x = x.to(device)
        output = model(x)
        y_pred = torch.argmax(output, dim=1)
        if count == 0:
            y_pred_test =  y_pred.cpu().numpy()
            count = 1
        else:
            y_pred_test = np.concatenate( (y_pred_test, y_pred.cpu().numpy()) )


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
