In [None]:
import os
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import random
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import warnings
import sys
import matplotlib as mpl 
warnings.filterwarnings("ignore")

from tqdm.notebook import tqdm
from tqdm import tqdm as tqdmm
from torch.utils.data import DataLoader
from torch.optim import Adam
from IPython.display import HTML, display,clear_output
from collections import Counter

import Utils.Dataload as d

# part = 'shadow/'
# subpart = 'copyMachine/'

part = 'baseline/'
subpart = 'office/'

# part = 'cameraJitter/'
# subpart = 'badminton/'

# part = 'badWeather/'
# subpart = 'blizzard/'

Dataset_store = 'Run_logging/cdnet/'
if 1-os.path.exists(Dataset_store+part):
    os.mkdir(Dataset_store+part)
if 1-os.path.exists(Dataset_store+part+subpart):
    os.mkdir(Dataset_store+part+subpart)
    
Dataset_name = Dataset_store+part+subpart
DATA_PATH = '/home/amax/yyq/Dataset/CDnet2014/change_dataset/'+part+subpart
t_label_path = DATA_PATH+'/total/t_label/'

os.environ['CUDA_VISIBLE_DEVICES'] = '1,0'
device1 = torch.device("cuda:0")
device0 = torch.device("cuda:1")
device2 = torch.device("cpu")

In [None]:
TRAIN_DATA_PATH = os.path.join(DATA_PATH)
TRAIN_LABEL_PATH = os.path.join(DATA_PATH)
TRAIN_TXT_PATH = os.path.join(TRAIN_DATA_PATH,'train.txt')
TEST_DATA_PATH = os.path.join(DATA_PATH)
TEST_LABEL_PATH = os.path.join(DATA_PATH)
TEST_TXT_PATH = os.path.join(TEST_DATA_PATH, 'test.txt')

In [None]:
train_batch_size = 2
test_batch_size = val_batch_size = 2

train_data = d.Dataset(TRAIN_DATA_PATH, TRAIN_LABEL_PATH,
                            TRAIN_TXT_PATH,'train',transform=True)
train_loader = DataLoader(train_data, batch_size=train_batch_size,
                             shuffle= True, num_workers= 8, pin_memory= True)
test_data = d.Dataset(TEST_DATA_PATH, TEST_LABEL_PATH,
                        TEST_TXT_PATH,'test', transform=False)
test_loader = DataLoader(test_data, batch_size=test_batch_size,
                            shuffle= False, num_workers= 8, pin_memory= True)

In [None]:
def plot(x,**kwargs):
    plt.figure(dpi=100)
    plt.axis('off')
    plt.imshow(x,**kwargs)
    plt.show()

In [None]:
def test_model(data_loader):
    global best_acc
    model.eval()
    with torch.no_grad():
        Accuracies = []
        Right = 0
        Sum = 0

        TP = 0
        FP = 0
        FN = 0

        for i,(i1,i2,label,file_name,mask) in enumerate(tqdm(data_loader)):
            i1,i2,label = i1.cuda(),i2.cuda(),label.cuda()

            prediction,c1,c2 = model(i1,i2)
            prediction = (prediction>0.5).int()

            Right += torch.sum(prediction == label)
            Sum += torch.sum(label>-1)

            impred = prediction
            imlabel = label

            accuracy = OA(prediction.view(-1),label.view(-1))
            Accuracies.append(float(accuracy))

            #Precision,recall,Iou
            numclass = 1
            TP +=  int(torch.sum(impred * (impred == imlabel)))
            FP += int(torch.sum(impred * (impred != imlabel)))
            FN += int(torch.sum(imlabel * (impred != imlabel)))

        Average_accuracy = np.mean(Accuracies)
        Overrall_accuracy = float(Right/Sum)
        Iou = TP/(TP+FP+FN)
        Percison = TP/(TP+FP)
        Recall = TP/(TP+FN)
        F1 = (2*Percison*Recall)/(Percison+Recall)

        Average_accuracy = round(Average_accuracy,4)
        Overrall_accuracy = round(Overrall_accuracy,4)
        Iou = round(Iou,4)
        Percison = round(Percison,4)
        Recall = round(Recall,4)
        F1 = round(F1,4)

        print('AA: \t\t',Average_accuracy)
        print('OA:\t\t',Overrall_accuracy)
        print('Iou:\t\t',Iou)
        print('Percison:\t',Percison)
        print('Recall:\t\t',Recall)
        print('F1\t\t',F1)
    
    f = open(Log_path+'Metric_recording.txt','a')
    f.write('Epoch:'+str(epoch)+'   Currentoa:'+str(round(F1,4))+'   Bestoa:'+str(round(best_acc,4))+'\n')
    f.close()
    state_dict = Load_Weight_FordataParallel(model.state_dict(),need_dataparallel=0)
    torch.save(state_dict, Log_path+'Last_model.pth')
    
    if F1 >= best_acc:
        best_acc = F1
        state_dict = Load_Weight_FordataParallel(model.state_dict(),need_dataparallel=0)
        torch.save(state_dict, Log_path+'Best_model.pth')
        f = open(Log_path+'Metric_recording.txt','a')
        f.write('The details of performance: \n')
        f.write('AA: \t\t'+str(Average_accuracy)+'\n'+'OA:\t\t'+str(Overrall_accuracy)+'\n'+'Iou:\t\t'+str(Iou)+'\n'+'Percison:\t'+str(Percison)+'\n'+'Recall:\t\t'+str(Recall)+'\n'+'F1\t\t'+str(F1)+'\n\n')
        f.close()
    
