In [None]:
import os
import torch
import argparse
import numpy as np
import torch.nn as nn
import scipy.io as sio
from Model_DADSSFF import DADSSFF
import torch.nn.functional as F 
from GeneratePic import generate_png
import torch.backends.cudnn as cudnn
from sklearn import metrics, preprocessing
from Gain_batch import gain_train_test_batch, gain_total_batch

### 1. Define Hyerparameters

In [None]:
parser = argparse.ArgumentParser(description='CD-DADSSFF')
parser.add_argument('--dataset',              default='River', help ='name   of datasets: China, River, USA')
parser.add_argument('--patches',  type=int,   default=7,       help ='size   of patches')
parser.add_argument('--batches',  type=int,   default=64,      help ='number of batches')
parser.add_argument('--epoches',  type=int,   default=100,     help ='amount of epoches')
parser.add_argument('--tr_rate',  type=float, default=5e-2,    help ='rate   of train')
parser.add_argument('--lr_rate',  type=float, default=0.001,   help ='rate   of learning')
parser.add_argument('--decay',    type=float, default=0.001,   help ='weight of decay ')
parser.add_argument('--cuda',     type=int,   default=2,       help ='ID     of cuda ')
parser.add_argument('--seed',     type=int,   default=1024,    help ='random seed ')

parser.add_argument('--lambdas1', type=float, default=1.0,     help ='Hyperparameter of DA  loss')
parser.add_argument('--lambdas2', type=float, default=1.0,     help ='Hyperparameter of KLD loss')

args=parser.parse_args(args=[])

Device = torch.device('cuda:' + str(args.cuda) if torch.cuda.is_available() else 'cpu')

In [None]:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
cudnn.deterministic = True
cudnn.benchmark = False

### 2. Importing Dataset

In [None]:
if args.dataset == 'China':
    T1 = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/code/DADSSFF/data/01_China/China_T1.mat')['China_T1']
    T2 = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/code/DADSSFF/data/01_China/China_T2.mat')['China_T2']
    GT = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/code/DADSSFF/data/01_China/China_GT.mat')['China_GT']
elif args.dataset == 'River':
    T1 = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/data/02_River/River_T1.mat')['River_T1']
    T2 = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/data/02_River/River_T2.mat')['River_T2']
    GT = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/data/02_River/River_GT.mat')['River_GT']
elif args.dataset == 'USA':
    T1 = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/data/03_USA/USA_T1.mat')['USA_T1']
    T2 = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/data/03_USA/USA_T2.mat')['USA_T2']
    GT = sio.loadmat('/home/qinxuexiang/01_ChangeDetection/data/03_USA/USA_GT.mat')['USA_GT']
else:
    print("Please enter one of 'China', 'River' and 'USA'!")

H1, W1, B1=T1.shape
H2, W2, B2=T2.shape
TT1=T1.reshape(H1*W1, B1)
TT2=T2.reshape(H2*W2, B2)
T1=preprocessing.scale(TT1)
T2=preprocessing.scale(TT2)
Time1=T1.reshape(H1, W1, B1)
Time2=T2.reshape(H2, W2, B2)
classes_num=np.max(GT)+1
print("Time1 shape: ", Time1.shape)
print("Time2 shape: ", Time2.shape)
print('class_num:',    classes_num)

### 3. Selecting Train and Test

In [None]:
Train_loader, Test_loader = gain_train_test_batch(Time1, Time2, GT, args.tr_rate, args.patches, args.batches, args.seed)

### 4. Load Network Structure

In [None]:
model = DADSSFF(bands=B1, num_classes=classes_num).to(Device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr_rate, weight_decay=args.decay)

### 5. Train

