# Import Lib

In [1]:
import torch
from torchvision import datasets, transforms
import torchvision 
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from sklearn.model_selection import train_test_split

# Data Generation 

In [2]:
# Configuration
num_bags = 1000
instances_per_bag = 10
label_digit = 7  # digit to detect
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor()
])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Prepare the data
def create_mil_mnist_bags(dataset, num_bags=1000, instances_per_bag=10, target_digit=7):
    bags = []
    labels = []  # Two binary labels per bag

    data = dataset.data
    targets = dataset.targets

    for _ in range(num_bags):
        indices = np.random.choice(len(dataset), instances_per_bag, replace=False)
        images = data[indices]
        digits = targets[indices]

        # First label: at least one image of the target digit
        label_1 = (digits == target_digit).any().item()
        # Second label: more than one image of the target digit
        label_2 = (digits == target_digit).sum().item() > 1

        bags.append(images)
        labels.append(torch.tensor([label_1, label_2], dtype=torch.float32))

    bags = torch.stack(bags)  # [num_bags, instances_per_bag, H, W]
    labels = torch.stack(labels)  # [num_bags, 2]

    return bags, labels

# Generate MIL-style bags
mil_bags, mil_labels = create_mil_mnist_bags(mnist_dataset, num_bags, instances_per_bag, label_digit)
mil_bags.shape, mil_labels.shape

(torch.Size([1000, 10, 28, 28]), torch.Size([1000, 2]))

# Attention Layer 

In [3]:
class AttentionLayer(nn.Module):
  def __init__(self, input_dim):
    super(AttentionLayer, self).__init__()
    self.input_dim = input_dim
    self.attention_weights = nn.Parameter(torch.randn(input_dim, 1))
    self.softmax = nn.Softmax(dim=1)
    self.fc = nn.Linear(input_dim, input_dim)

  def forward(self, x):
    # x: [batch_size, num_instances, input_dim]
    batch_size, num_instances, _ = x.size()
    x_flat = x.view(batch_size * num_instances, self.input_dim)
    
    # Compute attention scores
    attention_scores = torch.matmul(x_flat, self.attention_weights).view(batch_size, num_instances)
    
    # Apply softmax to get attention weights
    attention_weights = self.softmax(attention_scores)
    
    # Weighted sum of instances
    weighted_sum = torch.bmm(attention_weights.unsqueeze(1), x).squeeze(1)
    
    # Pass through a linear layer
    output = self.fc(weighted_sum)
    
    return output

# rbf kernel

In [4]:
class rbf_kernel(nn.Module):
    def __init__(self, input_dim, lengthscale_init=1.0):
        super().__init__()
        self.input_dim = input_dim
        # Use one lengthscale per dimension (ARD - Automatic Relevance Determination)
        self.log_lengthscale = nn.Parameter(torch.log(torch.ones(input_dim) * lengthscale_init))

    def forward(self, X1, X2):
        """
        Compute RBF (Gaussian) kernel matrix between X1 and X2 using ARD lengthscales.
        Args:
            X1: Tensor of shape (N1, D)
            X2: Tensor of shape (N2, D)
        Returns:
            Kernel matrix of shape (N1, N2)
        """
        # Ensure input is 2D
        if X1.ndimension() == 1:
            X1 = X1.unsqueeze(0)
        if X2.ndimension() == 1:
            X2 = X2.unsqueeze(0)

        # Scale by lengthscale (ARD: each dimension can have a different scale)
        X1_scaled = X1 / self.log_lengthscale.exp()
        X2_scaled = X2 / self.log_lengthscale.exp()

        # Compute squared Euclidean distance
        sqdist = torch.cdist(X1_scaled, X2_scaled, p=2).pow(2)

        # Compute RBF
        return torch.exp(-0.5 * sqdist)

# SVGP Model 

