In [1]:
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F


In [5]:
'''
Data Preprocessing

'''
#TODO Preprocess the data to a suitable format

import matplotlib.pyplot as plt
import torchvision

# Load MNIST dataset using torchvision
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())

# Get a random image from the dataset
image, label = mnist_dataset[np.random.randint(0, len(mnist_dataset))]

# # Plot the image
# plt.imshow(image[0], cmap='gray')
# plt.title(f'Label: {label}')
# plt.show()

test_image = image[0].view(-1, 28 * 28)


In [25]:
'''
Settings

'''
#TODO Define the settings of the IMAST model.

# Trade-off parameter for mutual information and smooth regularization
lam = 0.1


In [4]:
'''
Deep Neural Network

'''

class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        
        # Add first fully connected layer with 28 * 28 = 784 input neurons and 1200 output neurons
        self.fc1 = nn.Linear(28 * 28, 1200)
        # Initialize the weights of the first fully connected layer using the He normal initialization
        init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
        # Add first batch normalization layer with 1200 neurons and epsilon = 2e-5
        self.bn1   = nn.BatchNorm1d(1200, eps=2e-5)
        self.bn1_F = nn.BatchNorm1d(1200, eps=2e-5, affine=False)
        # Add first ReLU activation function
        self.relu1 = nn.ReLU()
        
        self.fc2 = nn.Linear(1200, 1200)
        init.kaiming_normal_(self.fc2.weight, nonlinearity='relu')
        self.bn2   = nn.BatchNorm1d(1200, eps=2e-5)
        self.bn2_F = nn.BatchNorm1d(1200, eps=2e-5, affine=False)

        self.relu2 = nn.ReLU()
        
        # Add output layer of size 10 
        self.fc3 = nn.Linear(1200, 10)
        init.kaiming_normal_(self.fc3.weight, nonlinearity='linear')
        
    # Define the forward pass through the network
    def forward(self, x):
        # Pass the input through the first fully connected layer
        x = self.fc1(x)
        # Pass the output of the first fully connected layer through the first batch normalization layer
        x = self.bn1(x)
        # Pass the output of the first batch normalization layer through the first ReLU activation function
        x = self.relu1(x)
        
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        
        x = self.fc3(x)
        
        return x
    
net = NeuralNet()

In [22]:
'''
Mutual Information

'''

def shannon_entropy(probabilities: torch.Tensor) -> float:
    """
    Computes the Shannon entropy of a tensor of probabilities. According to EEq. (9)
    
    Args:
    - probabilities: a 1D PyTorch tensor of probabilities
    
    Returns:
    - the Shannon entropy as a float
    """

    return -torch.sum(probabilities * torch.log(probabilities))


def mutual_information(probabilities: torch.Tensor, conditionals: torch.Tensor) -> float:
    """
    Calculate the mutual information between two discrete random variables. According to Eq. (7)
    
    Parameters:
    probabilities (torch.Tensor): The joint probabilities of the two random variables.
    conditionals (torch.Tensor): The conditional probabilities of each outcome of the second random variable given the first random variable.
    
    Returns:
    float: The mutual information between the two random variables.
    """
    marg_entropy = shannon_entropy(probabilities)
    cond_entropy = shannon_entropy(conditionals)
    
    return marg_entropy - cond_entropy

In [28]:
'''
Self-Augmented Training (SAT)

'''

import torch
import torch.nn.functional as F

def virtual_adversarial_perturbation(x, logit, eps=1.0, xi=1e-6, num_iters=1):
    """
    Calculate the virtual adversarial perturbation for a batch of input samples x.
    
    Args:
    - x: input samples (batch_size x input_dim)
    - logit: unnormalized log probabilities of the model for the input samples (batch_size x num_classes)
    - eps: perturbation size (float)
    - xi: small constant used for computing the finite difference approximation of the KL divergence (float)
    - num_iters: number of iterations to use for computing the perturbation (int)
    
    Returns:
    - d: virtual adversarial perturbation for the input samples (batch_size x input_dim)
    """
    # Sample random noise with the same shape as x
    d = torch.randn_like(x)
    for i in range(num_iters):
        # Normalize the noise and scale it by xi
        d = xi * F.normalize(d, dim=1)
        # Make a copy of the logits and detach them from the computation graph
        logit_p = logit.clone().detach()
        logit_m = logit.clone().detach()
        # Add the scaled noise to the logits
        logit_p += eps * F.normalize(d, dim=1)
        # Subtract the scaled noise from the logits
        logit_m -= eps * F.normalize(d, dim=1)
        # Compute the softmax probabilities of the perturbed logits
        p_p = F.softmax(logit_p, dim=1)
        p_m = F.softmax(logit_m, dim=1)
        # Compute the KL divergence between the probabilities
        kl = F.kl_div(p_p.log(), p_m, reduction='batchmean')
        # Compute the gradient of the KL divergence w.r.t. the noise
        d_grad, = torch.autograd.grad(kl, [d], retain_graph=False)
        # Update the noise by taking a small step in the direction of the gradient
        d = d_grad.detach()
    # Normalize the final noise and scale it by xi
    return xi * F.normalize(d, dim=1)

def virtual_adversarial_training(model, x, y, eps=1.0, xi=1e-6, num_iters=1):
    """
    Apply virtual adversarial training to a batch of input samples x and labels y.
    
    Args:
    - model: neural network model
    - x: input samples (batch_size x input_dim)
    - y: labels for the input samples (batch_size)
    - eps: perturbation size (float)
    - xi: small constant used for computing the finite difference approximation of the KL divergence (float)
    - num_iters: number of iterations to use for computing the perturbation (int)
    
    Returns:
    - loss: total loss (sum of cross-entropy loss on original input and perturbed input) for the batch (float)
    """
    # Compute the logits of the model on the input
    logit = model(x)
    # Detach the logits from the computation graph
    logit.detach_()
    # Compute the virtual adversarial perturbation for the input
    vadv = virtual_adversarial_perturbation(x, logit, eps, xi, num_iters)
    # Compute the logits of the model on the perturbed input
    logit_p = model(x + vadv)
    # Compute the cross-entropy loss on the original input
    loss = F.cross_entropy


#TODO Implement Virtual Adversarial Training (VAT)

#TODO Implement regularization penalty according to Eq. (4), (6), (6).

'\nSelf-Augmented Training (SAT)\n\n'

In [24]:
'''

'''

# # For running net on single training example
# net.eval()

# cond_pr = F.softmax(net(test_image), dim=1)
# marg_pr = cond_pr # TODO Implement according to Eq. 15

# mutual_information(marg_pr, cond_pr)
# # shannon_entropy(marg_pr)











tensor(0., grad_fn=<SubBackward0>)

In [None]:
'''
Training

'''

In [None]:
'''
Evaluation Metric

'''
# TODO Implement the unsupervised clustering accuracy (ACC) according to Eq. (16)