# train

In [None]:
%load_ext autoreload
%autoreload 2

# speed up the loading of the training data
import cv2
import os
import numpy as np
import torch as th
import itertools
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from onlineTripletloss import *
from selector import *
from model_SAGAN1_1 import NetG, NetD, NetA
from dataset2Loader_newtriplet import CASIABDataset
import torch.optim as optim
import visdom
from torchvision.utils import make_grid

Data_Dir = '../GaitRecognition/GEI_CASIA_B/gei/'
Model_Name = 'Model_64x64_TripletSAGAN_90_trial8'
Model_dir = './Transform_Model/'+ Model_Name
if not os.path.isdir(Model_dir):
    os.mkdir(Model_dir)
    
    
#python -m visdom.server
vis = visdom.Visdom(port=8097)
win = None
win1 = None
netg = NetG(nc=1)
netd = NetD(nc=1)
neta = NetA(nc=1)
device = th.device("cuda:1")

# weights init
all_mods = itertools.chain()
all_mods = itertools.chain(all_mods, [
    list(netg.children())[0].children(),
    list(netd.children())[0].children(),
    list(neta.children())[0].children()
])
for mod in all_mods:
    if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
#         init.xavier_normal_(tensor, gain=1.)
        init.normal_(mod.weight, 0.0, 0.02)
    elif isinstance(mod, nn.BatchNorm2d):
        init.normal_(mod.weight, 1.0, 0.02)
        init.constant_(mod.bias, 0.0)
        
epoches = 700
glr = 0.00001
dlr = 0.00004
# lr = 0.00002
real_label = 1
fake_label = 0
batchSize = 32
n_critic = 0
target = '090'
lambda_gp = 0
beta1 = 0
beta2 = 0.9
margin = 10
g_k = 2

netg = netg.to(device)
netd = netd.to(device)
neta = neta.to(device)
netg.train()
netd.train()
neta.train()
dataset = CASIABDataset(data_dir=Data_Dir,target=target)
train_loader = th.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=2, pin_memory=False)


optimG = optim.Adam(netg.parameters(), lr=glr, betas=(beta1, beta2))
optimD = optim.Adam(netd.parameters(), lr=dlr, betas=(beta1, beta2))
optimA = optim.Adam(neta.parameters(), lr=dlr, betas=(beta1, beta2))
# optimG = optim.RMSprop(netg.parameters(), lr=lr)
# optimD = optim.RMSprop(netd.parameters(), lr=lr)
# optimA = optim.RMSprop(neta.parameters(), lr=lr)

print("write parameter log...")
with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, margin = {}, dlr = {}, glr={}, g_k={}, batchsize = {}, beta1={}, beta2={}, n_critic = {}, target={},lambda_gp={} \n'.format(
            epoches, margin, dlr, glr, g_k, batchSize, beta1, beta2, n_critic, target, lambda_gp))