def eval_model():
    model.eval()
    with torch.no_grad():
        Accuracies = []
        Right = 0
        Sum = 0

        TP = 0
        FP = 0
        FN = 0

        for i,(i1,i2,label,file_name,mask) in enumerate(tqdm(test_loader)):
            i1,i2,label = i1.cuda(),i2.cuda(),label.cuda()

            prediction,c1,c2 = model(i1,i2)
            prediction = (prediction>0.5).int()
            
            Right += torch.sum(prediction == label)
            Sum += torch.sum(label>-1)

            impred = prediction
            imlabel = label

            accuracy = OA(prediction.view(-1),label.view(-1))
            Accuracies.append(float(accuracy))

            #Precision,recall,Iou
            TP +=  int(torch.sum(impred * (impred == imlabel)))
            FP += int(torch.sum(impred * (impred != imlabel)))
            FN += int(torch.sum(imlabel * (impred != imlabel)))

        Average_accuracy = np.mean(Accuracies)
        Overrall_accuracy = float(Right/Sum)
        Iou = TP/(TP+FP+FN)
        Percison = TP/(TP+FP)
        Recall = TP/(TP+FN)
        F1 = (2*Percison*Recall)/(Percison+Recall)

        Average_accuracy = round(Average_accuracy,4)
        Overrall_accuracy = round(Overrall_accuracy,4)
        Iou = round(Iou,4)
        Percison = round(Percison,4)
        Recall = round(Recall,4)
        F1 = round(F1,4)

    print('AA: \t\t',Average_accuracy)
    print('OA:\t\t',Overrall_accuracy)
    print('Iou:\t\t',Iou)
    print('Percison:\t',Percison)
    print('Recall:\t\t',Recall)
    print('F1\t\t',F1)
    
def trend_test():
    test_data = d.Dataset(TEST_DATA_PATH, TEST_LABEL_PATH,
                            TEST_TXT_PATH,'test', transform=False)
    test_loader = DataLoader(test_data, batch_size=1,
                                shuffle= False, num_workers= 8, pin_memory= True)

    TPs = [0,0,0]
    FPs = [0,0,0]
    FNs = [0,0,0]
    Ts = [0,0,0]
    Sums = [0,0,0]

    model.eval()
    for iter_, (i1,i2,label,file_name,mask) in enumerate(tqdm(test_loader)):
        i1,i2,label = i1.cuda(),i2.cuda(),label.cuda()
        with torch.no_grad():
            p,c1,c2 = model(i1,i2)
            c1_,c2_ = torch.max(c1,dim=1)[1],torch.max(c2,dim=1)[1]
            t1_semantic, t2_semantic = c1_[0], c2_[0]
            appear = ((t2_semantic-t1_semantic)==t2_semantic)*(t2_semantic!=0).int()
            disappear = ((t1_semantic-t2_semantic)==t1_semantic)*(t1_semantic!=0).int()
            transform = ((t2_semantic!=0).int())*((t1_semantic!=0).int())
            trend_map = (appear+disappear*2+transform*3).cpu()

            trend_name = file_name[0].split('/')[-1].split('.')[0]+'.npy'
            label_trend_path = os.path.join(t_label_path,trend_name)
            label_trend = torch.tensor(np.load(label_trend_path))

            for i,value in enumerate([1,2,3]):
                trend_map_i = (trend_map==value).int()
                label_trend_i = (label_trend==value).int()
                TPs[i]+= int(torch.sum(trend_map_i * (trend_map_i == label_trend_i)))
                FPs[i]+= int(torch.sum(trend_map_i * (trend_map_i != label_trend_i)))
                FNs[i]+= int(torch.sum(label_trend_i * (trend_map_i != label_trend_i)))
                Ts[i] += int(torch.sum(trend_map_i == label_trend_i))
                Sums[i]+=int(torch.sum(trend_map_i >-1))
    OAs = []
    IoUs = []
    Ps = []
    Rs = []
    F1s = []
    for i in range(3):
        TP = TPs[i]
        FP = FPs[i]
        FN = FNs[i]
        OAs.append(Ts[i]/Sums[i])
        IoUs.append(TP/(TP+FP+FN))
        p = TP/(TP+FP)
        r = TP/(TP+FN)
        Ps.append(p)
        Rs.append(r)
        F1s.append((2*p*r)/(p+r))
    f = open(Log_path+'Final_performance.txt','a')
    f.write('OA:'+str(OAs)+'\nIoU:'+str(IoUs)+'\nPrecision:'+str(Ps)+'\nRecall:'+str(Rs)+'\nF1:'+str(F1s)+'\n')
    f.close()


    print(Ps)
    print(Rs)
    print(F1s)
    print(IoUs)
    print(OAs)


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch

