# Import library

In [1]:
%load_ext autoreload
%autoreload 2

# speed up the loading of the training data
import cv2
import os
import numpy as np
import torch as th
import itertools
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from onlineTripletloss import *
from selector import *
from model_backbone.model_SAGAN1_1 import NetG, NetD, NetA
# from model_SAGAN2_Triplet import NetG, NetD, NetA
# from model_WGANGP import NetG, NetD, NetA
# from model_WGAN import NetG, NetD, NetA
# from model_siGAN import NetG, NetD, NetA
# from dataset2Loader import CASIABDataset
# from dataset2Loader_newtriplet import CASIABDataset
from dataset2Loader_triplet import CASIABDataset
import torch.optim as optim
import visdom
from torchvision.utils import make_grid

Data_Dir = '../GaitRecognition/GEI_CASIA_B/gei/'
Model_Name = 'Model_64x64_TripletSAGAN_90_trial10'
Model_dir = './Transform_Model/'+ Model_Name
if not os.path.isdir(Model_dir):
    os.mkdir(Model_dir)

## Epoch

In [None]:
dataset = CASIABDataset(data_dir=Data_Dir)
train_loader = th.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=False)
train_loader.next()

In [None]:
#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)

epoches = 700
glr = 0.00002
dlr = 0.00002
real_label = 1
fake_label = 0
batchSize = 32
target = '090'
lambda_gp = 0
beta1 = 0
beta2 = 0
margin = 0
n_g = 0
n_d = 0


netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)


label = th.zeros((batchSize, 1), requires_grad=False).to(device)
optimG = optim.Adam(netg.parameters(), lr=glr/2)
optimD = optim.Adam(netd.parameters(), lr=dlr/3)
optimA = optim.Adam(neta.parameters(), lr=dlr/3)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp))

print('Training starts')
low_loss = 10
for epoch in range(1,epoches+1):
    for i, (ass_label, noass_label, img) in enumerate(train_loader):
#     for i, (ass_label, noass_label, noass_img, img, ass_img) in enumerate(train_loader):
#         print("shape",ass_label.shape,noass_label.shape,img.shape)
        
        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
#         ass_img = ass_img.to(device).to(th.float32)
        img = img.to(device).to(th.float32)
#         noass_img = noass_img.to(device).to(th.float32)

        if i % n_g ==0:
            # update D
            lossD = 0
            optimD.zero_grad()
            output = netd(ass_label)
            label.fill_(real_label)
            lossD_real1 = F.binary_cross_entropy(output, label)
            lossD += lossD_real1.item()
            lossD_real1.backward()

            label.fill_(real_label)
            output1 = netd(noass_label)
            lossD_real2 = F.binary_cross_entropy(output1, label)
            lossD += lossD_real2.item()
            lossD_real2.backward()

            fake = netg(img).detach()
            label.fill_(fake_label)
            output2 = netd(fake)
            lossD_fake = F.binary_cross_entropy(output2, label)
            lossD += lossD_fake.item()
            lossD_fake.backward()

            optimD.step()

            # update A
            lossA = 0
            optimA.zero_grad()
            fake = netg(img).detach()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            faked = th.cat((img, fake), 1)

            label.fill_(real_label)
            output1 = neta(assd)
            lossA_real1 = F.binary_cross_entropy(output1, label)
            lossA += lossA_real1.item()
            lossA_real1.backward()

            label.fill_(fake_label)
            output = neta(noassd)
            lossA_real2 = F.binary_cross_entropy(output, label)
            lossA += lossA_real2.item()
            lossA_real2.backward()

            label.fill_(fake_label)
            output = neta(faked)
            lossA_fake = F.binary_cross_entropy(output, label)
            lossA += lossA_fake.item()
            lossA_fake.backward()

            optimA.step()

        if i % n_d ==0: 
            # update G
            lossG = 0
            optimG.zero_grad()
            fake = netg(img)
            output = netd(fake)

            label.fill_(real_label)
            lossGD = F.binary_cross_entropy(output, label)
            lossG += lossGD.item()
            lossGD.backward(retain_graph=True)

            faked = th.cat((img, fake), 1)
            output = neta(faked)
            label.fill_(real_label)
            lossGA = F.binary_cross_entropy(output, label)
            lossG += lossGA.item()
            lossGA.backward()

            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    
    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {}'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
            
    if (epoch>= 300) and epoch%20==0:  
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
    
    if (epoch>= 550) and (lossG/2)<low_loss:  
        low_loss = lossG/2
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
        



## update Discriminator k times

In [None]:
#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)
        
epoches = 700
glr = 0.00002
dlr = 0.00002
real_label = 1
fake_label = 0
batchSize = 32
target = '090'
lambda_gp = 0
beta1 = 0
beta2 = 0
margin = 0
n_g = 2
n_d = 0


netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)



label = th.zeros((batchSize, 1), requires_grad=False).to(device)
optimG = optim.Adam(netg.parameters(), lr=glr/2)
optimD = optim.Adam(netd.parameters(), lr=dlr/3)
optimA = optim.Adam(neta.parameters(), lr=dlr/3)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp))

low_loss = 10
print('Training starts')
for epoch in range(1,epoches+1):
    for i, (ass_label, noass_label, img) in enumerate(train_loader):
#     for i, (ass_label, noass_label, noass_img, img, ass_img) in enumerate(train_loader):
#         print("shape",ass_label.shape,noass_label.shape,img.shape)
        
        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
#         noass_img = noass_img.to(device).to(th.float32)
        img = img.to(device).to(th.float32)
#         ass_img = ass_img.to(device).to(th.float32)
        
        if i % n_g ==0:
            # update D
            lossD = 0
            optimD.zero_grad()
            output = netd(ass_label)
            label.fill_(real_label)
            lossD_real1 = F.binary_cross_entropy(output, label)
            lossD += lossD_real1.item()
            lossD_real1.backward()

            label.fill_(real_label)
            output1 = netd(noass_label)
            lossD_real2 = F.binary_cross_entropy(output1, label)
            lossD += lossD_real2.item()
            lossD_real2.backward()

            fake = netg(img).detach()
            label.fill_(fake_label)
            output2 = netd(fake)
            lossD_fake = F.binary_cross_entropy(output2, label)
            lossD += lossD_fake.item()
            lossD_fake.backward()

            optimD.step()

            # update A
            lossA = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            fake = netg(img).detach()
            faked = th.cat((img, fake), 1)

            label.fill_(real_label)
            output1 = neta(assd)
            lossA_real1 = F.binary_cross_entropy(output1, label)
            lossA += lossA_real1.item()
            lossA_real1.backward()

            label.fill_(fake_label)
            output = neta(noassd)
            lossA_real2 = F.binary_cross_entropy(output, label)
            lossA += lossA_real2.item()
            lossA_real2.backward()

            label.fill_(fake_label)
            output = neta(faked)
            lossA_fake = F.binary_cross_entropy(output, label)
            lossA += lossA_fake.item()
            lossA_fake.backward()

            optimA.step()
        
        if i % n_d ==0: 
            # update G
            lossG = 0
            optimG.zero_grad()
            fake = netg(img)
            output = netd(fake)

            label.fill_(real_label)
            lossGD = F.binary_cross_entropy(output, label)
            lossG += lossGD.item()
            lossGD.backward(retain_graph=True)

            faked = th.cat((img, fake), 1)
            output = neta(faked)
            label.fill_(real_label)
            lossGA = F.binary_cross_entropy(output, label)
            lossG += lossGA.item()
            lossGA.backward()

            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    
    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {}'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/2)<low_loss:  
        low_loss = lossG/2
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))    
        

