In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
# SimpleCNN điều chỉnh cho CIFAR-10 (3 channels, 32x32)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # Thay 1 -> 3 channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(0.25)  # Dropout cho UQ
        self.fc1 = nn.Linear(4096, 128)  # Flatten size: 64*8*8 = 4096 sau pooling (32->16->8)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # Dropout active cho UQ
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  # Log probs cho NLL loss

**Định nghĩa model:**

- Trong class `SimpleCNN`, layer `self.dropout = nn.Dropout(0.25)` là nền tảng cho UQ (dropout được giữ active trong inference).
  - **Khúc:** Dòng __init__ và `forward` (dòng `x = self.dropout(x)`).


In [None]:
# Training function (giữ nguyên)
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

In [None]:
# Inference với MC Dropout cho UQ (giữ nguyên)
def predict_with_uq(model, device, data, num_samples=20):
    model.train()  # Enable dropout cho UQ
    with torch.no_grad():
        preds = torch.stack([model(data) for _ in range(num_samples)], dim=0)  # (num_samples, batch_size, num_classes)
        probs = preds.exp()  # Convert log probs to probs
        mean_probs = probs.mean(dim=0)  # Average predictions
        uncertainty = probs.var(dim=0).mean(dim=1)  # Variance làm uncertainty (epistemic)
    return mean_probs.argmax(dim=1), uncertainty  # Predicted class và uncertainty

- **Hàm chính cho UQ:** Hàm `predict_with_uq` – toàn bộ hàm này tính UQ.
  - Khúc cụ thể:
    - `model.train()` để bật dropout.
    - `preds = torch.stack([model(data) for _ in range(num_samples)], dim=0)`: Chạy nhiều pass.
    - `uncertainty = probs.var(dim=0).mean(dim=1)`: Tính variance làm uncertainty.

- **Sử dụng:** Trong phần evaluation (full test loop) và visualize sample, unc được tính và in ra (e.g., `preds, unc = predict_with_uq(...)`).

In [None]:
# Cập nhật transform cho CIFAR-10 (RGB, mean/std chuẩn)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalize cho CIFAR-10
])

In [None]:
# Load datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Train (tăng epochs để cải thiện)
for epoch in range(1, 11):  # Tăng từ 3 lên 11 cho CIFAR-10
    train(model, device, train_loader, optimizer, epoch)
    print(f"Epoch {epoch} completed")  # Theo dõi tiến độ

In [None]:
# Full test evaluation
correct = 0
total = 0
uncertainties = []
for data, targets in test_loader:
    data, targets = data.to(device), targets.to(device)
    preds, unc = predict_with_uq(model, device, data)
    correct += (preds == targets).sum().item()
    total += targets.size(0)
    uncertainties.extend(unc.cpu().numpy().tolist())

In [None]:
accuracy = correct / total
avg_uncertainty = sum(uncertainties) / len(uncertainties)
print(f"Test Accuracy: {accuracy:.4f}, Average Uncertainty: {avg_uncertainty:.6f}")

In [None]:
# Ví dụ: Visualize một misclassified sample (e.g., index 0 trong test_dataset)
test_image, test_label = test_dataset[0]
test_image = test_image.unsqueeze(0).to(device)
preds, unc = predict_with_uq(model, device, test_image)
pred = preds[0].item()
unc_val = unc[0].item()
plt.imshow(test_image[0].cpu().permute(1, 2, 0).numpy())  # RGB nên permute
plt.title(f"Ground Truth: {test_label}, Predicted: {pred}, Uncertainty: {unc_val:.6f}")
plt.show()