In [None]:
import argparse
import os
import numpy as np
import math
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.utils as vutils
import numpy as np
import torch
import torch.nn as nn
import os
import csv
from PIL import Image, ImageEnhance
import numbers
import pandas as pd
import torchvision.transforms as transforms
import torch.utils.data as data
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import autograd

img_shape = (1, 128, 128)

cuda = True if torch.cuda.is_available() else False

def weights_init_uniform(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform(m.weight.data)
        m.bias.data.fill_(0.1)

    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform(m.weight.data)
        m.bias.data.fill_(0.1)


# block 32 64 128 256 represent how many channel for conv3 in this block 
# and the image dimension after convolution is presented
class Block32(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(
            # input --> 1 * 128 * 128
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=0),  # 32 * 126 * 126
            # normalization
            nn.BatchNorm2d(32, eps=1e-5),
            #  activate function
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=0),  # 32 * 124 * 124
            nn.BatchNorm2d(32, eps=1e-5),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        output = self.block(x)
        return output


class Block64(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0),  # 64 * 61 * 61
            nn.BatchNorm2d(64, eps=1e-5),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),  # 64 * 59 * 59
            nn.BatchNorm2d(64, eps=1e-5),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),  # 64 * 57 * 57
            nn.BatchNorm2d(64, eps=1e-5),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        output = self.block(x)
        return output


class Block128(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0),  # 128 * 28 * 28
            nn.BatchNorm2d(128, eps=1e-5),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),  # 128 * 26 * 26
            nn.BatchNorm2d(128, eps=1e-5),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),  # 128 * 24 * 24
            nn.BatchNorm2d(128, eps=1e-5),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),  # 128 * 22 * 22
            nn.BatchNorm2d(128, eps=1e-5),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        output = self.block(x)
        return output


class Block256(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=0),  # 256 * 10 * 10
            nn.BatchNorm2d(256, eps=1e-5),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0),  # 256 * 8 * 8
            nn.BatchNorm2d(256, eps=1e-5),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0),  # 256 * 6 * 6
            nn.BatchNorm2d(256, eps=1e-5),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0),  # 256 * 4 * 4
            nn.BatchNorm2d(256, eps=1e-5),
            nn.ReLU(inplace=True)
            # nn.Conv2d(256, 256, kernel_size=4, stride=1, padding=0),  # 256 * 1 * 1
            # nn.BatchNorm2d(256, eps=1e-5),
            # nn.ReLU(inplace=True)
        )

    def forward(self, x):
        output = self.block(x)
        return output


# channelattention
class ChannelAttention(nn.Module):
    def __init__(self, kernel):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(kernel, kernel // 8, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(kernel // 8, kernel, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return x * self.sigmoid(out)


# spatialattention(I did try to combine CA with SA but the result was not very good as espected)
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out, max_out], dim=1)
        y = self.conv1(y)
        return x * self.sigmoid(y)