# GaitGAN and triplet 

In [None]:
%load_ext autoreload
%autoreload 2

#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)
        
epoches = 700
glr = 0.00002
dlr = 0.00002
real_label = 1
fake_label = 0
batchSize = 32
target = '036'
lambda_gp = 0
beta1 = 0
beta2 = 0
margin = 5
n_g = 0
n_d = 0


netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)



label = th.zeros((batchSize, 1), requires_grad=False).to(device)
optimG = optim.Adam(netg.parameters(), lr=glr/2)
optimD = optim.Adam(netd.parameters(), lr=dlr/3)
optimA = optim.Adam(neta.parameters(), lr=dlr/3)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp))

low_loss = 10

print('Training starts')
for epoch in range(1,epoches+1):
    for i, (ass_label, noass_label, img) in enumerate(train_loader):
#         print("shape",ass_label.shape,noass_label.shape,img.shape)
        
        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
        img = img.to(device).to(th.float32)

        # update D
        if i % n_g==0:
            lossD = 0
            optimD.zero_grad()
            output = netd(ass_label)
            label.fill_(real_label)
            lossD_real1 = F.binary_cross_entropy(output, label)
            lossD += lossD_real1.item()
            lossD_real1.backward()

            label.fill_(real_label)
            output1 = netd(noass_label)
            lossD_real2 = F.binary_cross_entropy(output1, label)
            lossD += lossD_real2.item()
            lossD_real2.backward()

            fake, _ = netg(img)
            label.fill_(fake_label)
            output2 = netd(fake.detach())
            lossD_fake = F.binary_cross_entropy(output2, label)
            lossD += lossD_fake.item()
            lossD_fake.backward()

            optimD.step()

            # update A
            lossA = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            fake,_ = netg(img)
            faked = th.cat((img, fake), 1)

            label.fill_(real_label)
            output1 = neta(assd)
            lossA_real1 = F.binary_cross_entropy(output1, label)
            lossA += lossA_real1.item()
            lossA_real1.backward()

            label.fill_(fake_label)
            output = neta(noassd)
            lossA_real2 = F.binary_cross_entropy(output, label)
            lossA += lossA_real2.item()
            lossA_real2.backward()

            label.fill_(fake_label)
            output = neta(faked.detach())
            lossA_fake = F.binary_cross_entropy(output, label)
            lossA += lossA_fake.item()
            lossA_fake.backward()

            optimA.step()
        
        if i % n_d ==0: 
            # update G
            lossG = 0
            optimG.zero_grad()
            fake, A = netg(img)
            output = netd(fake)

            label.fill_(real_label)
            lossGD = F.binary_cross_entropy(output, label)
            lossG += lossGD.item()
            lossGD.backward(retain_graph=True)

            faked = th.cat((img, fake), 1)
            output = neta(faked)
            label.fill_(real_label)
            lossGA = F.binary_cross_entropy(output, label)
            lossG += lossGA.item()
            lossGA.backward()

            # constrain on generator
            fake_ass, P = netg(ass_label)
            fake_noass, N = netg(noass_label)
            lossTriplet = F.triplet_margin_loss(fake, fake_ass, fake_noass, margin = margin)
            lossG += lossTriplet.item()
            lossTriplet.backward()

            # constrain on encoder
    #         __, P = netg(ass_label)
    #         __, N = netg(noass_label)
    #         lossTriplet = F.triplet_margin_loss(A, P, N, margin = margin)
    #         lossG += lossTriplet.item()
    #         lossTriplet.backward()

            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake,_ = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {}'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/2)<low_loss:  
        low_loss = lossG/2
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))        
    


# GaitGAN k times and triplet   

In [None]:
%load_ext autoreload
%autoreload 2

#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)
        
epoches = 700
glr = 0.00002
dlr = 0.00002
real_label = 1
fake_label = 0
batchSize = 32
target = '090'
lambda_gp = 0
beta1 = 0
beta2 = 0
margin = 10
n_g = 2
n_d = 2


netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)



label = th.zeros((batchSize, 1), requires_grad=False).to(device)
optimG = optim.Adam(netg.parameters(), lr=glr/3)
optimD = optim.Adam(netd.parameters(), lr=dlr/3)
optimA = optim.Adam(neta.parameters(), lr=dlr/3)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp))

low_loss = 10

print('Training starts')
for epoch in range(1,epoches+1):
    for i, (ass_label, noass_label, noass_img, img, ass_img) in enumerate(train_loader):
#         com_img = th.cat((noass_img, img, ass_img), 0)
#         com_label = th.cat(( label_neg, label_anc, label_pos), 0)
#         com_img = com_img.to(device).to(th.float32)
#         com_label = com_label.to(device).to(th.float32)
        
#         if(i ==0):
#             print(label_neg,label_anc,label_pos)
#             print(com_label)
#         print("shape",ass_label.shape,noass_label.shape,img.shape,com_img.shape, com_label.shape)
        
        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
        noass_img = noass_img.to(device).to(th.float32)
        img = img.to(device).to(th.float32)
        ass_img = ass_img.to(device).to(th.float32)
        
        # update D
        if i % n_g==0:
            lossD = 0
            optimD.zero_grad()
            output = netd(ass_label)
            label.fill_(real_label)
            lossD_real1 = F.binary_cross_entropy(output, label)
            lossD += lossD_real1.item()
            lossD_real1.backward()

            label.fill_(real_label)
            output1 = netd(noass_label)
            lossD_real2 = F.binary_cross_entropy(output1, label)
            lossD += lossD_real2.item()
            lossD_real2.backward()

            fake, _ = netg(img)
            label.fill_(fake_label)
            output2 = netd(fake.detach()) #需要 detach 因為不希望更新fake的參數
            lossD_fake = F.binary_cross_entropy(output2, label)
            lossD += lossD_fake.item()
            lossD_fake.backward()

            optimD.step()

            # update A
            lossA = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            fake,_ = netg(img)
            faked = th.cat((img, fake.detach()), 1)  #需要 detach 因為不希望更新fake的參數

            label.fill_(real_label)
            output1 = neta(assd)
            lossA_real1 = F.binary_cross_entropy(output1, label)
            lossA += lossA_real1.item()
            lossA_real1.backward()

            label.fill_(fake_label)
            output = neta(noassd)
            lossA_real2 = F.binary_cross_entropy(output, label)
            lossA += lossA_real2.item()
            lossA_real2.backward()

            label.fill_(fake_label)
            output = neta(faked)
            lossA_fake = F.binary_cross_entropy(output, label)
            lossA += lossA_fake.item()
            lossA_fake.backward()

            optimA.step()
        
        if i % n_d ==0: 
            # update G
            lossG = 0
            optimG.zero_grad()
            fake, A = netg(img)
            output = netd(fake)

            label.fill_(real_label)
            lossGD = F.binary_cross_entropy(output, label)
            lossG += lossGD.item()
            lossGD.backward(retain_graph=True)  ##這裡需要retain graph 因為他之後有需要fake，因此需要retain

            faked = th.cat((img, fake), 1)
            output = neta(faked)
            label.fill_(real_label)
            lossGA = F.binary_cross_entropy(output, label)
            lossG += lossGA.item()
            lossGA.backward()

            ## new tripletloss
            _, P = netg(ass_img)
            __, N = netg(noass_img)
            lossf = TripletLoss(margin) #因為之前的A都沒被使用過所以不會遺失，如果是拿fake ，LossGA就需要再retain graph
            lossTriplet =lossf(A, P, N)
            lossG += lossTriplet.item()
            lossTriplet.backward()

    #         ## new onlinetripletloss
    #         __, com = netg(com_img)
    #         loss_fn = OnlineTripletLoss(margin, SemihardNegativeTripletSelector(margin))
    #         lossTriplet,len_triplet = loss_fn(com, com_label)
    #         lossG += lossTriplet.item()
    #         lossTriplet.backward()
    # #         print(lossTriplet.item(),len_triplet)

    #         ## triplet loss
    #         __, P = netg(ass_img)
    #         __, N = netg(noass_img)
    #         lossTriplet = F.triplet_margin_loss(A, P, N, margin = margin) #因為之前的A都沒被使用過所以不會遺失，如果是拿fake ，LossGA就需要再retain graph
    #         lossG += lossTriplet.item()
    #         lossTriplet.backward()
    # #         if i%10==0:
    # #             print("tripletloss ",lossTriplet.item())

            ## tripletloss no negative
    #         N_plus = th.zeros((A.size()), requires_grad=False).to(device)
    #         lossTriplet_AP = F.triplet_margin_loss(A, P, N_plus, margin = margin)
    #         lossG += lossTriplet_AP.item()
    #         lossTriplet += lossTriplet_AP
    #         lossTriplet.backward()

            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake,_ = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {}'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/2)<low_loss:  
        low_loss = lossG/2
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG  Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))        
    