In [5]:
class MultiOutputSVGP(nn.Module):
    def __init__(self, input_dim, num_tasks, num_inducing):
        super().__init__()
        self.input_dim = input_dim
        self.num_tasks = num_tasks
        self.num_inducing = num_inducing

        # Shared inducing points Z
        self.Z = nn.Parameter(torch.randn(num_inducing, input_dim))

        # Variational parameters
        self.m = nn.Parameter(torch.zeros(num_inducing))  # Mean of variational distribution
        self.L = nn.Parameter(torch.eye(num_inducing))  # Lower-triangular Cholesky of S

        # Base kernel
        self.kernel = rbf_kernel(input_dim)

        # Coregionalization matrix B (T x T)
        self.B = nn.Parameter(torch.eye(num_tasks))
        
        # Register buffer for device tracking
        self.register_buffer('dummy', torch.zeros(1))

    def compute_coregionalized_kernel(self, X1, T1, X2, T2):
        """
        Coregionalized kernel: K((x,t), (x',t')) = k(x,x') * B[t,t']
        """
        # Get the device from model parameters
        device = self.dummy.device
        
        # Ensure T1 and T2 are on the correct device
        if isinstance(T1, torch.Tensor):
            T1 = T1.to(device)
        else:
            T1 = torch.tensor(T1, dtype=torch.long, device=device)
            
        if isinstance(T2, torch.Tensor):
            T2 = T2.to(device)
        else:
            T2 = torch.tensor(T2, dtype=torch.long, device=device)
            
        k_base = self.kernel(X1, X2)  # [N, M]
        B_selected = self.B[T1][:, T2]  # [N, M]
        return k_base * B_selected
    
    def forward(self, X, T, Y, full_n):
        """
        X: [B, D], input features
        T: [B], task indices
        Y: [B], binary labels (0 or 1)
        full_n: total number of training samples (for scaling ELBO)
        """
        # Get the device from model parameters
        device = self.dummy.device
        
        B = X.shape[0]  # batch size
        M = self.num_inducing

        # Compute kernel matrices
        Kxz = self.compute_coregionalized_kernel(X, T, self.Z, torch.zeros(M, dtype=torch.long, device=device))  # [B, M]
        Kzz = self.compute_coregionalized_kernel(self.Z, torch.zeros(M, dtype=torch.long, device=device),
                                                 self.Z, torch.zeros(M, dtype=torch.long, device=device)) + 1e-6 * torch.eye(M, device=device)

        # Compute predictive mean and variance of q(f)
        Lzz = torch.linalg.cholesky(Kzz)
        Kzz_inv = torch.cholesky_inverse(Lzz)

        S = self.L @ self.L.T  # Ensure positive-definite

        mean_f = Kxz @ Kzz_inv @ self.m  # [B]
        cov_f = (Kxz @ Kzz_inv @ S @ Kzz_inv @ Kxz.T).diag()  # [B]
        std_f = torch.sqrt(cov_f + 1e-6)

        # Monte Carlo sampling (approximation of latent function)
        eps = torch.randn_like(mean_f)
        f_sample = mean_f + std_f * eps

        # Sigmoid for binary classification
        prob = torch.sigmoid(f_sample)

        # Log likelihood term (binary cross-entropy for classification)
        log_lik = Y * F.logsigmoid(f_sample) + (1 - Y) * F.logsigmoid(-f_sample) # Log-likelihood of Bernoulli
        log_lik_term = (full_n / B) * log_lik.sum()

        # KL[q(u) || p(u)] (variational KL divergence)
        KL = 0.5 * (
            torch.trace(Kzz_inv @ S) +
            self.m @ Kzz_inv @ self.m -
            M +
            torch.logdet(Kzz) - torch.logdet(S + 1e-6 * torch.eye(M, device=device))
        )

        elbo = log_lik_term - KL
        return -elbo  # minimize negative ELBO

    def predict(self, X_test, T_test):
        """
        X_test: [B_test, D], test input features
        T_test: [B_test], task indices for the test set
        """
        # Get the device from model parameters
        device = self.dummy.device
        
        self.eval()
        B_test = X_test.shape[0]  # batch size of test points
        M = self.num_inducing

        # Compute kernel between test points and inducing points
        Kxz_test = self.compute_coregionalized_kernel(X_test, T_test, self.Z, torch.zeros(M, dtype=torch.long, device=device))  # [B_test, M]

        # Compute kernel between inducing points
        Kzz = self.compute_coregionalized_kernel(self.Z, torch.zeros(M, dtype=torch.long, device=device),
                                                 self.Z, torch.zeros(M, dtype=torch.long, device=device)) + 1e-6 * torch.eye(M, device=device)

        # Cholesky decomposition of Kzz
        Lzz = torch.linalg.cholesky(Kzz)
        Kzz_inv = torch.cholesky_inverse(Lzz)

        # Compute the posterior mean: Kxz_test * Kzz_inv * m
        mean_f_test = Kxz_test @ Kzz_inv @ self.m  # [B_test]

        # Compute the posterior covariance: Kxz_test * Kzz_inv * S * Kzz_inv * Kxz_test.T
        S = self.L @ self.L.T  # Ensure positive-definite
        cov_f_test = (Kxz_test @ Kzz_inv @ S @ Kzz_inv @ Kxz_test.T).diag()  # [B_test]

        # Compute the posterior standard deviation
        std_f_test = torch.sqrt(cov_f_test + 1e-6)

        # Apply sigmoid to get probabilities
        prob_test = torch.sigmoid(mean_f_test)

        return prob_test, std_f_test

