In [1]:
import h5py
import torch
import os
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
from torch.utils.data.dataset import Dataset
import numpy as np
import setting_2 as setting
import time
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import ComFunction as cf
import math

In [2]:
# import torch.backends.cudnn as cudnn
# from scipy.io import loadmat
# from scipy.io import savemat
# from sklearn.metrics import confusion_matrix
from vit_pytorch import ViT

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type=='cuda':
#     dtype = torch.float32
#     torch.set_default_tensor_type(torch.cuda.FloatTensor)
    dtype = torch.float64
    torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
#     dtype = torch.float32
#     torch.set_default_tensor_type(torch.FloatTensor)
    dtype = torch.float64
    torch.set_default_tensor_type(torch.DoubleTensor)
print(device.type)

cuda


In [4]:
class HyperData(Dataset):
    def __init__(self, data, labels, transfor):
        self.data = data
        self.transformer = transfor
        self.labels = labels
        

    def __getitem__(self, index):
        img = self.data[index,:,:]
        label = self.labels[index]
        return img, label

    def __len__(self):
        return len(self.labels)

    def __labels__(self):
        return self.labels

In [5]:
def gain_neighborhood_band(x_train, band, band_patch, patch=5):
    nn = band_patch // 2
    pp = (patch*patch) // 2
    x_train_reshape = x_train.reshape(x_train.shape[0], patch*patch, band)
    x_train_band = np.zeros((x_train.shape[0], patch*patch*band_patch, band),dtype=float)
    # 中心区域
    x_train_band[:,nn*patch*patch:(nn+1)*patch*patch,:] = x_train_reshape
    #左边镜像
    for i in range(nn):
        if pp > 0:
            x_train_band[:,i*patch*patch:(i+1)*patch*patch,:i+1] = x_train_reshape[:,:,band-i-1:]
            x_train_band[:,i*patch*patch:(i+1)*patch*patch,i+1:] = x_train_reshape[:,:,:band-i-1]
        else:
            x_train_band[:,i:(i+1),:(nn-i)] = x_train_reshape[:,0:1,(band-nn+i):]
            x_train_band[:,i:(i+1),(nn-i):] = x_train_reshape[:,0:1,:(band-nn+i)]
    #右边镜像
    for i in range(nn):
        if pp > 0:
            x_train_band[:,(nn+i+1)*patch*patch:(nn+i+2)*patch*patch,:band-i-1] = x_train_reshape[:,:,i+1:]
            x_train_band[:,(nn+i+1)*patch*patch:(nn+i+2)*patch*patch,band-i-1:] = x_train_reshape[:,:,:i+1]
        else:
            x_train_band[:,(nn+1+i):(nn+2+i),(band-i-1):] = x_train_reshape[:,0:1,:(i+1)]
            x_train_band[:,(nn+1+i):(nn+2+i),:(band-i-1)] = x_train_reshape[:,0:1,(i+1):]
    return x_train_band

In [6]:
EPOCH=setting.EPOCH
BATCH_SIZE=setting.BATCH_SIZE
LR=setting.LR

In [7]:
train_filename =setting.train_data_name +'_'+ str(setting.PATCH_SIZE) + '_'+ str(setting.DTYPE) +'.h5'
with h5py.File(train_filename,'r') as readfile:
    train=readfile['train_patch'][:]
    train_labels=readfile['train_labels'][:]
print('train size:', train.shape)
print('train label name:',np.unique(train_labels))
print(np.max(train))
print(np.min(train))

train_max=np.max(np.abs(train))
print('train max:',train_max)


train size: (14000, 7, 7, 204)
train label name: [0 1]
3.2608695030212402
0.06606606394052505
train max: 3.2608695030212402


In [8]:
train_band = gain_neighborhood_band(train, setting.band, setting.band_patches, setting.PATCH_SIZE)
print(train_band.shape)
train_band =train_band.transpose(0,2,1)
print(train_band .shape)

(14000, 343, 204)
(14000, 204, 343)


In [9]:
val_filename = setting.val_data_name +'_'+ str(setting.PATCH_SIZE) +'_'+ str(setting.DTYPE) +'.h5'
with h5py.File(val_filename,'r') as readfile:
    val=readfile['val_patch'][:]
    val_labels=readfile['val_labels'][:]