# WGAN 

In [2]:
#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)
        
epoches = 700
glr = 0.00002
dlr = 0.00002
real_label = 1
fake_label = 0
batchSize = 32
target = '090'
lambda_gp = 0
beta1 = 0
beta2 = 0
margin = 0
n_g = 1
n_d = 2
clip = 0.1


netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)



label = th.zeros((batchSize, 1), requires_grad=False).to(device)
optimG = optim.RMSprop(netg.parameters(), lr=glr/2)
optimD = optim.RMSprop(netd.parameters(), lr=dlr/3)
optimA = optim.RMSprop(neta.parameters(), lr=dlr/3)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={},clip={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp, clip))

low_loss = 10

print('Training starts')
for epoch in range(1,epoches+1):
    for i, (ass_label, noass_label, img) in enumerate(train_loader):
#         print("shape",ass_label.shape,noass_label.shape,img.shape)
        

        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
        img = img.to(device).to(th.float32)

        # update D
        if i % n_g==0:
            lossD = 0
            optimD.zero_grad()
            output = netd(ass_label)
    #         label.fill_(real_label)
    #         lossD_real1 = F.binary_cross_entropy(output, label)
            lossD_real1 = -th.mean(output)
            lossD += lossD_real1.item()
            lossD_real1.backward()

    #         label.fill_(real_label)
            output1 = netd(noass_label)
    #         lossD_real2 = F.binary_cross_entropy(output1, label)
            lossD_real2 = -th.mean(output1)
            lossD += lossD_real2.item()
            lossD_real2.backward()

            fake = netg(img).detach()
    #         label.fill_(fake_label)
            output2 = netd(fake)
    #         lossD_fake = F.binary_cross_entropy(output2, label)
            lossD_fake = th.mean(output2)
            lossD += lossD_fake.item()
            lossD_fake.backward()

            optimD.step()

            for p in netd.parameters():
                p.data.clamp_(-clip, clip)

            # update A
            lossA = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            fake = netg(img).detach()
            faked = th.cat((img, fake), 1)

    #         label.fill_(real_label)
            output1 = neta(assd)
    #         lossA_real1 = F.binary_cross_entropy(output1, label)
            lossA_real1 = -th.mean(output1)
            lossA += lossA_real1.item()
            lossA_real1.backward()

    #         label.fill_(fake_label)
            output = neta(noassd)
    #         lossA_real2 = F.binary_cross_entropy(output, label)
            lossA_real2 = th.mean(output)
            lossA += lossA_real2.item()
            lossA_real2.backward()

    #         label.fill_(fake_label)
            output = neta(faked)
    #         lossA_fake = F.binary_cross_entropy(output, label)
            lossA_fake = th.mean(output)
            lossA += lossA_fake.item()
            lossA_fake.backward()

            optimA.step()

            for p in neta.parameters():
                p.data.clamp_(-clip, clip)

        # update G
        if i % n_d == 0:
            lossG = 0
            optimG.zero_grad()
            fake = netg(img)
            output = netd(fake)

#             label.fill_(real_label)
#             lossGD = F.binary_cross_entropy(output, label)
            lossGD = -th.mean(output)
            lossG += lossGD.item()
            lossGD.backward(retain_graph=True)

            faked = th.cat((img, fake), 1)
            output = neta(faked)
#             label.fill_(real_label)
#             lossGA = F.binary_cross_entropy(output, label)
            lossGA = -th.mean(output)
            lossG += lossGA.item()
            lossGA.backward()

            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    
    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {}'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/2)<low_loss:  
        low_loss = lossG/2
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))        
        
        




n_con= 10 ,n_ang= 11
target =  090
write parameter log...
Training starts
Epoch = 2, ErrG = 3.752237558364868, ErrA = -1.8985527356465657, ErrD = -3.5465147892634072
Epoch = 4, ErrG = 5.705430269241333, ErrA = -2.807316621144613, ErrD = -5.611066023508708
Epoch = 6, ErrG = 4.896071195602417, ErrA = -3.07237180074056, ErrD = -6.601948658625285
Epoch = 8, ErrG = 8.87260103225708, ErrA = -4.403616587320964, ErrD = -9.87374464670817
Epoch = 10, ErrG = 7.947443723678589, ErrA = -4.1470723152160645, ErrD = -8.064921021461487
Epoch = 12, ErrG = 1.238088607788086, ErrA = -4.921435038248698, ErrD = -4.710442225138347
Epoch = 14, ErrG = 0.5194168090820312, ErrA = -5.6255998611450195, ErrD = -5.323735237121582
Epoch = 16, ErrG = 6.348399639129639, ErrA = -7.5277970631917315, ErrD = -7.016017913818359
Epoch = 18, ErrG = 0.8842306137084961, ErrA = -7.25491460164388, ErrD = -7.062686284383138
Epoch = 20, ErrG = 0.0035724639892578125, ErrA = -7.803771336873372, ErrD = -7.927731831868489
Epoch = 22, E

Epoch = 178, ErrG = 13.589683532714844, ErrA = -53.47724533081055, ErrD = -60.667805989583336
Epoch = 180, ErrG = 27.157360076904297, ErrA = -54.507459004720054, ErrD = -51.485514322916664
Epoch = 182, ErrG = 56.5279426574707, ErrA = -54.72101338704427, ErrD = -54.934326171875
Epoch = 184, ErrG = 26.18047523498535, ErrA = -55.34730529785156, ErrD = -59.38808059692383
Epoch = 186, ErrG = 13.245643615722656, ErrA = -56.8813362121582, ErrD = -60.9880002339681
Epoch = 188, ErrG = 14.880294799804688, ErrA = -53.91685994466146, ErrD = -41.10221862792969
Epoch = 190, ErrG = 33.4071044921875, ErrA = -55.473968505859375, ErrD = -54.388196309407554
Epoch = 192, ErrG = 54.02478504180908, ErrA = -50.48332722981771, ErrD = -56.40521113077799
Epoch = 194, ErrG = 35.50703430175781, ErrA = -58.89679463704427, ErrD = -55.31469599405924
Epoch = 196, ErrG = 47.07379150390625, ErrA = -56.774068196614586, ErrD = -64.39537811279297
Epoch = 198, ErrG = 8.856361389160156, ErrA = -59.77715301513672, ErrD = -64