# Dataset and Loader

In [6]:
class MILMNISTDataset(torch.utils.data.Dataset):
    def __init__(self, bags, labels):
        """
        A simple dataset class for MIL MNIST data
        
        Parameters:
        bags: [num_bags, instances_per_bag, H, W] - Bags of MNIST images
        labels: [num_bags, num_tasks] - Labels for each bag and task
        """
        self.bags = bags
        self.labels = labels
        self.num_bags = bags.shape[0]
        self.num_tasks = labels.shape[1]
        
    def __len__(self):
        # Return total number of bag-task pairs
        return self.num_bags * self.num_tasks
    
    def __getitem__(self, idx):
        # Convert idx to bag_idx and task_idx
        bag_idx = idx // self.num_tasks
        task_idx = idx % self.num_tasks
        
        # Get the bag
        bag = self.bags[bag_idx]  # [instances_per_bag, H, W]
        
        # Get the label for this task
        label = self.labels[bag_idx, task_idx]
        
        return bag, torch.tensor(task_idx, dtype=torch.long), label

In [7]:
from torchvision.models import resnet18

class AttentionLayer(nn.Module):
    def __init__(self, feature_dim, hidden_dim=128):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, x):
        # x: [batch_size, num_instances, feature_dim]
        attention_weights = self.attention(x)  # [batch_size, num_instances, 1]
        attention_weights = F.softmax(attention_weights, dim=1)  # Normalize weights across instances
        
        # Weighted sum of instance features
        weighted_features = torch.sum(x * attention_weights, dim=1)  # [batch_size, feature_dim]
        
        return weighted_features, attention_weights