def OA(pre_classes, gt_classes):
    return torch.sum((pre_classes) == (gt_classes)).float()/len(pre_classes)

def T_softmax(x,dim=1,T=0.1):
    x = x/T
    x_ = torch.exp(x)
    x = x_/torch.sum(x_,dim=1).unsqueeze(dim)
    return x

from collections import OrderedDict
def Load_Weight_FordataParallel(state_dict, need_dataparallel=0):
        if_dataparallel = 1
        for k, v in state_dict.items():
            name = k[:6]
            if name != "module":
                if_dataparallel = 0
        if need_dataparallel == 1:
            if if_dataparallel == 1:
                return state_dict
            else:
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = "module."+k 
                    new_state_dict[name] = v 
                return new_state_dict
        else:
            if if_dataparallel == 0:
                return state_dict
            else:
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:] 
                    new_state_dict[name] = v 
                return new_state_dict 

class conv_block(nn.Module):
    """
    Convolution Block 
    """
    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):

        x = self.conv(x)
        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, size=None):
        if size is not None:
            x = nn.Upsample(size=size, mode='bilinear')(x)
        else:
            x = nn.Upsample(scale_factor=2, mode='bilinear')(x)
        x = self.up(x)
        return x


class UNet_Encoder(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, nl=32):
        super(UNet_Encoder, self).__init__()

        n1 = nl
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

    def forward(self, x):
        es = {}
        
        e1 = self.Conv1(x)
        es['e1']=e1
        
        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)
        es['e2']=e2

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)
        es['e3']=e3

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)
        es['e4']=e4

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        es['e5']=e5

        return es

class UNet_Decoder_I(nn.Module):
    def __init__(self, out_ch=2, nl = 64):
        super(UNet_Decoder_I, self).__init__()

        n1 = nl
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

    def forward(self, e):
        e5 = e['e5']
        e4 = e['e4']
        e3 = e['e3']
        e2 = e['e2']
        e1 = e['e1']
        
        d5 = self.Up5(e5,size=e4.shape[2:])
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5,size=e3.shape[2:])
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4,size=e2.shape[2:])
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3,size=e1.shape[2:])
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        return d2


class UNet_Decoder_S(nn.Module):
    def __init__(self, out_ch=2, nl = 128):
        super(UNet_Decoder_S, self).__init__()

        n1 = nl
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

    def forward(self, t1_e, t2_e):
        e5 = torch.cat((t1_e['e5'], t2_e['e5']),dim=1)
        e4 = torch.cat((t1_e['e4'], t2_e['e4']),dim=1)
        e3 = torch.cat((t1_e['e3'], t2_e['e3']),dim=1)
        e2 = torch.cat((t1_e['e2'], t2_e['e2']),dim=1)
        e1 = torch.cat((t1_e['e1'], t2_e['e1']),dim=1)
        
        d5 = self.Up5(e5,size=e4.shape[2:])
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5,size=e3.shape[2:])
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4,size=e2.shape[2:])
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3,size=e1.shape[2:])
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        return d2