Epoch = 356, ErrG = 33.6126594543457, ErrA = -54.70630900065104, ErrD = -60.700757344563804
Epoch = 358, ErrG = 38.190752029418945, ErrA = -54.00933074951172, ErrD = -58.47474924723307
Epoch = 360, ErrG = 24.69669532775879, ErrA = -51.98524602254232, ErrD = -69.35917663574219
Epoch = 362, ErrG = 20.46735382080078, ErrA = -58.84796142578125, ErrD = -65.61435190836589
Epoch = 364, ErrG = 32.753963470458984, ErrA = -54.45555623372396, ErrD = -72.78407796223958
Epoch = 366, ErrG = 50.228535652160645, ErrA = -63.61809285481771, ErrD = -68.19214502970378
Epoch = 368, ErrG = 35.94788932800293, ErrA = -55.87086486816406, ErrD = -63.45636494954427
Epoch = 370, ErrG = 13.586647033691406, ErrA = -55.783398946126304, ErrD = -54.916760762532554
Epoch = 372, ErrG = 17.255874633789062, ErrA = -48.30359903971354, ErrD = -46.36835225423177
Epoch = 374, ErrG = 32.9892520904541, ErrA = -55.58989715576172, ErrD = -59.28937530517578
Epoch = 376, ErrG = 22.925832748413086, ErrA = -53.68057378133138, ErrD = 

Epoch = 534, ErrG = 35.67050743103027, ErrA = -64.6280288696289, ErrD = -75.14555549621582
Epoch = 536, ErrG = 46.15334701538086, ErrA = -52.71881357828776, ErrD = -65.6470947265625
Epoch = 538, ErrG = 28.56289291381836, ErrA = -67.59828821818034, ErrD = -64.95235061645508
Epoch = 540, ErrG = 29.25141143798828, ErrA = -51.203712463378906, ErrD = -58.197113037109375
Epoch = 542, ErrG = 42.39829158782959, ErrA = -52.20666249593099, ErrD = -70.6221415201823
Epoch = 544, ErrG = 0.5973739624023438, ErrA = -60.515010833740234, ErrD = -52.27333068847656
Epoch = 546, ErrG = 56.850443840026855, ErrA = -64.28490193684895, ErrD = -68.94095865885417
Epoch = 548, ErrG = 48.342369079589844, ErrA = -63.897132873535156, ErrD = -73.07439422607422
Epoch = 550, ErrG = 20.176572799682617, ErrA = -54.681190490722656, ErrD = -57.4961903889974
Epoch = 552, ErrG = 27.500083923339844, ErrA = -62.00841776529948, ErrD = -61.99379221598307
Epoch = 554, ErrG = 41.838890075683594, ErrA = -57.539103190104164, ErrD =

 # WGAN-GP

In [None]:
%load_ext autoreload
%autoreload 2
from torch.autograd import grad, Variable

#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)
        
epoches = 700
glr = 0.00002
dlr = 0.00002
real_label = 1
fake_label = 0
batchSize = 32
target = '090'
lambda_gp = 10
beta1 = 0
beta2 = 0.999
margin = 0
n_g = 0
n_d = 5

netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)


optimG = optim.Adam(netg.parameters(), lr=glr, betas=(beta1, beta2))
optimD = optim.Adam(netd.parameters(), lr=dlr, betas=(beta1, beta2))
optimA = optim.Adam(neta.parameters(), lr=dlr, betas=(beta1, beta2))
# optimG = optim.RMSprop(netg.parameters(), lr=lr)
# optimD = optim.RMSprop(netd.parameters(), lr=lr)
# optimA = optim.RMSprop(neta.parameters(), lr=lr)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp))

low_loss = 10

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
#     alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    alpha = th.rand((batchSize, 1, 1, 1)).to(device).to(th.float32)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = D(interpolates)
#     fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    gradients = grad(outputs=d_interpolates, 
                     inputs=interpolates, 
                     grad_outputs=th.ones([real_samples.shape[0],1]).to(device).requires_grad_(False),
#                      grad_outputs = fake,
                     create_graph=True, 
                     retain_graph=True, 
                     only_inputs=True)[0]
#     fake = Variable((real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
#     # Get gradient w.r.t. interpolates
#     gradients = autograd.grad(
#         outputs=d_interpolates,
#         inputs=interpolates,
#         grad_outputs=fake,
#         create_graph=True,
#         retain_graph=True,
#         only_inputs=True,
#     )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

print('Training starts')
for epoch in range(1,epoches+1):
    for i, (ass_label, noass_label, img) in enumerate(train_loader):
#         print("shape",ass_label.shape,noass_label.shape,img.shape)
        
        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
        img = img.to(device).to(th.float32)
        
        # update D
        if i % n_g==0:
            lossD = 0
            lossD_ = 0
            optimD.zero_grad()
            output = netd(ass_label)
    #         label.fill_(real_label)
    #         lossD_real1 = F.binary_cross_entropy(output, label)
            lossD_real1 = -th.mean(output)
            lossD_ += lossD_real1
            lossD += lossD_real1.item()
    #         lossD_real1.backward()

    #         label.fill_(real_label)
            output1 = netd(noass_label)
    #         lossD_real2 = F.binary_cross_entropy(output1, label)
            lossD_real2 = -th.mean(output1)
            lossD_ += lossD_real2
            lossD += lossD_real2.item()
    #         lossD_real2.backward()

            fake = netg(img).detach()
    #         label.fill_(fake_label)
            output2 = netd(fake)
    #         lossD_fake = F.binary_cross_entropy(output2, label)
            lossD_fake = th.mean(output2)
            lossD_ += lossD_fake
            lossD += lossD_fake.item()
            gradient_penalty = compute_gradient_penalty(netd, ass_label.data, fake.data)
    #         lossD_fake.backward()
            lossD_ = lossD_/3 + lambda_gp * gradient_penalty
            lossD_.backward()

            optimD.step()


            # update A
            lossA = 0
            lossA_ = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            fake = netg(img).detach()
            faked = th.cat((img, fake), 1)

    #         label.fill_(real_label)
            output1 = neta(assd)
    #         lossA_real1 = F.binary_cross_entropy(output1, label)
            lossA_real1 = -th.mean(output1)
            lossA += lossA_real1.item()
            lossA_ += lossA_real1
    #         lossA_real1.backward()

    #         label.fill_(fake_label)
            output = neta(noassd)
    #         lossA_real2 = F.binary_cross_entropy(output, label)
            lossA_real2 = th.mean(output)
            lossA += lossA_real2.item()
            lossA_ += lossA_real2
    #         lossA_real2.backward()

    #         label.fill_(fake_label)
            output = neta(faked)
    #         lossA_fake = F.binary_cross_entropy(output, label)
            lossA_fake = th.mean(output)
            lossA += lossA_fake.item()
            lossA_ += lossA_fake
    #         lossA_fake.backward()
            gradient_penalty = compute_gradient_penalty(neta, assd.data, faked.data)
            lossA_ = lossA_/3 + lambda_gp * gradient_penalty

            lossA_.backward()
            optimA.step()
    
            
        # update G
        if i % n_d == 0:
            lossG = 0
            lossG_ = 0
            optimG.zero_grad()
            fake = netg(img)
            output = netd(fake)

#             label.fill_(real_label)
#             lossGD = F.binary_cross_entropy(output, label)
            lossGD = -th.mean(output)
            lossG += lossGD.item()
            lossG_ += lossGD
