In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Define Encoder (Maps image to latent space)
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
    
    def forward(self, img):
        return self.model(img)

# Define Head (Performs classification and predicts loss)
class Head(nn.Module):
    def __init__(self, latent_dim):
        super(Head, self).__init__()
        self.classifier = nn.Linear(latent_dim, 10)
        self.loss_predictor = nn.Linear(latent_dim, 1)
    
    def forward(self, latent):
        class_output = self.classifier(latent)
        loss_prediction = self.loss_predictor(latent)
        return class_output, loss_prediction

# Define Refiner (Updates latent vector if predicted loss is high)
class Refiner(nn.Module):
    def __init__(self, latent_dim):
        super(Refiner, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
    
    def forward(self, latent):
        return self.model(latent)

# Hyperparameters
latent_dim = 64
batch_size = 64
lr = 0.001
epochs = 10
loss_threshold = 0.5

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize models
encoder = Encoder(latent_dim)
head = Head(latent_dim)
refiner = Refiner(latent_dim)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
head.to(device)
refiner.to(device)


# Optimizers
optimizer = optim.Adam(list(encoder.parameters()) + list(head.parameters()) + list(refiner.parameters()), lr=lr)

# Loss functions
classification_loss = nn.CrossEntropyLoss()
loss_prediction_loss = nn.MSELoss()

# Training loop
for epoch in range(epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        optimizer.zero_grad()

        imgs = imgs.to(device)
        labels = labels.to(device)

        
        # Encode image to latent space
        latent = encoder(imgs)
        
        # Get classification output and predicted loss
        class_output, predicted_loss = head(latent)
        # Compute actual classification loss
        actual_loss = classification_loss(class_output, labels)
        loss_pred_loss = loss_prediction_loss(predicted_loss.squeeze(), actual_loss.detach())
        
        # Refine latent vector if predicted loss is too high (up to 5 times)
        for _ in range(10):
                latent_tmp = refiner(latent)
                latent = torch.where(predicted_loss > loss_threshold, latent_tmp, latent)
                class_output, predicted_loss = head(latent)
                actual_loss = classification_loss(class_output, labels)
                loss_pred_loss = loss_prediction_loss(predicted_loss.squeeze(), actual_loss.detach())
        
        # Total loss
        total_loss = actual_loss + loss_pred_loss
        total_loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{epochs} | Classification Loss: {actual_loss.item():.4f} | Loss Prediction Loss: {loss_pred_loss.item():.4f}")


In [28]:
predicted_loss.argmax()

tensor(4, device='cuda:0')

In [29]:
import plotly.express as px

px.imshow(imgs[4].squeeze().cpu().numpy())

In [4]:
predicted_loss

tensor([[-0.1854],
        [-0.1140],
        [-0.2096],
        [-0.0622],
        [-0.1045],
        [-0.1606],
        [-0.2371],
        [-0.2042],
        [-0.1474],
        [-0.1966],
        [-0.1145],
        [-0.2203],
        [-0.1291],
        [-0.1771],
        [-0.0921],
        [-0.2438],
        [-0.2116],
        [-0.2724],
        [-0.1708],
        [-0.1663],
        [-0.1306],
        [-0.1410],
        [-0.1281],
        [-0.2084],
        [-0.0121],
        [-0.2215],
        [-0.0538],
        [-0.0116],
        [-0.0874],
        [-0.1568],
        [-0.1303],
        [-0.1734],
        [-0.1641],
        [-0.1585],
        [-0.1278],
        [-0.0612],
        [-0.1441],
        [-0.1778],
        [-0.1813],
        [-0.1492],
        [-0.1351],
        [-0.1792],
        [-0.2016],
        [-0.1699],
        [-0.1871],
        [-0.1851],
        [-0.1729],
        [-0.1695],
        [-0.1823],
        [-0.1069],
        [-0.1124],
        [-0.0905],
        [-0.