print('val size:', val.shape)
print('val label name:',np.unique(val_labels))
print(np.max(val))
print(np.min(val))

val_max=np.max(np.abs(val))
print('val max:',val_max)


val_band = gain_neighborhood_band(val, setting.band, setting.band_patches, setting.PATCH_SIZE)
print(val_band.shape)
val_band =val_band.transpose(0,2,1)
print(val_band.shape)

val size: (7000, 7, 7, 204)
val label name: [0 1]
1.2545454502105713
0.04953031614422798
val max: 1.2545454502105713
(7000, 343, 204)
(7000, 204, 343)


In [10]:
# cpu
# train_set=HyperData(train_band,train_labels, None)
# trainloader= Data.DataLoader(dataset=train_set,batch_size=BATCH_SIZE,shuffle=True, num_workers=0)

# val_set=HyperData(val_band, val_labels, None)
# valloader= Data.DataLoader(dataset=val_set,batch_size=BATCH_SIZE,shuffle=False, num_workers=0)

In [11]:
# # gpu
train_set=HyperData(train_band,train_labels, None)
trainloader= Data.DataLoader(dataset=train_set,batch_size=BATCH_SIZE,shuffle=True, num_workers=0,generator=torch.Generator(device='cuda'))

val_set=HyperData(val_band, val_labels, None)
valloader= Data.DataLoader(dataset=val_set,batch_size=BATCH_SIZE,shuffle=False, num_workers=0)

In [12]:
time_str=time.strftime("%Y_%m_%d_%H_%M", time.localtime())

writer_name=os.path.join(setting.writer_name,'nn_'+time_str)
print(writer_name)
writer = SummaryWriter(writer_name)
# tensorboard --logdir=runs

runs\nn_2023_01_16_10_43


In [13]:
net=ViT(
    image_size = setting.PATCH_SIZE,
    near_band = setting.band_patches,
    num_patches = setting.band,
    num_classes = setting.num_class,
    dim = 64,
    depth =2,
    heads = 4,
    mlp_dim =6,
    dropout = 0.1,
    emb_dropout = 0.1,
    mode = setting.mode
).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=setting.LR, weight_decay=setting.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=setting.EPOCH//10, gamma=setting.gamma)

In [14]:
best_acc=0
best_epoch=0

global_steps=0

In [15]:
filename=os.path.join(setting.train_result_dir,'traintime_'+time_str+'.txt')
print(filename)

result=open(filename,'a')
start=time.time()
for epoch in range(setting.EPOCH):  # loop over the dataset multiple times
    print('epoch:',epoch)
    result.write('epoch:'+str(epoch)+'\n')
    correct = 0
    total = 0
    
    running_loss = 0.0
    train_loss= 0.0
    net.train()
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        b_inputs=Variable(inputs).to(device)
        b_labels=Variable(labels).to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()
#         print(inputs.shape)

        # forward + backward + optimize
        outputs = net(b_inputs)
#         print(labels.unique)
#         print(outputs.unique)
        loss = criterion(outputs, b_labels)
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        global_steps+=1
#         loss_list.append(loss)
            
        # print statistics
        running_loss += loss.item()
        train_loss += loss.item()
        if (global_steps) % 200 == 0:    # print every 2000 mini-batches
            print('global steps %5d running loss: %.3f' %
                  (global_steps, running_loss / 200))
#             state = {
#                 'epoch': epoch,
#                 'step': i,
#                 'net': net.state_dict(),
#                 'optimizer':optimizer.state_dict(),
#                 'loss':running_loss / 200
#             }
#             checkpoint_name=os.path.join(setting.checkpoint_dir,'model_time_'+time_str+'_epoch_'+str(epoch)+'_step_'+str(i)+'.pth')
#             torch.save(state, checkpoint_name)
            running_loss = 0.0
            
        #save loss log file