#             lossGD.backward(retain_graph=True)

            faked = th.cat((img, fake), 1)
            output = neta(faked)
#             label.fill_(real_label)
#             lossGA = F.binary_cross_entropy(output, label)
            lossGA = -th.mean(output)
            lossG += lossGA.item()
            lossG_ += lossGA
#             lossGA.backward()
            lossG_ = lossG_/2
            lossG_.backward(retain_graph=True)
            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {}'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/2)<low_loss:  
        low_loss = lossG/2
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))           

        
        


# SA GaitGAN (hing)

In [None]:
%load_ext autoreload
%autoreload 2
from torch.autograd import grad, Variable

#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)
        

epoches = 700
glr = 0.00001
dlr = 0.00004
real_label = 1
fake_label = 0
batchSize = 32
target = '090'
lambda_gp = 0
beta1 = 0
beta2 = 0.9
margin = 10
n_g = 2
n_d = 1


netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)


optimG = optim.Adam(netg.parameters(), lr=glr, betas=(beta1, beta2))
optimD = optim.Adam(netd.parameters(), lr=dlr, betas=(beta1, beta2))
optimA = optim.Adam(neta.parameters(), lr=dlr, betas=(beta1, beta2))
# optimG = optim.RMSprop(netg.parameters(), lr=lr)
# optimD = optim.RMSprop(netd.parameters(), lr=lr)
# optimA = optim.RMSprop(neta.parameters(), lr=lr)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp))

low_loss = 10

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
#     alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    alpha = th.rand((batchSize, 1, 1, 1)).to(device).to(th.float32)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = D(interpolates)
#     fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    gradients = grad(outputs=d_interpolates, 
                     inputs=interpolates, 
                     grad_outputs=th.ones([real_samples.shape[0],1]).to(device).requires_grad_(False),
#                      grad_outputs = fake,
                     create_graph=True, 
                     retain_graph=True, 
                     only_inputs=True)[0]
#     fake = Variable((real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
#     # Get gradient w.r.t. interpolates
#     gradients = autograd.grad(
#         outputs=d_interpolates,
#         inputs=interpolates,
#         grad_outputs=fake,
#         create_graph=True,
#         retain_graph=True,
#         only_inputs=True,
#     )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

print('Training starts')
for epoch in range(1,epoches+1):
    for i, (ass_label, noass_label, img) in enumerate(train_loader):
#         print("shape",ass_label.shape,noass_label.shape,img.shape)
        
        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
        img = img.to(device).to(th.float32)

        if i % n_g ==0:
            # update D
            lossD = 0
            lossD_ = 0
            optimD.zero_grad()
            d_out_assreal,dr1 = netd(ass_label)
            d_loss_assreal = nn.ReLU()(1.0 - d_out_assreal).mean()

            lossD_ += d_loss_assreal
            lossD += d_loss_assreal.item()

            d_out_noassreal,dr2 = netd(noass_label)
            d_loss_noassreal = nn.ReLU()(1.0 - d_out_noassreal).mean()

            lossD_ += d_loss_noassreal
            lossD += d_loss_noassreal.item()

            fake, gf1 = netg(img)
            d_out_fake, df1 = netd(fake.detach())  #需要 detach 因為不希望更新fake的參數
            d_loss_fake = nn.ReLU()(1.0 + d_out_fake).mean()

            lossD_ += d_loss_fake
            lossD += d_loss_fake.item()
    #         gradient_penalty = compute_gradient_penalty(netd, ass_label.data, fake.data)
            lossD_ = lossD_/3
            lossD_.backward()
            optimD.step()


            # update A
            lossA = 0
            lossA_ = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            faked, gf1 = netg(img)
            faked = th.cat((img, faked.detach()), 1)  #需要 detach 因為不希望更新fake的參數

            d_out_assreal,dr1 = neta(assd)
            d_loss_assreal = nn.ReLU()(1.0 - d_out_assreal).mean()
            lossA += d_loss_assreal.item()
            lossA_ += d_loss_assreal

            d_out_noassreal,dr2 = neta(noassd)
            d_loss_noassreal = nn.ReLU()(1.0 + d_out_noassreal).mean()

            lossA_ += d_loss_noassreal
            lossA += d_loss_noassreal.item()

            d_out_faked, df3 = neta(faked)
            d_loss_faked = nn.ReLU()(1.0 + d_out_faked).mean()

            lossA_ += d_loss_faked
            lossA += d_loss_faked.item()
    #         gradient_penalty = compute_gradient_penalty(neta, assd.data, faked.data)
            lossA_ = lossA_/3
            lossA_.backward()
            optimA.step()

            
        # update G
        if i % n_d == 0:
            lossG = 0
            lossG_ = 0
            optimG.zero_grad()
            fake,_= netg(img)
            g_out_fake,_ = netd(fake)
            g_loss_fake = - g_out_fake.mean()

            lossG += g_loss_fake.item()
            lossG_ += g_loss_fake

            faked = th.cat((img, fake), 1)
            g_out_faked,_ = neta(faked)
            g_loss_faked = - g_out_faked.mean()
            lossG += g_loss_faked.item()
            lossG_ += g_loss_faked

            lossG_ = lossG_/2
            lossG_.backward(retain_graph=True) ##其實這裡也不需要retain graph 因為他只 backward 一次；
            #如果爾後有要再backward，第一次就需要retain graph
            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake,_ = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/2, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {}, Gattn={}, Dattn={}, Aattn={}'.format(
            epoch, lossG/2, lossA/3, lossD/3, netg.attn.gamma.item(), netd.attn.gamma.item(), neta.attn.gamma.item()
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/2)<low_loss:  
        low_loss = lossG/2
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/2, lossA/3, lossD/3
        ))           

        

# SA GaitGAN and triplet 

In [2]:
%load_ext autoreload
%autoreload 2
from torch.autograd import grad, Variable

#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)

epoches = 700
glr = 0.00001
dlr = 0.00004
real_label = 1
fake_label = 0
batchSize = 32
target = '090'
lambda_gp = 0
beta1 = 0
beta2 = 0.9
margin = 5
n_g = 2
n_d = 1

netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)


optimG = optim.Adam(netg.parameters(), lr=glr, betas=(beta1, beta2))
optimD = optim.Adam(netd.parameters(), lr=dlr, betas=(beta1, beta2))
optimA = optim.Adam(neta.parameters(), lr=dlr, betas=(beta1, beta2))
# optimG = optim.RMSprop(netg.parameters(), lr=lr)
# optimD = optim.RMSprop(netd.parameters(), lr=lr)
# optimA = optim.RMSprop(neta.parameters(), lr=lr)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp))

low_loss = 10

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
#     alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    alpha = th.rand((batchSize, 1, 1, 1)).to(device).to(th.float32)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = D(interpolates)
#     fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    gradients = grad(outputs=d_interpolates, 
                     inputs=interpolates, 
                     grad_outputs=th.ones([real_samples.shape[0],1]).to(device).requires_grad_(False),
#                      grad_outputs = fake,
                     create_graph=True, 
                     retain_graph=True, 
                     only_inputs=True)[0]
#     fake = Variable((real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
#     # Get gradient w.r.t. interpolates
#     gradients = autograd.grad(
#         outputs=d_interpolates,
#         inputs=interpolates,
#         grad_outputs=fake,
#         create_graph=True,
#         retain_graph=True,
#         only_inputs=True,
#     )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

print('Training starts')
for epoch in range(1,epoches+1):
#     for i, (ass_label, noass_label, img) in enumerate(train_loader):
    for i, (ass_label, noass_label, noass_img, img, ass_img) in enumerate(train_loader):