class Siamese_VGG(nn.Module):
    def __init__(self, drop, use_w_init=True):
        super().__init__()

        # blocks
        self.block32 = Block32()
        self.block64 = Block64()
        self.block128 = Block128()
        self.block256 = Block256()

        # only one fully connecter was used
        self.final = nn.Sequential(
            nn.Dropout(p=drop),
            # combine all features to 2 dimention because of binaire classfication
            nn.Linear(32 + 64 + 128 + 256, 2)
        )

        # initialization of the weights and bias 
        if use_w_init:
            self.apply(weights_init_uniform)

        # channel attention
        self.attention32 = ChannelAttention(32)
        self.attention64 = ChannelAttention(64)
        self.attention128 = ChannelAttention(128)
        self.attention256 = ChannelAttention(256)

    def forward(self, x1, x2):
        # forward propagation
        o1 = self.block32(x1)
        o2 = self.block32(x2)
        att_o1 = self.attention32(o1)
        att_o2 = self.attention32(o2)
        # use GAP to extract the features after the block
        feats1_1 = torch.nn.functional.adaptive_avg_pool2d(o1, (1, 1))
        feats1_2 = torch.nn.functional.adaptive_avg_pool2d(o2, (1, 1))
        # use add bit-by-bit to combine 2 GAP
        feats1 = feats1_1.view((feats1_1).size(0), -1) + feats1_2.view((feats1_2).size(0), -1)

        # now input = output from last block + channel attention
        o1 = self.block64(o1 + att_o1)
        o2 = self.block64(o2 + att_o2)
        att_o1 = self.attention64(o1)
        att_o2 = self.attention64(o2)
        feats2_1 = torch.nn.functional.adaptive_avg_pool2d(o1, (1, 1))
        feats2_2 = torch.nn.functional.adaptive_avg_pool2d(o2, (1, 1))
        feats2 = feats2_1.view((feats2_1).size(0), -1) + feats2_2.view((feats2_2).size(0), -1)

        o1 = self.block128(o1 + att_o1)
        o2 = self.block128(o2 + att_o2)
        att_o1 = self.attention128(o1)
        att_o2 = self.attention128(o2)
        feats3_1 = torch.nn.functional.adaptive_avg_pool2d(o1, (1, 1))
        feats3_2 = torch.nn.functional.adaptive_avg_pool2d(o2, (1, 1))
        feats3 = feats3_1.view((feats3_1).size(0), -1) + feats3_2.view((feats3_2).size(0), -1)

        o1 = self.block256(o1 + att_o1)
        o2 = self.block256(o2 + att_o2)
        # att_o1 = self.attention256(o1)
        # att_o2 = self.attention256(o2)
        feats4_1 = torch.nn.functional.adaptive_avg_pool2d(o1, (1, 1))
        feats4_2 = torch.nn.functional.adaptive_avg_pool2d(o2, (1, 1))
        feats4 = feats4_1.view((feats4_1).size(0), -1) + feats4_2.view((feats4_2).size(0), -1)  #

        # combine all GAPs to get the final features
        feats_final = torch.cat([feats1, feats2], 1)
        feats_final = torch.cat([feats_final, feats3], 1)
        feats_final = torch.cat([feats_final, feats4], 1)

        return self.final(feats_final), feats_final
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.main = nn.Sequential(
            # k+(in-1)s-2p
            nn.ConvTranspose2d(2048, 128*32, 4, 1, 0, bias=False), # 4096, 4, 4
            nn.BatchNorm2d(128 * 32),
            nn.ReLU(True),
            nn.ConvTranspose2d(128*32, 128 * 16, 4, 2, 1, bias=False), # 2048, 8, 8 # 8
            nn.BatchNorm2d(128 * 16),
            nn.ReLU(True),
            # input is Z, going into a convolution
            nn.ConvTranspose2d(128 * 16, 128 * 8, 4, 2, 0, bias=False), # 1024, 20, 20 # 18
            nn.BatchNorm2d(128 * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(128 * 8, 128 * 4, 4, 2, 1, bias=False), # 512, 40, 40 # 36
            nn.BatchNorm2d(128 * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(128 * 4, 128 * 2, 4, 2, 0, bias=False), # 256, 80, 80 # 74
            nn.BatchNorm2d(128 * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(128 * 2, 128, 4, 2, 1, bias=False), # 128, 160, 160 # 128, # 148
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(128, 1, 7, 2, 1, bias=False), # 1, 320, 320   # 1,299,299 # 299
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, z):
        # print(z.shape)
        img = self.main(z)
        # print(img.size())
        # img = img.view(img.shape[0], *img_shape)
        return img

    
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder1 = nn.Sequential(
            # input is 1 x 128 x 128 1*299*299
            nn.Conv2d(1, 32, 3, 2, 0, bias=False),  # 32*63*63 32*149*149
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(32, 32 * 2, 3, 2, 0, bias=False),  # 64*31*31 64*74*74
            nn.BatchNorm2d(32 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(32 * 2, 32 * 4, 3, 2, 0, bias=False),  # 128*15*15 128*36*36
            nn.BatchNorm2d(32 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(32 * 4, 32 * 8, 3, 2, 0, bias=False),  # 256*7*7 256*17*17
            nn.BatchNorm2d(32 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(32 * 8, 32 * 16, 3, 2, 0, bias=False),  # 512*3*3 512 * 8 * 8
            nn.BatchNorm2d(32 * 16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32 * 16, 32 * 32, 3, 2, 0, bias=False),  # 1024*1*1 1024 * 3* 3
            nn.BatchNorm2d(32 * 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32 * 32, 32 * 64, 3, 2, 0, bias=False),  # 2048*1*1 2048 * 1* 1
            nn.BatchNorm2d(32 * 64),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.encoder2 = nn.Sequential(
            # input is 1 x 128 x 128
            nn.Conv2d(1, 32, 3, 2, 0, bias=False),  # 32*63*63
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(32, 32 * 2, 3, 2, 0, bias=False),  # 64*31*31
            nn.BatchNorm2d(32 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(32 * 2, 32 * 4, 3, 2, 0, bias=False),  # 128*15*15
            nn.BatchNorm2d(32 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(32 * 4, 32 * 8, 3, 2, 0, bias=False),  # 256*7*7
            nn.BatchNorm2d(32 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(32 * 8, 32 * 16, 3, 2, 0, bias=False),  # 512*3*3
            nn.BatchNorm2d(32 * 16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32 * 16, 32 * 32, 3, 2, 0, bias=False),  # 1024*1*1
            nn.BatchNorm2d(32 * 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32 * 32, 32 * 64, 3, 2, 0, bias=False),  # 1024*1*1
            nn.BatchNorm2d(32 * 64),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, img):
        # print(img.size())
        output1 = self.encoder1(img)
        output2 = self.encoder2(img)
        return output1,output2
    
    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


def get_pair(I):
    s = I.size[0]
    pad = int(np.floor(s / 3))
    ps = 128
    l = I.crop([0, pad, ps, pad + ps])
    m = I.crop([s - ps, pad, s, pad + ps])
    m = m.transpose(Image.FLIP_LEFT_RIGHT)

    return l, m


class KneeGradingDataset(data.Dataset):
    def __init__(self, dataset, transform, augment, stage='train'):
        super(KneeGradingDataset, self).__init__()
        self.dataset = dataset
        # self.train_files = train_files
        self.transform = transform
        self.augment = augment
        self.stage = stage
        if self.stage == 'train':
            self.images, self.labels = self.load_csv("paper_0_2_100000.csv")
        if self.stage == 'valid':
            self.images, self.labels = self.load_csv("trainData_0_2.csv")

    def load_csv(self, filename):
        images, labels = [], []
        with open(os.path.join(self.dataset, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img_0, l_0, img_2, l_2 = row
                # print(img_0, img_2)
                label_0 = int(l_0)
                label_2 = int(l_2)       
                images.append([img_0, img_2])
                labels.append([label_0,label_2])
                #print(labels)
                #break
        return images, labels

    def __getitem__(self, index):
        img_0, img_2, label_0, label_2 = self.images[index][0], self.images[index][1], self.labels[index][0],self.labels[index][1]
        fname = os.path.join(self.dataset, img_0[0], img_0)
        fname2 = os.path.join(self.dataset, img_2[0], img_2)
        img = Image.open(fname)
        img2 = Image.open(fname2)
        img = self.augment(img)
        img2 = self.augment(img2)
        
        l, m = get_pair(img)
        l2, m2 = get_pair(img2)
        
        l = self.transform(l)
        m = self.transform(m)
        
        l2 = self.transform(l2)
        m2 = self.transform(m2)
        
        img = self.transform(img)
        img2 = self.transform(img2)
        # print(label_0)
        return img, img2, l, m, l2, m2, label_0, label_2

    def __len__(self):
        # return len(self.train_files)
        return 100000


class CenterCrop(object):
    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        w, h = img.size
        tw, th, = self.size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return img.crop((x1, y1, x1 + tw, y1 + th))


class distance_loss(nn.Module):
    def __init__(self):
        super(distance_loss, self).__init__()

    def forward(self, x1,x2):
        mean_no_related = torch.mean(x1)
        mean_key = torch.mean(x2)
        var_no_related = torch.var(x1)
        var_key = torch.var(x2)
        
        loss = (var_no_related + var_key) / torch.square(mean_no_related - mean_key)

        return loss.mean()

def get_pair_Tensor(I):
    tensor_l = torch.flip(I[:,:,100:228,171:299][0].cpu(), [-1]).unsqueeze(1)
    tensor_m = I[:,:,100:228,0:128][0].unsqueeze(1)
    for i in range(I.shape[0]-1):
        tensor_l = torch.cat((tensor_l, torch.flip(I[:,:,100:228,171:299][i+1].cpu(), [-1]).unsqueeze(1)), 0)
        tensor_m = torch.cat((tensor_m, I[:,:,100:228,0:128][i+1].unsqueeze(1)), 0) 

    return tensor_l,tensor_m

# Initialize generator and discriminator
# discriminator = Discriminator()
discriminator2 = Siamese_VGG(0.2)
encoder = Encoder()
decoder = Decoder()

CUDA = 0

if cuda:
    #discriminator.cuda(CUDA)
    #discriminator.apply(weights_init)
    discriminator2.cuda(CUDA)
    discriminator2.apply(weights_init)
    encoder.cuda(CUDA)
    encoder.apply(weights_init)
    decoder.cuda(CUDA)
    decoder.apply(weights_init)

# Configure data loader
train_cats_length = []
val_length = []
test_length = []

transf_tens = transforms.Compose([
    transforms.ToTensor()
])
augment_transforms = transforms.Compose([
    CenterCrop(299)
])

train_ds = KneeGradingDataset('./OAI_m', transform=transf_tens, augment=augment_transforms, stage='train')
train_loader = data.DataLoader(train_ds, batch_size=64, shuffle=True)

#optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=0.0002)
optimizer_D_ = torch.optim.Adam(discriminator2.parameters(),lr=0.0001)
optimizer_E = torch.optim.Adam(encoder.parameters(),lr=0.0001)
optimizer_DE = torch.optim.Adam(decoder.parameters(),lr=0.0001)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

batches_done = 0
device = torch.device("cuda:" + str(CUDA) if (torch.cuda.is_available()) else "cpu")

loss_Ds = []
loss_Gs = []
mse = nn.MSELoss()
loss_Dis = distance_loss()

loss_Ds = []
loss_Rs = []
loss_Distances = []
loss_CEs = []

for epoch in range(100):
    for i, (img, img2, l, m, l2, m2, label_0, label_2) in enumerate(train_loader):
        # print(label_0)
        img0 = Variable(img.cuda(CUDA))
        img2 = Variable(img2.cuda(CUDA))
        l0 = Variable(l.cuda(CUDA))
        l2 = Variable(l2.cuda(CUDA))
        m0 = Variable(m.cuda(CUDA))
        m2 = Variable(m2.cuda(CUDA))
        label_0 = Variable(label_0.long().cuda(CUDA))
        label_2 = Variable(label_2.long().cuda(CUDA))
        # real_imgs = torch.cat((real_imgs_l,real_imgs_r),dim=0)    
        
        vector_Key0 = encoder(img0)[0]   
        vector_Unrelated0 = encoder(img0)[1]        
        vector_img0 = vector_Key0 + vector_Unrelated0
        
        vector_Key2 = encoder(img2)[0]
        vector_Unrelated2 = encoder(img2)[1]        
        vector_img2 = vector_Key2 + vector_Unrelated2
        
        vector_Key_Exchange_img0 = vector_Key2 + vector_Unrelated0
        vector_Key_Exchange_img2 = vector_Key0 + vector_Unrelated2
        
        R_img0 = decoder(vector_img0)
        R_img2 = decoder(vector_img2)
        
        Key_Exchange_img0 = decoder(vector_Key_Exchange_img0)
        Key_Exchange_img2 = decoder(vector_Key_Exchange_img2)
        
        # print(Key_Exchange_img0.size())
        Key_Exchange_img0_l = get_pair_Tensor(Key_Exchange_img0)[0].cuda(CUDA)
        Key_Exchange_img0_m = get_pair_Tensor(Key_Exchange_img0)[1].cuda(CUDA)
        
        Key_Exchange_img2_l = get_pair_Tensor(Key_Exchange_img2)[0].cuda(CUDA)
        Key_Exchange_img2_m = get_pair_Tensor(Key_Exchange_img2)[1].cuda(CUDA)
        
        optimizer_D_.zero_grad()
        
        for p in discriminator2.parameters():
            p.requires_grad = True
            
        loss_D1 = F.cross_entropy(discriminator2(l0,m0)[0],label_0)
        loss_D2 = F.cross_entropy(discriminator2(l2,m2)[0],label_2)
        loss_D = loss_D1 + loss_D2
        loss_Ds.append(loss_D)
        loss_D.backward()
        optimizer_D_.step()
        
        #loss_Ds.append(loss_D)
        
        optimizer_E.zero_grad()
        optimizer_DE.zero_grad()
        
        for p in discriminator2.parameters():
            p.requires_grad = False
            
        loss_distance0 = loss_Dis(vector_Key0, vector_Unrelated0)
        loss_distance2 = loss_Dis(vector_Key2, vector_Unrelated2)
        
        loss_distance = loss_distance0 + loss_distance2
        loss_Distances.append(loss_distance)
        
        loss_R0 = mse(img0, R_img0)
        loss_R2 = mse(img2, R_img2)
        
        loss_R = loss_R0 + loss_R2
        loss_Rs.append(loss_R)
        
        loss_ce0 = F.cross_entropy(discriminator2(Key_Exchange_img0_l, Key_Exchange_img0_m)[0],label_2)
        loss_ce2 = F.cross_entropy(discriminator2(Key_Exchange_img2_l, Key_Exchange_img2_m)[0],label_0)
        
        loss_ce = loss_ce0 + loss_ce2
        loss_CEs.append(loss_ce)
        
        loss_total = loss_R + 0.001 * loss_distance + 0.01 * loss_ce
        loss_total.backward()
        
        #loss_Gs.append(loss_G)
        
        optimizer_E.step()
        optimizer_DE.step()
        
        print("[Epoch %d/%d] [Batch %d/%d] [loss_R: %f][loss_D: %f][loss_Dis: %f][loss_CE: %f]" % (epoch, 100, batches_done % len(train_loader), len(train_loader),loss_R.item(), loss_D.item(),loss_distance.item(),loss_ce.item()))
        batches_done += 1
        
    if epoch % 10 == 0 and epoch != 0:
        torch.save(encoder, 'KECAE_encoder_epoch_' + str(batches_done) + '.pth')
        torch.save(decoder,'KECAE_decoder_epoch_' + str(batches_done) + '.pth')
        torch.save(discriminator2,'KECAE_siamese_epoch_' + str(batches_done) + '.pth')