#         if global_steps%10 ==0:
#             writer.add_scalar('train_loss',loss,global_steps)
            
        
    
    accuracy = 100 * correct / total
    print( 'train oa:', 100 * correct / total,'train loss',train_loss/len(trainloader))
    result.write('train accuracy:'+str(100 * correct / total) +'\t'+'train loss:'+str(train_loss/len(trainloader)))
    #save training accuracy log file
    writer.add_scalar('train_accuracy',accuracy,epoch)
    writer.add_scalar('train_loss',train_loss/len(trainloader),epoch)
    
    
    val_correct = 0.0
    val_total = 0.0
    val_loss = 0.0
    net.eval()
    for valdata in valloader:
        val_inputs, val_labels = valdata
        val_inputs = val_inputs.to(device)
        val_labels =  val_labels.to(device)
        val_outputs = net(val_inputs)
        valloss = criterion(val_outputs, val_labels)
        val_loss += valloss.item()
        _, val_predicted = torch.max(val_outputs.data, 1)
        val_total += val_labels.size(0)
        val_correct += (val_predicted == val_labels).sum().item()
    val_oa=val_correct/val_total
    print(' val oa:',val_oa, 'val_loss:',val_loss/len(valloader))
    result.write('\t val oa:'+str(val_oa)+'\t'+ 'val_loss:'+str(val_loss/len(valloader)))
    writer.add_scalar('val_Accu',val_oa , epoch)
    writer.add_scalar('val_loss',val_loss/len(valloader) , epoch)
    
    if val_oa > best_acc:
        state = {
            'epoch': epoch,
            'accuracy': val_oa,
            'net': net.state_dict(),
            'optimizer':optimizer.state_dict(),
            'loss':loss
        }
        best_model_name=os.path.join(setting.best_model_dir,'model_time_'+time_str+'.pth')
        torch.save(state, best_model_name)
        best_acc=val_oa
        best_epoch=epoch
        print('epoch:',epoch,'best val accuracy:',val_oa)
    
    #pridict  testing data and save accuracy log file
#     if (epoch) % 10 == 0:
#         test_correct = 0.0
#         test_total = 0.0
#         test_loss = 0.0
#         net.eval()
#         for testdata in testloader:
#             test_inputs, test_labels = testdata
#             test_inputs = test_inputs.to(device)
#             test_labels =  test_labels.to(device)
#             test_outputs = net(test_inputs)
#             tsloss = criterion(test_outputs, test_labels)
#             test_loss += tsloss.item()
#             _, test_predicted = torch.max(test_outputs.data, 1)
#             test_total += test_labels.size(0)
#             test_correct += (test_predicted == test_labels).sum().item()
#         oa=test_correct/test_total
#         print(' test oa:',oa, 'loss:',test_loss/len(testloader))
#         result.write('\t test oa:'+str(oa)+'\t'+ 'loss:'+str(test_loss/len(testloader)))
#         writer.add_scalar('Test_Accu',oa , epoch)
#         writer.add_scalar('Test_loss',test_loss/len(testloader) , epoch)
    
#         if oa > best_acc:
#             state = {
#                 'epoch': epoch,
#                 'accuracy': oa,
#                 'net': net.state_dict(),
#                 'optimizer':optimizer.state_dict(),
#                 'loss':loss
#             }
#             best_model_name=os.path.join(setting.best_model_dir,'model_time_'+time_str+'.pth')
#             torch.save(state, best_model_name)
#             best_acc=oa
#             print('epoch:',epoch,'best test accuracy:',oa)
        

    finish_state = {
        'epoch': epoch,
        'net': net.state_dict(),
        'optimizer':optimizer.state_dict(),
        'loss':loss,
        'train_oa':accuracy,
#         'test_oa':oa
        }
end=time.time()
model_name=os.path.join(setting.model_dir,'model_time_'+time_str+'.pth')
torch.save(finish_state, model_name)
#plt.ioff()
#plt.show()
print('Finished Training')
writer.close()
result.close()

.\ViT_train_result\traintime_2023_01_16_10_43.txt
epoch: 0
global steps   200 running loss: 0.590
train oa: 66.75 train loss 0.5624863202559458
 val oa: 0.8544285714285714 val_loss: 0.38154877831289297
epoch: 0 best val accuracy: 0.8544285714285714
epoch: 1
global steps   400 running loss: 0.213
train oa: 90.66428571428571 train loss 0.23109978807442813
 val oa: 0.8962857142857142 val_loss: 0.26265342898832006