class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=2, nl=64, trend_num = 3, T=0.1):
        super(UNet, self).__init__()
        self.encoder = UNet_Encoder(in_ch,nl)
        self.decoder_s = UNet_Decoder_S(out_ch,nl=nl*2)
        self.decoder_i = UNet_Decoder_I(out_ch,nl=nl)
        self.conv_s1 = nn.Conv2d(nl+nl*2, trend_num, kernel_size=1)
        self.conv_s2 = nn.Conv2d(nl+nl*2, trend_num, kernel_size=1)
        self.conv_i1 = nn.Conv2d(nl, trend_num, kernel_size=1)
        self.conv_i2 = nn.Conv2d(nl, trend_num, kernel_size=1)
        self.T = T
    def Normalization(self, x,dim=1):
        min_ = torch.min(x,dim=dim)[0].unsqueeze(dim)
        max_ = torch.max(x,dim=dim)[0].unsqueeze(dim)
        x = (x-min_)/(max_-min_)
        return x

    def forward(self, t1_i,t2_i):
        t1_f = self.encoder(t1_i)
        t2_f = self.encoder(t2_i)
        
        t1_s = self.decoder_i(t1_f)
        t2_s = self.decoder_i(t2_f)
        s = self.decoder_s(t1_f,t2_f)
        
        t1_trend_ = self.conv_s1(torch.cat((s,t1_s),dim=1))
        t2_trend_ = self.conv_s2(torch.cat((s,t2_s),dim=1))
        t1_trend_ = t1_trend_ - torch.max(t1_trend_,dim=1)[0].unsqueeze(1)
        t2_trend_ = t2_trend_ - torch.max(t2_trend_,dim=1)[0].unsqueeze(1)

        t1_trend = T_softmax(t1_trend_,dim=1,T=self.T)
        t2_trend = T_softmax(t2_trend_,dim=1,T=self.T)
        
        p = 1-torch.sum(t1_trend*t2_trend,dim=1)
        
        t1_trend_ = self.conv_i1(t1_s)
        t2_trend_ = self.conv_i2(t2_s)
        t1_trend_ = t1_trend_ - torch.max(t1_trend_,dim=1)[0].unsqueeze(1)
        t2_trend_ = t2_trend_ - torch.max(t2_trend_,dim=1)[0].unsqueeze(1)

        t1_trend = T_softmax(t1_trend_,dim=1,T=self.T)
        t2_trend = T_softmax(t2_trend_,dim=1,T=self.T)
        
        return p, t1_trend, t2_trend

In [None]:
epoch_num = 99
trend_num = 3
best_acc = 0
loss_mean = []
pretrained_path = None
# 
Log_path = Dataset_name

In [None]:
model = UNet(3,2,trend_num=trend_num)
model = nn.DataParallel(model)
model = model.cuda()

if pretrained_path is not None:
    state_dict = Load_Weight_FordataParallel(torch.load(pretrained_path),need_dataparallel=1)
    pretrain_state_dict = {}
    for k,v in model.state_dict().items():
        if k in state_dict.keys():
            pretrain_state_dict[k] = state_dict[k]
        else:
            pretrain_state_dict[k] = v
    model.load_state_dict(pretrain_state_dict)

opt = torch.optim.Adam(model.parameters(),lr=1e-4)

STEPS_PER_EPOCH = len(train_loader)
TOTAL_STEPS = STEPS_PER_EPOCH * 80
scheduler = lr_scheduler.StepLR(opt, step_size=TOTAL_STEPS, gamma=0.1)

Loss_function_classify = nn.CrossEntropyLoss()

In [None]:
def loss_p_n(p,label, weight=None):
    loss_p = torch.mean(-label*torch.log(p+1e-10),dim=(1,2))
    loss_n = torch.mean(-(1-label)*torch.log(1-p+1e-10),dim=(1,2))
    if weight is not None:
        loss = loss_p*(1-weight) + loss_n*weight
    else:
        loss = loss_p+loss_n
    return torch.mean(loss)

def loss_p(p,label,weight=None):
    loss_p = torch.mean(-label*torch.log(p+1e-10),dim=(1,2))
    if weight is not None:
        loss_p = loss_p*weight
    return torch.mean(loss_p)

def sigmoid_coe(x,T=0.01):
    y = 1/(1+torch.exp((-x+0.5)/T))
    return y

In [None]:
loss_sum = []

