In [15]:
import os
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import random
import numpy as np
import matplotlib as mpl 
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm
from tqdm import tqdm as tqdmm
from torch.utils.data import DataLoader
from torch.optim import Adam
from IPython.display import HTML, display,clear_output
from model.backbone_resnet import resnet18

# from tool.utils import *
from tool.train_tool import *
import tool.dataload_seg as d

# Dataset = 'CDD'
# Dataset = 'SYSU-CD'
Dataset = "LEVIR"

if Dataset == 'CDD':
    DATA_PATH = '/home/amax/yyq/Dataset/CDD'
if Dataset == 'SYSU-CD':
    DATA_PATH = '/home/amax/yyq/Dataset/SYSU-CD/'
if Dataset == 'LEVIR':
    DATA_PATH = '/home/yang/yyq/Dataset/LEVIR-CD-crop/'
        
os.environ['CUDA_VISIBLE_DEVICES'] = '1,0'
device1 = torch.device("cuda:0")
device0 = torch.device("cuda:1")
device2 = torch.device("cpu")

## Dataload

In [16]:
TRAIN_DATA_PATH = os.path.join(DATA_PATH)
TRAIN_LABEL_PATH = os.path.join(DATA_PATH)
TRAIN_TXT_PATH = os.path.join(TRAIN_DATA_PATH,'train.txt')
VAL_DATA_PATH = os.path.join(DATA_PATH)
VAL_LABEL_PATH = os.path.join(DATA_PATH)
VAL_TXT_PATH = os.path.join(VAL_DATA_PATH,'val.txt')
TEST_DATA_PATH = os.path.join(DATA_PATH)
TEST_LABEL_PATH = os.path.join(DATA_PATH)
TEST_TXT_PATH = os.path.join(TEST_DATA_PATH, 'test.txt')

## Dataloading

In [17]:
batch_size = 32
train_batch_size = 16
seg_num = 200

train_data = d.Dataset(TRAIN_DATA_PATH, TRAIN_LABEL_PATH,
                            TRAIN_TXT_PATH,'train',transform=True, seg_num=seg_num,ratio=2)
train_loader = DataLoader(train_data, batch_size=train_batch_size,
                             shuffle= True, num_workers= 8, pin_memory= True)

test_batch_size = batch_size
test_data = d.Dataset(TEST_DATA_PATH, TEST_LABEL_PATH,
                        TEST_TXT_PATH,'test', transform=False, seg_num=seg_num,ratio=2)
test_loader = DataLoader(test_data, batch_size=test_batch_size,
                            shuffle= False, num_workers= 8, pin_memory= True)

val_batch_size = batch_size
val_data = d.Dataset(VAL_DATA_PATH, VAL_LABEL_PATH,
                        VAL_TXT_PATH,'val', transform=False, seg_num=seg_num,ratio=2)
val_loader = DataLoader(val_data, batch_size=val_batch_size,
                            shuffle= False, num_workers= 8, pin_memory= True)

## prepare

In [18]:
Epoch = 200
Channel_num = 64
word_num = 64
key = 'res1'

Best_oa = 0
loss_mean = []

accumulation = 1

In [19]:
from model.CDet_dl import ChangeDetector,Load_Weight_FordataParallel

model = ChangeDetector(Channel_num, 2, key, word_num=word_num, backbone='resnet34')
# model.load_state_dict(torch.load('Results_changedetector_LEVIR/Best_model.pth'))

model = nn.DataParallel(model)
model = model.cuda()

STEPS_PER_EPOCH = len(train_loader)
opt = torch.optim.Adam(model.parameters(),1e-4,betas=(0.9,0.999)) #Construct optimizers
scheduler = lr_scheduler.OneCycleLR(opt, max_lr=1e-3, anneal_strategy='cos', total_steps=Epoch*STEPS_PER_EPOCH) #Construct schedulers for learning 

Loss_function_classify = nn.CrossEntropyLoss()

Log_path = 'results/'+Dataset+'/'
if not os.path.exists(Log_path):
    os.mkdir(Log_path)

resnet34


In [20]:
def eval_model(data_loader, model, parameters):
    model.load_state_dict(parameters)
    model.eval()
    with torch.no_grad():
        Accuracies = []
        Right = 0
        Sum = 0

        TP = 0
        FP = 0
        FN = 0

        for i,(imgt1,imgt2, gt, filename,_,seg1,num1,max_num1,seg2,num2,max_num2) in enumerate(tqdm(data_loader)):
            imgt1, imgt2, gt = imgt1.cuda(), imgt2.cuda(), gt.cuda()
            seg1,num1,max_num1,seg2,num2,max_num2 = seg1.cuda(),num1.cuda(),max_num1.cuda(),seg2.cuda(),num2.cuda(),max_num2.cuda()

            prediction = model(imgt1, imgt2, seg1,num1,max_num1,seg2,num2,max_num2)
#             prediction = nn.UpsamplingNearest2d(scale_factor=2)(prediction)
            prediction = torch.max(prediction,dim=1)[1]