low_loss = 10

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = th.rand((batchSize, 1, 1, 1)).to(device).to(th.float32)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = D(interpolates)
    gradients = grad(outputs=d_interpolates, 
                     inputs=interpolates, 
                     grad_outputs=th.ones([real_samples.shape[0],1]).to(device).requires_grad_(False),
#                      grad_outputs = fake,
                     create_graph=True, 
                     retain_graph=True, 
                     only_inputs=True)[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

print('Training starts')
for epoch in range(1,epoches+1):
#     for i, (ass_label, noass_label, img) in enumerate(train_loader):
    for i, (ass_label, noass_label, noass_img, img, ass_img) in enumerate(train_loader):
#        
        ass_label = ass_label.to(device).to(th.float32)
        noass_label = noass_label.to(device).to(th.float32)
        noass_img = noass_img.to(device).to(th.float32)
        img = img.to(device).to(th.float32)
        ass_img = ass_img.to(device).to(th.float32)

        if i % g_k ==0:
            # update D
            lossD = 0
            lossD_ = 0
            optimD.zero_grad()
            d_out_assreal,dr1 = netd(ass_label)
            d_loss_assreal = nn.ReLU()(1.0 - d_out_assreal).mean()

            lossD_ += d_loss_assreal
            lossD += d_loss_assreal.item()

            d_out_noassreal,dr2 = netd(noass_label)
            d_loss_noassreal = nn.ReLU()(1.0 - d_out_noassreal).mean()

            lossD_ += d_loss_noassreal
            lossD += d_loss_noassreal.item()

            fake, code = netg(img)
            d_out_fake, df1 = netd(fake.detach())
            d_loss_fake = nn.ReLU()(1.0 + d_out_fake).mean()

            lossD_ += d_loss_fake
            lossD += d_loss_fake.item()
    #         gradient_penalty = compute_gradient_penalty(netd, ass_label.data, fake.data)
            lossD_ = lossD_/3
            lossD_.backward()
            optimD.step()

    #         for p in netd.parameters():
    #             p.data.clamp_(-0.01, 0.01)

            # update A
            lossA = 0
            lossA_ = 0
            optimA.zero_grad()
            assd = th.cat((img, ass_label), 1)
            noassd = th.cat((img, noass_label), 1)
            faked, code = netg(img)
            faked = th.cat((img, faked.detach()), 1)

            d_out_assreal,dr1 = neta(assd)
            d_loss_assreal = nn.ReLU()(1.0 - d_out_assreal).mean()
            lossA += d_loss_assreal.item()
            lossA_ += d_loss_assreal

            d_out_noassreal,dr2 = neta(noassd)
            d_loss_noassreal = nn.ReLU()(1.0 + d_out_noassreal).mean()

            lossA_ += d_loss_noassreal
            lossA += d_loss_noassreal.item()

            d_out_faked, df3 = neta(faked)
            d_loss_faked = nn.ReLU()(1.0 + d_out_faked).mean()

            lossA_ += d_loss_faked
            lossA += d_loss_faked.item()
    #         gradient_penalty = compute_gradient_penalty(neta, assd.data, faked.data)
            lossA_ = lossA_/3
            lossA_.backward()
            optimA.step()

#         for p in neta.parameters():
#             p.data.clamp_(-0.01, 0.01)
            
        # update G
#         if i % n_critic == 0:
        lossG = 0
        lossG_ = 0
        optimG.zero_grad()
        fake, A= netg(img)
        g_out_fake,_ = netd(fake)
        g_loss_fake = - g_out_fake.mean()

        lossG += g_loss_fake.item()
        lossG_ += g_loss_fake

        faked = th.cat((img, fake), 1)
        g_out_faked,_ = neta(faked)
        g_loss_faked = - g_out_faked.mean()
        lossG += g_loss_faked.item()
        lossG_ += g_loss_faked
        
        # constrain on generator
        fake_ass, P = netg(ass_img)
        fake_noass, N = netg(noass_img)
        lossTriplet = F.triplet_margin_loss(A, P, N, margin = margin)
        lossG_ += lossTriplet
        lossG += lossTriplet.item()
#         lossTriplet.backward()
        
        lossG_ = lossG_/3
        lossG_.backward(retain_graph=True)
        optimG.step()

        if i % 20 == 0:
            with th.no_grad():
                netg.eval()  #切換
                fake,_ = netg(img) 
                netg.train() #切換回去
            fake = (fake + 1) / 2 * 255
            real = (ass_label + 1) / 2 * 255
            ori = (img + 1) / 2 * 255
            al = th.cat((fake, real, ori), 2)
            display = make_grid(al, 20).cpu().numpy()
            if win1 is None:
                win1 = vis.image(display,
                                 opts=dict(title="train", caption='train'))
            else:
                vis.image(display, win=win1)

    if epoch % 2==0:   #2   
        if win is None:
            win = vis.line(X=np.array([[epoch, epoch,
                                        epoch]]),
                           Y=np.array([[lossG/3, lossA/3, lossD/3]]),
                           opts=dict(
                               title=Model_Name,
                               ylabel='loss',
                               xlabel='epochs',
                               legend=['lossG', 'lossA', 'lossD']
                           ))
        else:
            vis.line(X=np.array([[epoch, epoch,
                                  epoch]]),
                     Y=np.array([[lossG/3, lossA/3, lossD/3]]),
                     win=win,
                     update='append')

        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/3, lossA/3, lossD/3
        ))
        print('Epoch = {}, ErrG = {}, ErrTri = {}, ErrA = {}, ErrD = {}, Gattn={}, Dattn={}, Aattn={}'.format(
            epoch, lossG/3,lossTriplet.item(), lossA/3, lossD/3, netg.attn.gamma.item(), netd.attn.gamma.item(), neta.attn.gamma.item()
        ))
            
    if (epoch>= 300) and epoch%20==0:      
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/snapshot'+ Model_Name +'_%d.t7' % epoch)
        
    if (epoch>= 550) and (lossG/3)<low_loss:  
        low_loss = lossG/3
        state = {
            'netA': neta.state_dict(),
            'netG': netg.state_dict(),
            'netD': netd.state_dict()
        }
        th.save(state, Model_dir+'/lowest_snapshot'+ Model_Name +'_%d.t7' % epoch)
        with open(Model_dir+"/snapshot_log.txt", "a") as myfile:
            myfile.write('lower_lossG Epoch = {}, ErrG = {}, ErrA = {}, ErrD = {} \n'.format(
            epoch, lossG/3, lossA/3, lossD/3
        ))    

# dataset loader

In [None]:
import torch as th
import torch.utils.data as data
import cv2
import numpy as np
import os
import glob
import random
th.cuda.manual_seed(29)
th.manual_seed(29)
np.random.seed(0)
random.seed(0)

def loadImage(path):
#     print(path)
    inImage = cv2.imread(path, 0)
    info = np.iinfo(inImage.dtype) #(min = 0 ,max = 255),Maximum value of given dtype.
    inImage = inImage.astype(np.float) / info.max # 歸一化