#         print("shape",ass_label.shape,noass_label.shape,img.shape)
        
    # while epoch < 30000:
    #     ass_label, noass_label, img = dataset.getbatch(batchSize)
        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
        noass_img = noass_img.to(device).to(th.float32)
        img = img.to(device).to(th.float32)
        ass_img = ass_img.to(device).to(th.float32)

        if i % n_g ==0:
            # update D
            lossD = 0
            lossD_ = 0
            optimD.zero_grad()
            d_out_assreal,dr1 = netd(ass_label)
            d_loss_assreal = nn.ReLU()(1.0 - d_out_assreal).mean()

            lossD_ += d_loss_assreal
            lossD += d_loss_assreal.item()

            d_out_noassreal,dr2 = netd(noass_label)
            d_loss_noassreal = nn.ReLU()(1.0 - d_out_noassreal).mean()

            lossD_ += d_loss_noassreal
            lossD += d_loss_noassreal.item()

            fake, code = netg(img)
            d_out_fake, df1 = netd(fake.detach())
            d_loss_fake = nn.ReLU()(1.0 + d_out_fake).mean()

            lossD_ += d_loss_fake
            lossD += d_loss_fake.item()
    #         gradient_penalty = compute_gradient_penalty(netd, ass_label.data, fake.data)
            lossD_ = lossD_/3
            lossD_.backward()
            optimD.step()


            # update A
            lossA = 0
            lossA_ = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            faked, code = netg(img)
            faked = th.cat((img, faked.detach()), 1)

            d_out_assreal,dr1 = neta(assd)
            d_loss_assreal = nn.ReLU()(1.0 - d_out_assreal).mean()
            lossA += d_loss_assreal.item()
            lossA_ += d_loss_assreal

            d_out_noassreal,dr2 = neta(noassd)
            d_loss_noassreal = nn.ReLU()(1.0 + d_out_noassreal).mean()

            lossA_ += d_loss_noassreal
            lossA += d_loss_noassreal.item()

            d_out_faked, df3 = neta(faked)
            d_loss_faked = nn.ReLU()(1.0 + d_out_faked).mean()

            lossA_ += d_loss_faked
            lossA += d_loss_faked.item()
    #         gradient_penalty = compute_gradient_penalty(neta, assd.data, faked.data)
            lossA_ = lossA_/3
            lossA_.backward()
            optimA.step()

            
        # update G
        if i % n_d == 0:
            lossG = 0
            lossG_ = 0
            optimG.zero_grad()
            fake, A= netg(img)
            g_out_fake,_ = netd(fake)
            g_loss_fake = - g_out_fake.mean()

            lossG += g_loss_fake.item()
            lossG_ += g_loss_fake

            faked = th.cat((img, fake), 1)
            g_out_faked,_ = neta(faked)
            g_loss_faked = - g_out_faked.mean()
            lossG += g_loss_faked.item()
            lossG_ += g_loss_faked

            # constrain on generator
            fake_ass, P = netg(ass_img)
            fake_noass, N = netg(noass_img)
            lossTriplet = F.triplet_margin_loss(A, P, N, margin = margin)
            lossG_ += lossTriplet
            lossG += lossTriplet.item()
    #         lossTriplet.backward()

            lossG_ = lossG_/3
            lossG_.backward(retain_graph=True)
            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake,_ = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/3, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/3, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/3, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrTri = {}, ErrA = {}, ErrD = {}, Gattn={}, Dattn={}, Aattn={}'.format(
            epoch, lossG/3,lossTriplet.item(), lossA/3, lossD/3, netg.attn.gamma.item(), netd.attn.gamma.item(), neta.attn.gamma.item()
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/3)<low_loss:  
        low_loss = lossG/3
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/3, lossA/3, lossD/3
        ))           



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
n_con= 10 ,n_ang= 11
target =  090
write parameter log...
Training starts
Epoch = 2, ErrG = 0.9585853020350138, ErrTri = 2.466822624206543, ErrA = 0.7065815528233846, ErrD = 0.6579362954944372, Gattn=-0.004149867687374353, Dattn=0.00046767963794991374, Aattn=-8.339400665136054e-05
Epoch = 4, ErrG = 0.7519396940867106, ErrTri = 2.0908613204956055, ErrA = 0.7104401787122091, ErrD = 0.6771175134927034, Gattn=-0.00813620537519455, Dattn=-2.712438435992226e-05, Aattn=1.7503132767160423e-05
Epoch = 6, ErrG = 0.6494477788607279, ErrTri = 1.9485763311386108, ErrA = 0.6816310677677393, ErrD = 0.674371782069405, Gattn=-0.011028623208403587, Dattn=-9.603230137145147e-05, Aattn=-6.488304734375561e-06
Epoch = 8, ErrG = 0.7547747691472372, ErrTri = 2.2539944648742676, ErrA = 0.6777164159963528, ErrD = 0.6727034399906794, Gattn=-0.011997980065643787, Dattn=-8.682869520271197e-05, Aattn=-1.9308947230456397e-05
Epoc

Epoch = 80, ErrG = 0.12779359022776285, ErrTri = 0.3435516953468323, ErrA = 0.667331455896298, ErrD = 0.6712576548258463, Gattn=-0.012592598795890808, Dattn=-0.04463300108909607, Aattn=-0.010896462947130203
Epoch = 82, ErrG = 0.13504366079966226, ErrTri = 0.3858763575553894, ErrA = 0.6675103244682153, ErrD = 0.668526558826367, Gattn=-0.012568212114274502, Dattn=-0.04516434296965599, Aattn=-0.010990452021360397
Epoch = 84, ErrG = 0.2641621232032776, ErrTri = 0.7865022420883179, ErrA = 0.6666676619400581, ErrD = 0.6699176772187153, Gattn=-0.012581859715282917, Dattn=-0.04560864716768265, Aattn=-0.011078497394919395
Epoch = 86, ErrG = 0.1082709829012553, ErrTri = 0.2936580777168274, ErrA = 0.6693165500958761, ErrD = 0.6687939849992593, Gattn=-0.012538941577076912, Dattn=-0.04600320756435394, Aattn=-0.011201289482414722
Epoch = 88, ErrG = 0.20183048645655313, ErrTri = 0.5872557163238525, ErrA = 0.6697991186132034, ErrD = 0.6681809425354004, Gattn=-0.01254214532673359, Dattn=-0.046354942023

Epoch = 160, ErrG = 0.04624592264493307, ErrTri = 0.14531847834587097, ErrA = 0.6654125942538182, ErrD = 0.6675088405609131, Gattn=-0.011807342991232872, Dattn=-0.04994853958487511, Aattn=-0.02006554789841175
Epoch = 162, ErrG = 0.15079299608866373, ErrTri = 0.4418911933898926, ErrA = 0.6669775992631912, ErrD = 0.6680883566538492, Gattn=-0.011787527240812778, Dattn=-0.04997759312391281, Aattn=-0.020188095048069954
Epoch = 164, ErrG = 0.08177668352921803, ErrTri = 0.24216826260089874, ErrA = 0.6677520424127579, ErrD = 0.6688441435496012, Gattn=-0.011746217496693134, Dattn=-0.05001070722937584, Aattn=-0.020334459841251373
Epoch = 166, ErrG = 0.1105264921983083, ErrTri = 0.32114341855049133, ErrA = 0.6676631731291612, ErrD = 0.6682314872741699, Gattn=-0.01171492412686348, Dattn=-0.05003757402300835, Aattn=-0.020436042919754982
Epoch = 168, ErrG = 0.07819304863611858, ErrTri = 0.2211034893989563, ErrA = 0.6673837937414646, ErrD = 0.6680321854849657, Gattn=-0.011712619103491306, Dattn=-0.05

