## Model: Regular Autoencoder with contractive loss (Implemented by 2 ways)

In [10]:
import torch
from torch import nn
from torch import optim
from torch.autograd import grad, Variable,functional
import torchvision
import torchvision.transforms as transforms


# ---------------------------------Model----------------------------------------------
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
class Reshape(nn.Module):
    def forward(self, input):
        return input.view(-1,512,8,8)

class Regular_AE(nn.Module):
    def __init__(self,laten_dims=64):
        super(Regular_AE, self).__init__()
        self.laten_dims = laten_dims
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 128, 4, stride=2, padding=1), 
            nn.BatchNorm2d(128),           
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),     
            nn.BatchNorm2d(256),     
            nn.ReLU(),
			nn.Conv2d(256, 512, 4, stride=2, padding=1),  
            nn.BatchNorm2d(512),         
            nn.ReLU(),
            Flatten(),
            nn.Linear(4*4*512,self.laten_dims)
        )
        self.decoder = nn.Sequential(
            nn.Linear(laten_dims,8*8*512),
            Reshape(),
			nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.BatchNorm2d(256),
            nn.ReLU(),
			nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 3, 1, stride=1, padding=0),   # [batch, 3, 32, 32]
            nn.Sigmoid(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded,decoded
    


In [11]:
# calculated element by element
def ctr_lossv1(x, encoding):
    contractive_loss = 0.0
    for encoding_i in encoding:
        grads = grad(encoding_i,x,create_graph=True)    
        list_grad.append(grads[0]) 
        contractive_loss  += sum([grd.norm()**2 for grd in grads[0]])
    return contractive_loss

def ctr_sum(x_batch,encoding_batch):
    sum_ctr = 0.0
    for i in range(len(x_batch)):
        sum_ctr += ctr_lossv1(x_batch,encoding_batch[i])
    return sum_ctr
# calculated by pytorch jacobian function

def ctr_lossv3(x,function):
    matrix = functional.jacobian(function,x,create_graph=True)
    return sum([grd.norm()**2 for grd in matrix])



### Test the loss functions


In [12]:
net = Regular_AE()
list_grad = []

criterion = nn.BCELoss()
criterion2 = nn.CrossEntropyLoss()
num_epochs = 100
wd = 5e-04
optimizer = optim.Adam(net.parameters(),weight_decay=wd, lr=0.001)


# for i, data in enumerate(trainloader,0):
x = Variable(torch.rand((2,3,32,32)), requires_grad=True)            
optimizer.zero_grad()
encoding, decoding = net(x)


fc = ctr_sum(x,encoding)
fc3 = ctr_lossv3(x, net.encoder)
fc3.backward()


loss = criterion(decoding,x.detach())
loss.backward()
x.grad = None
optimizer.step()




print(fc)
print(fc3)



tensor(516.3995, grad_fn=<AddBackward0>)
tensor(516.3872, grad_fn=<AddBackward0>)