#             plt.imsave('vis_'+Dataset+'/'+filename[0].split('/')[-1].split('.')[0]+'.png',prediction[0].cpu(),cmap='gray')
            
            Right += torch.sum(prediction == gt)
            Sum += torch.sum(gt>-1)

            impred = prediction
            imlabel = gt

            accuracy = OA(prediction.view(-1),gt.view(-1))
            Accuracies.append(float(accuracy))

            #Precision,recall,Iou
            numclass = 1
            TP +=  int(torch.sum(impred * (impred == imlabel)))
            FP += int(torch.sum(impred * (impred != imlabel)))
            FN += int(torch.sum(imlabel * (impred != imlabel)))

        Average_accuracy = np.mean(Accuracies)
        Overrall_accuracy = float(Right/Sum)
        Iou = TP/(TP+FP+FN)
        Percison = TP/(TP+FP)
        Recall = TP/(TP+FN)
        F1 = (2*Percison*Recall)/(Percison+Recall)

        Average_accuracy = round(Average_accuracy,4)
        Overrall_accuracy = round(Overrall_accuracy,4)
        Iou = round(Iou,4)
        Percison = round(Percison,4)
        Recall = round(Recall,4)
        F1 = round(F1,4)

    print('AA: \t\t',Average_accuracy)
    print('OA:\t\t',Overrall_accuracy)
    print('Iou:\t\t',Iou)
    print('Percison:\t',Percison)
    print('Recall:\t\t',Recall)
    print('F1\t\t',F1)

In [21]:
def test_model(data_loader):
    global Best_oa
    model.eval()
    with torch.no_grad():
        Accuracies = []
        Right = 0
        Sum = 0

        TP = 0
        FP = 0
        FN = 0

        for i,(imgt1,imgt2, gt, filename,_,seg1,num1,max_num1,seg2,num2,max_num2) in enumerate(tqdm(data_loader)):
            imgt1, imgt2, gt = imgt1.cuda(), imgt2.cuda(), gt.cuda()
            seg1,num1,max_num1,seg2,num2,max_num2 = seg1.cuda(),num1.cuda(),max_num1.cuda(),seg2.cuda(),num2.cuda(),max_num2.cuda()

            prediction = model(imgt1, imgt2, seg1,num1,max_num1,seg2,num2,max_num2)
#             prediction = nn.UpsamplingNearest2d(scale_factor=2)(prediction)
            prediction = torch.max(prediction,dim=1)[1]

            Right += torch.sum(prediction == gt)
            Sum += torch.sum(gt>-1)

            impred = prediction
            imlabel = gt

            accuracy = OA(prediction.view(-1),gt.view(-1))
            Accuracies.append(float(accuracy))

            #Precision,recall,Iou
            numclass = 1
            TP +=  int(torch.sum(impred * (impred == imlabel)))
            FP += int(torch.sum(impred * (impred != imlabel)))
            FN += int(torch.sum(imlabel * (impred != imlabel)))

        Average_accuracy = np.mean(Accuracies)
        Overrall_accuracy = float(Right/Sum)
        Iou = TP/(TP+FP+FN)
        Percison = TP/(TP+FP)
        Recall = TP/(TP+FN)
        F1 = (2*Percison*Recall)/(Percison+Recall)

        Average_accuracy = round(Average_accuracy,4)
        Overrall_accuracy = round(Overrall_accuracy,4)
        Iou = round(Iou,4)
        Percison = round(Percison,4)
        Recall = round(Recall,4)
        F1 = round(F1,4)

        print('AA: \t\t',Average_accuracy)
        print('OA:\t\t',Overrall_accuracy)
        print('Iou:\t\t',Iou)
        print('Percison:\t',Percison)
        print('Recall:\t\t',Recall)
        print('F1\t\t',F1)
    
    if F1 > Best_oa:
        Best_oa = F1
        state_dict = Load_Weight_FordataParallel(model.state_dict(),need_dataparallel=0)
        torch.save(state_dict, Log_path+'Best_model.pth')
#         torch.save(model.state_dict(), Log_path+'Best_model.pth')
    f = open(Log_path+'Metric_recording.txt','a')
    f.write('AA: \t\t'+str(Average_accuracy)+'\n'+'OA:\t\t'+str(Overrall_accuracy)+'\n'+'Iou:\t\t'+str(Iou)+'\n'+'Percison:\t'+str(Percison)+'\n'+'Recall:\t\t'+str(Recall)+'\n'+'F1\t\t'+str(F1)+'\n')
#     f.close()
        
#     f = open(Log_path+'Metric_recording.txt','a')
    f.write('Epoch:'+str(epoch)+'   Currentoa:'+str(round(F1,4))+'   Bestoa:'+str(round(Best_oa,4))+'\n')
    f.close()