#     inImage = inImage.astype(np.float) / 127.5-1 # 歸一化

    iw = inImage.shape[1]
    ih = inImage.shape[0]
    if iw <= ih:
        inImage = cv2.resize(inImage, (64, int(64 * ih/iw)))
#         print(inImage.shape)
    else:
        inImage = cv2.resize(inImage, (int(64 * iw / ih), 64)) #(160,80)->(128,64)
    inImage = inImage[0:64, 0:64]

#     inImage = cv2.resize(inImage, (48, 160))
    img = th.from_numpy(2 * inImage - 1).unsqueeze(0) # unsqueeze(0) 在第0維多增加一維 代表灰階，且輸入為-1~1之間
    # img shape = [1,64,64]  
#     print(th.max(img),th.min(img),th.mean(img))
    return img


class CASIABDataset(data.Dataset):
    def __init__(self, data_dir, target):
        self.data_dir = data_dir
        self.ids = np.arange(1, 63) #1-62
        self.cond = ['bg-01', 'bg-02', 'cl-01', 'cl-02',
                     'nm-01', 'nm-02', 'nm-03', 'nm-04',
                     'nm-05', 'nm-06']
#         self.angles = ['000', '018', '036', '054', '072',
#                        '108', '126', '144', '162', '180']  # originally
        self.angles = ['000', '018', '036', '054', '072', '090',
                       '108', '126', '144', '162', '180']
        self.n_id = 62
        self.n_cond = len(self.cond)
        self.n_ang = len(self.angles)
        print("n_con=",self.n_cond,',n_ang=',self.n_ang)
        self.target = target
        print('target = ',self.target)
        

    def __getitem__(self, index):
            # r1 is GT target
            # r2 is irrelevant GT target
            # r3 is source image   
        while(True):           
            id1 = th.randint(0, self. n_id, (1,)).item() + 1
            id1 = '%03d' % id1
            cond1 = th.randint(4, self.n_cond, (1,)).item()
            cond1 = self.cond[int(cond1)]
            r1 = id1 + '/' + cond1 + '/' + id1 + '-' + cond1 + '-' + self.target+'.png'                     
            if os.path.exists(self.data_dir + r1):
                break
#         print('r1=,',r1)


        id2 = id1 
        while(True):
            id2 = th.randint(0, self. n_id, (1,)).item() + 1
            id2 = '%03d' % id2
            cond2 = th.randint(4, self.n_cond, (1,)).item()
            cond2 = int(cond2)
            cond2 = self.cond[cond2]
            r2 = id2 + '/' + cond2 + '/' +  id2 + '-' + cond2 + '-' + self.target+'.png'
            
            cond2_1 = th.randint(4, self.n_cond, (1,)).item() # not all conditions
            cond2_1 = self.cond[int(cond2_1)]
            angle2_1 = th.randint(0, self.n_ang, (1,)).item()
            angle2_1 = self.angles[int(angle2_1)]
            r2_1 = id2 + '/' + cond2_1 + '/' + id2 + '-' + cond2_1 + '-' + angle2_1+'.png'
            if os.path.exists(self.data_dir + r2) and (id2!=id1) and os.path.exists(self.data_dir + r2_1) and (r2!=r2_1):
                break
#         print('r2=,',r2)
#         print('r2_1=,',r2_1)

        while True:
            angle = th.randint(0, self.n_ang, (1,)).item()
            angle = int(angle)
            angle = self.angles[angle]
            cond3 = th.randint(0, self.n_cond, (1,)).item()
            cond3 = int(cond3)
            cond3 = self.cond[cond3]
            r3 = id1 + '/' + cond3 + '/'  +  id1 + '-' + cond3 + '-' + angle + '.png'
            
            cond3_1 = th.randint(4, self.n_cond, (1,)).item() #not all con
            cond3_1 = self.cond[int(cond3_1)]
            angle3_1 = th.randint(0, self.n_ang, (1,)).item()
            angle3_1 = self.angles[int(angle3_1)]
            r3_1 = id1 + '/' + cond3_1 + '/' + id1 + '-' + cond3_1 + '-' + angle3_1+'.png'
            if os.path.exists(self.data_dir + r3) and os.path.exists(self.data_dir + r3_1) and (r3!=r3_1):
                break
#         print('r3=,',r3)
#         print('r3_1=,',r3_1,'\n')

        img1 = loadImage(self.data_dir + r1)
        img2 = loadImage(self.data_dir + r2)
        img2_1 = loadImage(self.data_dir + r2_1)
        img3 = loadImage(self.data_dir + r3)
        img3_1 = loadImage(self.data_dir + r3_1)
        return img1, img2, img2_1, img3, img3_1
    
    def __len__(self):
        total_len = 0
        total_len = len(glob.glob(self.data_dir))
        return 6400