epoch: 1 best val accuracy: 0.8962857142857142
epoch: 2
global steps   600 running loss: 0.156
train oa: 92.1 train loss 0.19160443793728132
 val oa: 0.8922857142857142 val_loss: 0.2620066463661791
epoch: 3
global steps   800 running loss: 0.135
train oa: 92.60714285714286 train loss 0.18070028070042984
 val oa: 0.8997142857142857 val_loss: 0.25448680380171496
epoch: 3 best val accuracy: 0.8997142857142857
epoch: 4
global steps  1000 running loss: 0.114
train oa: 92.88571428571429 train loss 0.17790197074349598
 val oa: 0.8977142857142857 val_loss: 0.255058312255156
epoch: 5
glob

In [16]:
print('best_acc:',best_acc)
print('best_epoch:',best_epoch)
print('train time:', end-start)
print(best_model_name)
print(model_name)

best_acc: 0.9128571428571428
best_epoch: 12
train time: 576.1477062702179
.\model\model_best\model_time_2023_01_16_10_43.pth
.\model\model_time_2023_01_16_10_43.pth


In [17]:

# model_name='./model/salinas/model_best/model_time_2021_03_18_12_17.pth'
# model_name='.\model\salinas\model_time_2020_03_31_11_43.pth'
# model = torch.load(model_name)
model = torch.load(best_model_name)

In [18]:
net.load_state_dict(model['net'])
# optimizer.load_state_dict(model['optimizer'])
# start_epoch = model['epoch'] + 1
# loss=model['loss']

<All keys matched successfully>

In [19]:
predict_trainlabels=[]
trainlabels=[]
t1=time.time()
with torch.no_grad():
    for trdata in trainloader:
        trinputs, trlabels = trdata
        trinputs = trinputs.to(device)
        trlabels = trlabels.to(device)
        troutputs = net(trinputs)
        _, trpredicted = torch.max(F.softmax(troutputs), 1)
        predict_trainlabels.extend(trpredicted)
        trainlabels.extend(trlabels)
    print('predict training set finished')
#print(len(trainlabels))
#print(len(predict_trainlabels))
t2=time.time()
print('predict train time:',t2-t1)

predict_trainlabels=torch.tensor(predict_trainlabels, device='cpu')
trainlabels=torch.tensor(trainlabels,device='cpu')

oa_train, aa_train, kappa_train, acc_train=cf.eval_results_own(predict_trainlabels,trainlabels,2)
print('OA_train:',oa_train, '\nAA_train:', aa_train, '\nkappa_train:', kappa_train, '\nacc_train:', acc_train)

  _, trpredicted = torch.max(F.softmax(troutputs), 1)


predict training set finished
predict train time: 9.871007442474365
OA_train: 0.965 
AA_train: 0.965 
kappa_train: 0.9299999999999999 
acc_train: [0.95385714 0.97614286]


In [20]:
predict_vallabels=[]
val_labels=[]
t1=time.time()
with torch.no_grad():
    for valdata in valloader:
        vinputs, vlabels = valdata
        vinputs = vinputs.to(device)
        vlabels = vlabels.to(device)
        voutputs = net(vinputs)
        _, vpredicted = torch.max(F.softmax(voutputs), 1)
        predict_vallabels.extend(vpredicted)
        val_labels.extend(vlabels)
    print('predict val set finished')
#print(len(trainlabels))
#print(len(predict_trainlabels))
t2=time.time()
print('predict val time:',t2-t1)

predict_vallabels=torch.tensor(predict_vallabels, device='cpu')
val_labels=torch.tensor(val_labels,device='cpu')

oa_val, aa_val, kappa_val, acc_val=cf.eval_results_own(predict_vallabels,val_labels,2)
print('OA_val:',oa_val, '\nAA_val:', aa_val, '\nkappa_val:', kappa_val, '\nacc_val:', acc_val)

  _, vpredicted = torch.max(F.softmax(voutputs), 1)


predict val set finished
predict val time: 4.927295207977295
OA_val: 0.9128571428571428 
AA_val: 0.9128571428571428 
kappa_val: 0.8257142857142856 
acc_val: [0.94542857 0.88028571]