Epoch = 240, ErrG = 0.06343314051628113, ErrTri = 0.17051109671592712, ErrA = 0.6662570306410392, ErrD = 0.6675957838694254, Gattn=-0.011148124933242798, Dattn=-0.050880447030067444, Aattn=-0.024350441992282867
Epoch = 242, ErrG = 0.03076508641242981, ErrTri = 0.12920352816581726, ErrA = 0.6760872583836317, ErrD = 0.6685718602190415, Gattn=-0.011108435690402985, Dattn=-0.05089928209781647, Aattn=-0.02452228218317032
Epoch = 244, ErrG = 0.0406149427096049, ErrTri = 0.1906483769416809, ErrA = 0.667980999375383, ErrD = 0.6654911724229654, Gattn=-0.01110122911632061, Dattn=-0.050915610045194626, Aattn=-0.024672985076904297
Epoch = 246, ErrG = 0.04960868755976359, ErrTri = 0.152554452419281, ErrA = 0.6683341929068168, ErrD = 0.6685439745585123, Gattn=-0.01108759269118309, Dattn=-0.0509263314306736, Aattn=-0.024717384949326515
Epoch = 248, ErrG = 0.06105399131774902, ErrTri = 0.18426120281219482, ErrA = 0.6671962607651949, ErrD = 0.6683210531870524, Gattn=-0.011069265194237232, Dattn=-0.0509

Epoch = 320, ErrG = 0.026926105221112568, ErrTri = 0.14490027725696564, ErrA = 0.6673168887694677, ErrD = 0.6674735341221094, Gattn=-0.010641888715326786, Dattn=-0.051305003464221954, Aattn=-0.027669396251440048
Epoch = 322, ErrG = 0.015490168084700903, ErrTri = 0.04573426768183708, ErrA = 0.665820856889089, ErrD = 0.6601831782609224, Gattn=-0.010675662197172642, Dattn=-0.05130710452795029, Aattn=-0.02767622470855713
Epoch = 324, ErrG = 0.028827245036760967, ErrTri = 0.10632704198360443, ErrA = 0.6672664651026329, ErrD = 0.6675263550132513, Gattn=-0.010639852844178677, Dattn=-0.05131368339061737, Aattn=-0.027707574889063835
Epoch = 326, ErrG = 0.010874231656392416, ErrTri = 0.05341231822967529, ErrA = 0.6664720041056474, ErrD = 0.6677729946871599, Gattn=-0.01066367607563734, Dattn=-0.05132419988512993, Aattn=-0.027837183326482773
Epoch = 328, ErrG = 0.06360767285029094, ErrTri = 0.20423787832260132, ErrA = 0.6760392362872759, ErrD = 0.6653173758337895, Gattn=-0.01062257681041956, Dattn

Epoch = 400, ErrG = 0.021567632754643757, ErrTri = 0.12746253609657288, ErrA = 0.6599019343654314, ErrD = 0.6571185514330864, Gattn=-0.01032545417547226, Dattn=-0.0519290491938591, Aattn=-0.030334975570440292
Epoch = 402, ErrG = 0.04079435269037882, ErrTri = 0.05331122875213623, ErrA = 0.6625541479637226, ErrD = 0.6555084679275751, Gattn=-0.010306482203304768, Dattn=-0.0520239919424057, Aattn=-0.03038717806339264
Epoch = 404, ErrG = 0.028792386253674824, ErrTri = 0.06831564009189606, ErrA = 0.674206554889679, ErrD = 0.6595599465072155, Gattn=-0.0102873919531703, Dattn=-0.05208193138241768, Aattn=-0.030543651431798935
Epoch = 406, ErrG = 0.004470611612002055, ErrTri = 0.03739680349826813, ErrA = 0.6678738923122486, ErrD = 0.655048793181777, Gattn=-0.0102952029556036, Dattn=-0.052107617259025574, Aattn=-0.03067343309521675
Epoch = 408, ErrG = 0.04852007826169332, ErrTri = 0.07453939318656921, ErrA = 0.6637147789200147, ErrD = 0.6592970831940571, Gattn=-0.010261110961437225, Dattn=-0.0522

Epoch = 480, ErrG = 0.04912222425142924, ErrTri = 0.05676022171974182, ErrA = 0.5488464310765266, ErrD = 0.6545193611333767, Gattn=-0.010075208730995655, Dattn=-0.08983349055051804, Aattn=-0.04358437657356262
Epoch = 482, ErrG = 0.06257649511098862, ErrTri = 0.0672198012471199, ErrA = 0.5491865649819374, ErrD = 0.6330285668373108, Gattn=-0.010065283626317978, Dattn=-0.0904770940542221, Aattn=-0.043725091964006424
Epoch = 484, ErrG = 0.30302573430041474, ErrTri = 0.026455482468008995, ErrA = 0.6374905606110891, ErrD = 0.5661847845961651, Gattn=-0.010069114156067371, Dattn=-0.0912172868847847, Aattn=-0.043923716992139816
Epoch = 486, ErrG = 0.470584262162447, ErrTri = 0.03002079948782921, ErrA = 0.6097608841955662, ErrD = 0.5843923129141331, Gattn=-0.010053095407783985, Dattn=-0.09177210927009583, Aattn=-0.043994344770908356
Epoch = 488, ErrG = 0.1825961054613193, ErrTri = 0.02878015674650669, ErrA = 0.571100685124596, ErrD = 0.5436303373426199, Gattn=-0.010031440295279026, Dattn=-0.0920

Epoch = 560, ErrG = 0.36791494488716125, ErrTri = 0.22656866908073425, ErrA = 0.5459384309748808, ErrD = 0.6169107723981142, Gattn=-0.00985210295766592, Dattn=-0.10228914022445679, Aattn=-0.04870302230119705
Epoch = 562, ErrG = -0.04202384750048319, ErrTri = 0.11790540814399719, ErrA = 0.557707667350769, ErrD = 0.5708788732687632, Gattn=-0.009852294810116291, Dattn=-0.10251547396183014, Aattn=-0.04881814867258072
Epoch = 564, ErrG = 0.4410789559284846, ErrTri = 0.2288297861814499, ErrA = 0.5224854654322068, ErrD = 0.6153865456581116, Gattn=-0.009843547828495502, Dattn=-0.10289381444454193, Aattn=-0.049043238162994385
Epoch = 566, ErrG = 0.4450821578502655, ErrTri = 0.17591990530490875, ErrA = 0.41933204730351764, ErrD = 0.6039719227701426, Gattn=-0.009832552634179592, Dattn=-0.10308627039194107, Aattn=-0.0490964911878109
Epoch = 568, ErrG = 0.03848965217669805, ErrTri = 0.10629246383905411, ErrA = 0.4614044427871704, ErrD = 0.5783117612202963, Gattn=-0.009824334643781185, Dattn=-0.1033

