In [6]:
from loaders import tiny_imagenet_loader, tiny_imagenet_corrupted_loader
import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.swa_utils import AveragedModel, SWALR
from tqdm.notebook import tqdm

In [11]:
# Base model parameters
loader = tiny_imagenet_loader(train=True, batch_size=128, num_workers=4)
model = models.resnet18()
optimizer = SGD(model.parameters(), lr=0.01)
loss_fn = CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
epochs = 300

# SWA parameters
swa_model = AveragedModel(model)

model.to(device)
swa_model.to(device)

# TODO: Review documentation (Comes from SWA paper)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

Using device cuda


In [None]:
# Check the model parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_params)

11689512


In [None]:
# Train the model on ImageNet
for epoch in tqdm(range(epochs), desc="Training", leave=False):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        loss_fn(model(images), labels).backward()
        optimizer.step()
    
    # Schedule learning rate with SWA
    if epoch > swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()

2. On the test set of corrupted images, compute the adapted parameters
$$\hat{\theta_i} = \arg\max_{\theta} \frac{\tilde{\alpha}}{m}\sum_{j=1}^m -H(Y\mid \tilde{x}_j,\theta) + \log q_i(\theta)$$

For $q$ Gaussian with mean $\mu$ and covariance $C$,
$$q(\theta) = (2\pi)^{-d/2}\text{det}(C)^{-1/2}\exp(-\frac{1}{2}(\theta-\mu)C^{-1}(\theta-\mu))$$
$$\log q(\theta) = \log ((2\pi)^{-d/2}\text{det}(C)^{-1/2}) -\frac{1}{2}(\theta-\mu)C^{-1}(\theta-\mu))$$

Then the explicit form to minimize is
$$\hat{\theta_i} = \arg\min_{\theta} \frac{\tilde{\alpha}}{m}\sum_{j=1}^m \sum_{y \in Y} p(y|x,\theta)\log p(y|x,\theta) + \frac{1}{2}(\theta-\mu)C^{-1}(\theta-\mu)$$

In [None]:
def test_adapted_loss_fn(model, test_samples, mean, covariance, num_classes=200, alpha=1):
    """ Loss function for test adaptation for Bayesian Adaptation as described above """

    total_test_entropy = 0
    for j in range(len(test_samples)):
        for y in range(num_classes):
            pred = model(test_samples[j])
            total_test_entropy += pred * math.log(pred)

    test_entropy_term = alpha / len(test_samples) * total_test_entropy

    C_inv = torch.linalg.inv(covariance)

    flat_params = torch.cat([p.view(-1) for p in model.parameters()])

    train_entropy_term = 0.5 * (flat_params - mean) @ C_inv @ (flat_params - mean)

    return test_entropy_term + train_entropy_term
    

def test_adaptation(model, test_loader):
    """ Take the baseline model and test loader, output a new model that is adapted to the test set """

    

3. For each test input $\tilde{x}_j$, marginalize over ensemble
$$P(y\mid\tilde{x_j}) = \frac{1}{k}\sum_{i=1}^k P(y\mid\tilde{x_j}, \hat{\theta_i})$$

In [1]:
def ensemble_prediction(x, models, num_classes):
    """ Output the prediction of the ensemble of models """
    average_prediction = tensor.zeros(num_classes)
    for model in models:
        average_prediction += 1/len(models) * model(x)
    return average_prediction    

_IncompleteInputError: incomplete input (1435617026.py, line 3)