In [28]:
import torch.nn as nn
import torch

torch.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # encoder
        self.input = nn.Linear(784,512)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(512,128)
        self.fc2 = nn.Linear(128,32)
        self.fc3 = nn.Linear(32,10)
        
        # decoder
        self.up_fc1 = nn.Linear(10,32)
        self.up_fc2 = nn.Linear(32,128)
        self.up_fc3 = nn.Linear(128,512)
        self.up_fc4 = nn.Linear(512,784)
        self.sigmoid = nn.Sigmoid()

        
    def encoder(self, x):
        x = self.input(x)
        x = self.relu(x)
        x = self.tanh(x) 
        x = self.fc1(x)
        x = self.relu(x)
        x = self.tanh(x) 
        x = self.fc2(x)
        x = self.tanh(x) 
        x = self.fc3(x)
        return x
        
    def decoder(self, x):
        x = self.up_fc1(x)
        x = self.tanh(x)
        x = self.up_fc2(x)
        x = self.tanh(x)
        x = self.up_fc3(x)
        x = self.tanh(x)
        x = self.up_fc4(x)
        x = self.sigmoid(x)
        return x
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [29]:
import torch.optim as optim

autoencoder = Autoencoder()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.003)
loss_function = nn.MSELoss()

In [30]:
def custom_loss(outputs, inputs,pred,truth , model_parameters):
    inputs = inputs.to(device)
    outputs = outputs.to(device)
    pred = pred.to(device)
    truth = truth.to(device)
    
    # Find indices where truth is equal to pred
    matching_indices = torch.nonzero(truth == pred)
    non_matching = torch.nonzero(truth != pred)

    same_labels =  torch.sum(torch.exp(- ((outputs[matching_indices] - inputs[matching_indices])).pow(2)))   / ((truth == pred).sum().item()+1)
    dif_labels =   torch.sum((1 - torch.exp(- ((outputs[non_matching] - inputs[non_matching])).pow(2))))  / ((truth != pred).sum().item()+1)
    weight_term =   same_labels * dif_labels
    cmse_loss =  (weight_term * ((pred - truth) ** 2)).mean()
    l2_penalty = 0.00001  * sum([(p**2).sum() for p in model_parameters])
    loss = cmse_loss + l2_penalty
    return loss

In [31]:
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0,), (1.0,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=32)

In [32]:
from torch.optim.lr_scheduler import StepLR


num_epochs = 10
autoencoder.to(device)
lr_scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

for epoch in range(num_epochs):
    running_loss = 0.0

    for data in train_loader:
        inputs, truth = data
        inputs = inputs.to(device).view(-1, 784)
        optimizer.zero_grad()
        outputs = autoencoder(inputs)
#         pred  = autoencoder.encoder(inputs)
        loss = loss_function(outputs, inputs)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    lr_scheduler.step() 
    with torch.no_grad():
        test_running_loss=0
        for data in test_loader:
            inputs, truth = data
            inputs = inputs.to(device).view(-1, 784)
            outputs = autoencoder(inputs)
            loss = loss_function(outputs, inputs)
            test_running_loss += loss.item()
    print(f"Epoch {epoch+1}, Train Loss:{running_loss}, Test Loss:{test_running_loss}")
            
        

Epoch 1, Train Loss:98.9900325126946, Test Loss:14.687195479869843
Epoch 2, Train Loss:69.62784658931196, Test Loss:11.956211220473051
Epoch 3, Train Loss:60.32480302080512, Test Loss:11.879587262868881
Epoch 4, Train Loss:53.515615133568645, Test Loss:10.101189261302352
Epoch 5, Train Loss:49.12290708348155, Test Loss:9.773617381229997
Epoch 6, Train Loss:46.14490120485425, Test Loss:9.50498066842556
Epoch 7, Train Loss:44.63449816033244, Test Loss:8.858461201190948
Epoch 8, Train Loss:43.16488500405103, Test Loss:8.494042042642832
Epoch 9, Train Loss:42.08986044768244, Test Loss:9.206453433260322
Epoch 10, Train Loss:41.420900595374405, Test Loss:8.610432829707861