Epoch = 640, ErrG = 0.10186848292748134, ErrTri = 0.05243108421564102, ErrA = 0.45602838198343915, ErrD = 0.5280692875385284, Gattn=-0.009716290980577469, Dattn=-0.11052778363227844, Aattn=-0.05559198930859566
Epoch = 642, ErrG = -0.008460988601048788, ErrTri = 0.11227580904960632, ErrA = 0.43526018410921097, ErrD = 0.6677882671356201, Gattn=-0.009705569595098495, Dattn=-0.11067089438438416, Aattn=-0.055719271302223206
Epoch = 644, ErrG = -0.0007916390895843506, ErrTri = 0.12355789542198181, ErrA = 0.434973140557607, ErrD = 0.5586033860842387, Gattn=-0.009694783017039299, Dattn=-0.1108897253870964, Aattn=-0.055882662534713745
Epoch = 646, ErrG = 0.22210735827684402, ErrTri = 0.04333358258008957, ErrA = 0.4343907907605171, ErrD = 0.525574350108703, Gattn=-0.0097075579687953, Dattn=-0.11107366532087326, Aattn=-0.05607428029179573
Epoch = 648, ErrG = 0.2745477110147476, ErrTri = 0.1526755839586258, ErrA = 0.368615984916687, ErrD = 0.4923664319018523, Gattn=-0.009697654284536839, Dattn=-0.

# SA GaitGAN and triplet (onlineloss)

In [None]:
%load_ext autoreload
%autoreload 2
from torch.autograd import grad, Variable

#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)
        
epoches = 700
glr = 0.00001
dlr = 0.00004
real_label = 1
fake_label = 0
batchSize = 32
target = '090'
lambda_gp = 0
beta1 = 0
beta2 = 0.9
margin = 10
n_g = 2
n_d = 0

netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)


optimG = optim.Adam(netg.parameters(), lr=glr, betas=(beta1, beta2))
optimD = optim.Adam(netd.parameters(), lr=dlr, betas=(beta1, beta2))
optimA = optim.Adam(neta.parameters(), lr=dlr, betas=(beta1, beta2))
# optimG = optim.RMSprop(netg.parameters(), lr=lr)
# optimD = optim.RMSprop(netd.parameters(), lr=lr)
# optimA = optim.RMSprop(neta.parameters(), lr=lr)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, batchsize = {}, beta1={}, beta2={}, n_d = {}, n_g={} target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, batchSize, beta1, beta2, n_d, n_g, target, lambda_gp))

low_loss = 10

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
#     alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    alpha = th.rand((batchSize, 1, 1, 1)).to(device).to(th.float32)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = D(interpolates)
#     fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    gradients = grad(outputs=d_interpolates, 
                     inputs=interpolates, 
                     grad_outputs=th.ones([real_samples.shape[0],1]).to(device).requires_grad_(False),
#                      grad_outputs = fake,
                     create_graph=True, 
                     retain_graph=True, 
                     only_inputs=True)[0]
#     fake = Variable((real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
#     # Get gradient w.r.t. interpolates
#     gradients = autograd.grad(
#         outputs=d_interpolates,
#         inputs=interpolates,
#         grad_outputs=fake,
#         create_graph=True,
#         retain_graph=True,
#         only_inputs=True,
#     )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

print('Training starts')
for epoch in range(1,epoches+1):
#     for i, (ass_label, noass_label, img) in enumerate(train_loader):
    for i, (ass_label, noass_label, noass_img, img, ass_img, label_neg, label_anc, label_pos) in enumerate(train_loader):
        
        com_img = th.cat((noass_img, img, ass_img), 0)
        com_label = th.cat(( label_neg, label_anc, label_pos), 0)
        com_img = com_img.to(device).to(th.float32)
        com_label = com_label.to(device).to(th.float32)
#         print("shape",ass_label.shape,noass_label.shape,img.shape)
        

        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
        noass_img = noass_img.to(device).to(th.float32)
        img = img.to(device).to(th.float32)
        ass_img = ass_img.to(device).to(th.float32)

        if i % n_g ==0:
            # update D
            lossD = 0
            lossD_ = 0
            optimD.zero_grad()
            d_out_assreal,dr1 = netd(ass_label)
            d_loss_assreal = nn.ReLU()(1.0 - d_out_assreal).mean()

            lossD_ += d_loss_assreal
            lossD += d_loss_assreal.item()

            d_out_noassreal,dr2 = netd(noass_label)
            d_loss_noassreal = nn.ReLU()(1.0 - d_out_noassreal).mean()

            lossD_ += d_loss_noassreal
            lossD += d_loss_noassreal.item()

            fake, code = netg(img)
            d_out_fake, df1 = netd(fake.detach())
            d_loss_fake = nn.ReLU()(1.0 + d_out_fake).mean()

            lossD_ += d_loss_fake
            lossD += d_loss_fake.item()
    #         gradient_penalty = compute_gradient_penalty(netd, ass_label.data, fake.data)
            lossD_ = lossD_/3
            lossD_.backward()
            optimD.step()


            # update A
            lossA = 0
            lossA_ = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            faked, code = netg(img)
            faked = th.cat((img, faked.detach()), 1)

            d_out_assreal,dr1 = neta(assd)
            d_loss_assreal = nn.ReLU()(1.0 - d_out_assreal).mean()
            lossA += d_loss_assreal.item()
            lossA_ += d_loss_assreal

            d_out_noassreal,dr2 = neta(noassd)
            d_loss_noassreal = nn.ReLU()(1.0 + d_out_noassreal).mean()

            lossA_ += d_loss_noassreal
            lossA += d_loss_noassreal.item()

            d_out_faked, df3 = neta(faked)
            d_loss_faked = nn.ReLU()(1.0 + d_out_faked).mean()

            lossA_ += d_loss_faked
            lossA += d_loss_faked.item()
    #         gradient_penalty = compute_gradient_penalty(neta, assd.data, faked.data)
            lossA_ = lossA_/3
            lossA_.backward()
            optimA.step()


            
        # update G
        if i % n_d == 0:
            lossG = 0
            lossG_ = 0
            optimG.zero_grad()
            fake, A= netg(img)
            g_out_fake,_ = netd(fake)
            g_loss_fake = - g_out_fake.mean()

            lossG += g_loss_fake.item()
            lossG_ += g_loss_fake

            faked = th.cat((img, fake), 1)
            g_out_faked,_ = neta(faked)
            g_loss_faked = - g_out_faked.mean()
            lossG += g_loss_faked.item()
            lossG_ += g_loss_faked

    #         # pytorch tripletloss
    #         fake_ass, P = netg(ass_img)
    #         fake_noass, N = netg(noass_img)
    #         lossTriplet = F.triplet_margin_loss(A, P, N, margin = margin)
    #         lossG_ += lossTriplet
    #         lossG += lossTriplet.item()
    # #         lossTriplet.backward()
    
            ## new tripletloss
#             fake_ass, P = netg(ass_img)
#             fake_noass, N = netg(noass_img)
#             loss_fn = TripletLoss(margin)
#             lossTriplet = loss_fn(A, P, N)
#             lossG += lossTriplet.item()
#             lossG_ += lossTriplet

            ## new onlinetripletloss
            __, com = netg(com_img)
            loss_fn = OnlineTripletLoss(margin, SemihardNegativeTripletSelector(margin))
            lossTriplet,len_triplet = loss_fn(com, com_label)
            lossG += lossTriplet.item()
            lossG_ += lossTriplet
    #         lossTriplet.backward()


            lossG_ = lossG_/3
            lossG_.backward(retain_graph=True)
            optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake,_ = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/3, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/3, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/3, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrTri = {}, ErrA = {}, ErrD = {}, Gattn={}, Dattn={}, Aattn={}'.format(
            epoch, lossG/3,lossTriplet.item(), lossA/3, lossD/3, netg.attn.gamma.item(), netd.attn.gamma.item(), neta.attn.gamma.item()
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/3)<low_loss:  
        low_loss = lossG/3
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/3, lossA/3, lossD/3
        ))           