class MIL_SVGP(nn.Module):
    def __init__(self, num_classes, feature_dim=512, num_inducing=50, pretrained=True):
        super().__init__()
        
        # Feature extractor (ResNet18)
        self.feature_extractor = resnet18(pretrained=pretrained)
        # Modify the first conv layer to accept grayscale images (1 channel)
        self.feature_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Remove the final classification layer
        self.feature_dim = feature_dim
        self.feature_extractor = nn.Sequential(*list(self.feature_extractor.children())[:-1])
        
        # Attention mechanism for instance aggregation
        self.attention = AttentionLayer(feature_dim)
        
        # Final classifier
        self.classifier = nn.Linear(feature_dim, num_classes)

        # Linear layer 
        self.linear_layer = nn.Linear(feature_dim, 8)
        
        # SVGP model for each task/class
        self.svgp = MultiOutputSVGP(
            input_dim=8,
            num_tasks=num_classes,
            num_inducing=num_inducing
        )
        
        # Register a buffer to track device
        self.register_buffer('dummy', torch.zeros(1))
        
    def extract_instance_features(self, x):
        """
        Extract features from individual instances
        
        Parameters:
        x: [batch_size, num_instances, channels, height, width]
        
        Returns:
        features: [batch_size, num_instances, feature_dim]
        """
        batch_size, num_instances = x.shape[0], x.shape[1]
        
        # Reshape to process all instances at once
        x_reshaped = x.view(batch_size * num_instances, *x.shape[2:])
        
        # Extract features
        features = self.feature_extractor(x_reshaped)
        features = features.view(features.size(0), -1)  # Flatten spatial dimensions
        
        # Reshape back to [batch_size, num_instances, feature_dim]
        features = features.view(batch_size, num_instances, -1)
        
        return features
    
    def forward(self, x, task_indices, labels=None, full_n=None):
        """
        Forward pass through the network
        
        Parameters:
        x: [batch_size, num_instances, channels, height, width] - Bag of instances
        task_indices: [batch_size] - Task indices for each bag
        labels: [batch_size] - Labels for each bag (optional, for training)
        full_n: Total number of training samples (optional, for ELBO scaling)
        
        Returns:
        logits: [batch_size, num_classes] - Classification logits
        loss: Scalar - SVGP loss (if labels and full_n are provided)
        """
        device = self.dummy.device
        batch_size = x.shape[0]
        
        # Ensure x is properly formatted
        if len(x.shape) == 4:  # [batch_size, num_instances, height, width]
            # Add channel dimension for grayscale
            x = x.unsqueeze(2)  # [batch_size, num_instances, 1, height, width]
        
        # Normalize pixel values
        x = x.float() / 255.0
        
        # Extract features from each instance in each bag
        instance_features = self.extract_instance_features(x)  # [batch_size, num_instances, feature_dim]
        
        # Aggregate instance features using attention
        bag_features, attention_weights = self.attention(instance_features)  # [batch_size, feature_dim]
        
        # Get classification logits
        logits = self.classifier(bag_features)  # [batch_size, num_classes]
        
        # Calculate SVGP loss if in training mode
        loss = None
        if labels is not None and full_n is not None:
            bag_features = self.linear_layer(bag_features)  # [batch_size, 8]
            loss = self.svgp(bag_features, task_indices, labels, full_n)
        
        return logits, loss, bag_features, attention_weights
    
    def predict(self, x, task_indices):
        """
        Make predictions for test data
        
        Parameters:
        x: [batch_size, num_instances, channels, height, width] - Bag of instances
        task_indices: [batch_size] - Task indices for each bag
        
        Returns:
        probs: [batch_size] - Predicted probabilities
        stds: [batch_size] - Uncertainty estimates
        """
        self.eval()
        
        # Ensure x is properly formatted
        if len(x.shape) == 4:  # [batch_size, num_instances, height, width]
            # Add channel dimension for grayscale
            x = x.unsqueeze(2)  # [batch_size, num_instances, 1, height, width]
        
        # Normalize pixel values
        x = x.float() / 255.0
        
        # Extract and aggregate features
        with torch.no_grad():
            instance_features = self.extract_instance_features(x)
            bag_features, _ = self.attention(instance_features)
            
            # Get SVGP predictions
            probs, stds = self.svgp.predict(bag_features, task_indices)
        
        return probs, stds

In [8]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Split data
indices = np.arange(len(mil_bags))
train_indices, test_indices = train_test_split(
    indices, 
    test_size=0.2,
    random_state=42,
    shuffle=True
)

train_bags = mil_bags[train_indices]
train_labels = mil_labels[train_indices]
test_bags = mil_bags[test_indices]
test_labels = mil_labels[test_indices]

print(f"Train bags shape: {train_bags.shape}, Train labels shape: {train_labels.shape}")
print(f"Test bags shape: {test_bags.shape}, Test labels shape: {test_labels.shape}")