In [None]:
best = 0
for epoch in range(args.epoches):
    model.train()
    Train_Loss, Train_GT, Train_label= [], [], []

    for train_T1, train_T2, train_gt in Train_loader:

        train_T1, train_T2, train_gt = train_T1.to(Device), train_T2.to(Device), train_gt.to(Device)

        ## 1) Forwards
        train_pred, train_DA_loss, train_KLD_loss = model(train_T1, train_T2)
        loss = criterion(train_pred, train_gt) + args.lambdas1*train_DA_loss + args.lambdas2*train_KLD_loss

        ## 2) Zero the gradients
        optimizer.zero_grad()

        ## 3) Backward and update weights
        loss.backward() 
        optimizer.step()

        ## 4) estimate accuracy
        _, train_label = torch.max(F.softmax(train_pred, dim=1), 1)

        Train_Loss.append(loss.cpu().item())
        Train_GT.extend(np.array(train_gt.cpu()))
        Train_label.extend(np.array(train_label.cpu()))

    Train_Loss = np.array(Train_Loss)
    Train_GT   = np.array(Train_GT)
    Train_label= np.array(Train_label)
    Train_F1   = metrics.f1_score(Train_GT, Train_label)

    ## Test
    if (epoch+1) % 10 == 0:
        model.eval()    
        with torch.no_grad():

            Test_loss, Test_GT, Test_label = [], [], []
                      
            for test_T1, test_T2, test_gt in Test_loader:

                test_T1, test_T2, test_gt = test_T1.to(Device), test_T2.to(Device), test_gt.to(Device)

                test_pred, test_DA_loss, test_KLD_loss = model(test_T1, test_T2) 
                test_loss = criterion(test_pred, test_gt) + args.lambdas1*test_DA_loss + args.lambdas2*test_KLD_loss

                _, test_label = torch.max(F.softmax(test_pred, dim=1), 1)

                Test_loss.append(test_loss.cpu().item())
                Test_GT.extend(np.array(test_gt.cpu()))
                Test_label.extend(np.array(test_label.cpu()))

        Test_loss = np.array(Test_loss)
        Test_GT   = np.array(Test_GT)
        Test_label= np.array(Test_label)
        Test_F1   = metrics.f1_score(Test_GT,Test_label)

        print('epoch: {:03d}/{}; Train loss: {:.4f}; Train F1: {:.4f}; // Test loss:{:.4f}; Test F1: {:.4f}'.format\
              (epoch+1, args.epoches, np.mean(Train_Loss), Train_F1*100, np.mean(Test_loss), Test_F1*100))

        if Test_F1 >= best:
            best = Test_F1
            name='DADSSFF_' + args.dataset + '.pth'
            torch.save(model, name) 

### 4 Evaluation

In [None]:
name='DADSSFF_' + args.dataset + '.pth'
model = torch.load(name)
model.eval()    
with torch.no_grad():

    Test_GT, Test_label = [], []
                      
    for test_T1, test_T2, test_gt in Test_loader:

        test_T1, test_T2, test_gt = test_T1.to(Device), test_T2.to(Device), test_gt.to(Device)

        test_pred, _, _ = model(test_T1, test_T2)        

        _, test_label = torch.max(F.softmax(test_pred, dim=1), 1)

        Test_GT.extend(np.array(test_gt.cpu()))
        Test_label.extend(np.array(test_label.cpu()))
        
Test_GT, Test_label = np.array(Test_GT), np.array(Test_label)

Test_OA    = metrics.accuracy_score(Test_GT,Test_label)
Test_Kappa = metrics.cohen_kappa_score(Test_GT,Test_label)
F1_score   = metrics.f1_score(Test_GT,Test_label)
Precision  = metrics.precision_score(Test_GT,Test_label)
Recall     = metrics.recall_score(Test_GT,Test_label)

print('F1_socre: {:.2f}; OA: {:.2f}; Kappa: {:.4f}; Precision: {:.4f}; Recall: {:.4f}; '.format(F1_score*100, Test_OA*100, Test_Kappa, Precision, Recall))

### 5. Generate result picture

#### 5.1 gain total samples

In [None]:
Total_loader = gain_total_batch(Time1, Time2, GT, args.patches, args.batches)

#### 5.2 gain predict

In [None]:
model = torch.load(name)
model.eval()
with torch.no_grad():

    Total_pred, Total_position = [], []

    for total_location, Total_T1, Total_T2,_ in Total_loader:
        
        Total_T1, Total_T2 = Total_T1.to(Device), Total_T2.to(Device)            
        total_pred, _, _   = model(Total_T1, Total_T2)
        
        _, total_label = torch.max(F.softmax(total_pred, dim=1), 1)

        Total_pred.extend(np.array(total_label.cpu()))
        Total_position.extend(np.array(total_location.cpu()))

Total_pred     = np.array(Total_pred)
Total_position = np.array(Total_position)

#### 5.4 Get result

In [None]:
_ = generate_png(GT, Total_position, Total_pred, args.dataset)