In [None]:
if __name__ == "__main__":
    loss_sum = []
    for epoch in tqdm(range(Epoch)):

        model.train()
        for iter_,(imgt1,imgt2,gt, _,train_mask,seg1,num1,max_num1,seg2,num2,max_num2) in enumerate(tqdm(train_loader)):
            opt.zero_grad()
            cbatch_size = imgt1.shape[0]

            imgt1, imgt2, gt, train_mask = imgt1.cuda(), imgt2.cuda(), gt.cuda(), train_mask.cuda()
            seg1,num1,max_num1,seg2,num2,max_num2 = seg1.cuda(),num1.cuda(),max_num1.cuda(),seg2.cuda(),num2.cuda(),max_num2.cuda()

            prediction = model(imgt1, imgt2, seg1,num1,max_num1,seg2,num2,max_num2)
            
            '''dictionary loss'''
            dictionary = model.module.dictionary
            dic_eye = model.module.eye
            dictionary_loss = torch.sqrt(torch.sum((dictionary.matmul(dictionary.T) - dic_eye)**2))

            '''change detection'''
            loss_change = changeloss(train_mask, gt, prediction)
            
            loss = loss_change+dictionary_loss
            print(round(loss_change.item(),6))

            loss.backward()
            opt.step()
            scheduler.step()

            loss_sum.append(round(loss.item(),4))
            loss_mean.append(np.mean(loss_sum))
            
            if (iter_+1)%10 == 0:
                clear_output(wait=True)
                show_plot(np.arange(len(loss_mean)), loss_mean,title='The total loss change')

        test_model(test_loader)
        if (epoch+1)%10 == 0:
            state_dict = torch.load(Log_path+'Best_model.pth')
            state_dict = Load_Weight_FordataParallel(state_dict,need_dataparallel=1)
            eval_model(test_loader, model, state_dict)

In [68]:
def eval_model(data_loader, model, parameters):
    model.load_state_dict(parameters)
    model.eval()
    with torch.no_grad():
        Accuracies = []
        Right = 0
        Sum = 0

        TP = 0
        FP = 0
        FN = 0

        for i,(imgt1,imgt2, gt, filename,_,seg1,num1,max_num1,seg2,num2,max_num2) in enumerate(tqdm(data_loader)):
            imgt1, imgt2, gt = imgt1.cuda(), imgt2.cuda(), gt.cuda()
            seg1,num1,max_num1,seg2,num2,max_num2 = seg1.cuda(),num1.cuda(),max_num1.cuda(),seg2.cuda(),num2.cuda(),max_num2.cuda()

            prediction = model(imgt1, imgt2, seg1,num1,max_num1,seg2,num2,max_num2)
#             prediction = nn.UpsamplingNearest2d(scale_factor=2)(prediction)
            prediction = torch.max(prediction,dim=1)[1]
#             plt.imsave('vis_'+Dataset+'/'+filename[0].split('/')[-1].split('.')[0]+'.png',prediction[0].cpu(),cmap='gray')
            
            Right += torch.sum(prediction == gt)
            Sum += torch.sum(gt>-1)
            
            batch_size = imgt1.shape[0]
            for i in range(batch_size):
                pr_name = filename[i].split('/')[-1].split('.')[0]+'.png'
                plt.imsave('prediction_results_LEVIR/'+pr_name,prediction[i].detach().cpu(),cmap='gray')
            
            impred = prediction
            imlabel = gt

            accuracy = OA(prediction.view(-1),gt.view(-1))
            Accuracies.append(float(accuracy))

            #Precision,recall,Iou
            numclass = 1
            TP +=  int(torch.sum(impred * (impred == imlabel)))
            FP += int(torch.sum(impred * (impred != imlabel)))
            FN += int(torch.sum(imlabel * (impred != imlabel)))

        Average_accuracy = np.mean(Accuracies)
        Overrall_accuracy = float(Right/Sum)
        Iou = TP/(TP+FP+FN)
        Percison = TP/(TP+FP)
        Recall = TP/(TP+FN)
        F1 = (2*Percison*Recall)/(Percison+Recall)

        Average_accuracy = round(Average_accuracy,4)
        Overrall_accuracy = round(Overrall_accuracy,4)
        Iou = round(Iou,4)
        Percison = round(Percison,4)
        Recall = round(Recall,4)
        F1 = round(F1,4)

    print('AA: \t\t',Average_accuracy)
    print('OA:\t\t',Overrall_accuracy)
    print('Iou:\t\t',Iou)
    print('Percison:\t',Percison)
    print('Recall:\t\t',Recall)
    print('F1\t\t',F1)

In [69]:
test_data = d.Dataset(TEST_DATA_PATH, TEST_LABEL_PATH,
                        TEST_TXT_PATH,'test', transform=False, seg_num=seg_num,ratio=2)
test_loader = DataLoader(test_data, batch_size=16,
                            shuffle= False, num_workers= 8, pin_memory= True)

In [70]:
state_dict = torch.load(Log_path+'Best_model.pth')
state_dict = Load_Weight_FordataParallel(state_dict,need_dataparallel=1)
eval_model(test_loader, model, state_dict)

100%|██████████| 128/128 [00:58<00:00,  2.18it/s]

AA: 		 0.9907
OA:		 0.9907
Iou:		 0.8315
Percison:	 0.9188
Recall:		 0.8974
F1		 0.908