for epoch in range(epoch_num):
    loss_mean = []
    
    loss_mean_1 = []
    loss_mean_2 = []
    for iter_, (i1,i2,label,file_name,mask) in enumerate(tqdm(train_loader)):
        i1,i2,label = i1.cuda(),i2.cuda(),label.cuda()
        p,c1,c2 = model(i1,i2)
        nc1, nc2 = c1[:,0],c2[:,0]
        ratio = torch.sum(label,dim=(1,2))/(label.shape[1]*label.shape[2])

        loss1 = loss_p_n(p,label)
        
        loss2_n = loss_p(nc1*nc2, (1-label))
        loss2_g = loss_p_n(1-torch.sum(c1*c2,dim=1),label)
        loss2 = loss2_n+loss2_g

        loss = (loss1+loss2)*10
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        scheduler.step()
        
        loss_mean.append(loss.item())
        loss_mean_1.append(loss1.item())
        loss_mean_2.append(loss2.item())
                
        if (iter_+1)%10 ==0:
            print('******************')
            print()
            print('Total loss: ',round(np.mean(loss_mean),5))
            print('Change loss: ',round(np.mean(loss_mean_1),5))
            print('Trend loss: ',round(np.mean(loss_mean_2),5))
            print('******************')
            loss_sum.append(np.mean(loss_mean))
            loss_mean = []
            loss_mean_1 = []
            loss_mean_2 = []
            loss_mean_3 = []
    try:
        trend_test()
    except:
        pass

In [None]:
def plot_save(x,**kwargs):
    plt.figure(dpi=100)
    plt.axis('off')
    plt.imsave(arr=x,**kwargs)

In [None]:
colors = ['black', 'deepskyblue', 'white','red'] 
cmap = mpl.colors.ListedColormap(colors)

test_batch_size=1

test_data = d.Dataset(TRAIN_DATA_PATH, TEST_LABEL_PATH,
                        TEST_TXT_PATH,'test', transform=False)
test_loader = DataLoader(test_data, batch_size=test_batch_size,
                            shuffle= False, num_workers= 8, pin_memory= True)

train_data = d.Dataset(TRAIN_DATA_PATH, TRAIN_LABEL_PATH,
                            TRAIN_TXT_PATH,'train',transform=False)
train_loader = DataLoader(train_data, batch_size=test_batch_size,
                             shuffle= True, num_workers= 8, pin_memory= True)

for iter_, (i1,i2,label,file_name,mask) in enumerate(tqdm(test_loader)):
    i1,i2,label = i1.cuda(),i2.cuda(),label.cuda()
    with torch.no_grad():
        p,c1,c2 = model(i1,i2)
        c1_,c2_ = torch.max(c1,dim=1)[1],torch.max(c2,dim=1)[1]
    
    index=0
    save_path = Log_path+'save_pic/'+str(iter_)+'/'
    if 1-os.path.exists(save_path):
        os.makedirs(save_path)
    
    i_name = file_name[index].split('/')[-1]
    print(i_name)

    I1 = i1[index].permute(1,2,0).cpu()+0.5
    I2 = i2[index].permute(1,2,0).cpu()+0.5
    Label = label[index].detach().cpu()


    t1_semantic, t2_semantic = c1_[index], c2_[index]

    appear = ((t2_semantic-t1_semantic)==t2_semantic)*(t2_semantic!=0).int()
    disappear = ((t1_semantic-t2_semantic)==t1_semantic)*(t1_semantic!=0).int()
    transform = ((t2_semantic!=0).int())*((t1_semantic!=0).int())
    trend_map = (appear+disappear*2+transform*3).cpu().numpy()
    
    print('predicted_trend_map')
    plot_save(trend_map,vmin=0,vmax=trend_num,cmap=cmap,fname=save_path+'Trend_P.png')
    
    trend_name = file_name[0].split('/')[-1].split('.')[0]+'.npy'
    label_trend_path = os.path.join(t_label_path,trend_name)
    label_trend = np.load(label_trend_path)
    
    print('label_trend_map')
    plot_save(label_trend,vmin=0,vmax=trend_num,cmap=cmap,fname=save_path+'Trend_GT.png')

    plot_save(i1[index].permute(1,2,0).cpu().numpy()+0.5,fname=save_path+'T1.png')
    plot_save(i2[index].permute(1,2,0).cpu().numpy()+0.5,fname=save_path+'T2.png')
    plot_save((p[index].detach().cpu().numpy()>0.5).astype(int),fname=save_path+'P.png')

    print("label")
    plot_save(label[index].detach().cpu().numpy(),cmap='gray',fname=save_path+'GT.png')
              
    print('**********')
    plot_save(torch.sum(c1[index,1:],dim=0).detach().cpu().numpy(),fname=save_path+'Trend_T1.png')
    plot_save(torch.sum(c2[index,1:],dim=0).detach().cpu().numpy(),fname=save_path+'Trend_T2.png')