# Create a simple dataset class that returns bags, task indices, and labels
class MILDataset(torch.utils.data.Dataset):
    def __init__(self, bags, labels):
        self.bags = bags
        self.labels = labels
        self.num_bags = bags.shape[0]
        self.num_tasks = labels.shape[1]
        
    def __len__(self):
        return self.num_bags * self.num_tasks
    
    def __getitem__(self, idx):
        bag_idx = idx // self.num_tasks
        task_idx = idx % self.num_tasks
        
        bag = self.bags[bag_idx]
        label_index = self.labels[bag_idx, task_idx]
        labels = self.labels[bag_idx]
        
        return bag, torch.tensor(task_idx, dtype=torch.long), label_index, labels

# Create datasets and dataloaders
train_dataset = MILDataset(train_bags, train_labels)
test_dataset = MILDataset(test_bags, test_labels)

train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=32, 
    shuffle=False,
    pin_memory=True
)

# Initialize the integrated MIL-SVGP model
model = MIL_SVGP(
    num_classes=2,  # Binary classification for each task
    feature_dim=512,  # ResNet18 feature dimension
    num_inducing=50,
    pretrained=True
).to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
bce_loss = nn.BCEWithLogitsLoss()

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    
    for bags, task_indices, labels_index, labels in train_loader:
        # Move data to device
        bags = bags.to(device)
        task_indices = task_indices.to(device)
        labels_index = labels_index.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        logit, loss, _, _ = model(bags, task_indices, labels_index, len(train_dataset))
        # print(f'Shape of labels: {labels.shape}')
        loss = bce_loss(logit, labels) * 0.5 + loss * 0.5
        # Backward pass
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}")

Using device: cuda
Train bags shape: torch.Size([800, 10, 28, 28]), Train labels shape: torch.Size([800, 2])
Test bags shape: torch.Size([200, 10, 28, 28]), Test labels shape: torch.Size([200, 2])




Epoch 1/20, Loss: 555.3259
Epoch 2/20, Loss: 554.7806
Epoch 3/20, Loss: 554.6634
Epoch 4/20, Loss: 554.6228
Epoch 5/20, Loss: 554.5673
Epoch 6/20, Loss: 554.5479
Epoch 7/20, Loss: 554.5397
Epoch 8/20, Loss: 554.5405
Epoch 9/20, Loss: 554.5300
Epoch 10/20, Loss: 554.5286
Epoch 11/20, Loss: 554.5289
Epoch 12/20, Loss: 554.5209
Epoch 13/20, Loss: 554.5324
Epoch 14/20, Loss: 554.5168
Epoch 15/20, Loss: 554.5343
Epoch 16/20, Loss: 554.5167
Epoch 17/20, Loss: 554.5370
Epoch 18/20, Loss: 554.5519
Epoch 19/20, Loss: 554.5300
Epoch 20/20, Loss: 554.5347


# Evaluation 

In [9]:
# Đánh giá mô hình trên tập kiểm tra sau khi huấn luyện xong
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch_features, batch_tasks, batch_labels_index, batch_labels in test_loader:
        batch_features = batch_features.to(device)
        batch_tasks = batch_tasks.to(device)
        
        # Dự đoán xác suất cho mỗi mẫu trong batch
        # probs, _ = model.predict(batch_features, batch_tasks)
        logits, _, _, _ = model(batch_features, batch_tasks)
        probs = torch.sigmoid(logits)
        probs = probs.cpu()
        
        # Chuyển xác suất thành nhãn dự đoán (ngưỡng 0.5)
        preds = (probs >= 0.5).float()
        
        all_preds.append(preds.cpu())
        all_labels.append(batch_labels.cpu())

# Gộp các batch lại
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

accuracy = accuracy_score(all_labels, all_preds)
# precision = precision_score(all_labels, all_preds)
# recall = recall_score(all_labels, all_preds)
# f1 = f1_score(all_labels, all_preds)

print(f"Accuracy: {accuracy:.4f}")
# print(f"Precision: {precision:.4f}")
# print(f"Recall: {recall:.4f}")
# print(f"F1 Score: {f1:.4f}")

Accuracy: 0.7650
