# MEMO for Test Time Adaptation
TODO: insert short intro

import libraries

In [7]:
import torch
import torchvision

from torch.utils.data import Dataset
import boto3
from pathlib import Path

from io import BytesIO
from PIL import Image

import torch.nn.functional as F
import torchvision.transforms as T

import torch.nn as nn

## Data
TODO with bucket

In [2]:
class S3ImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.s3_bucket = "deeplearning2024-datasets"
        self.s3_region = "eu-west-1"
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)
        self.transform = transform

        # Get list of objects in the bucket
        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name

            # Keep track of valid instances
            self.instances.append((label, key))

        # Sort classes in alphabetical order (as in ImageFolder)
        self.classes = sorted(set(label for label, _ in self.instances))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, key = self.instances[idx]

            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            response = self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            # img_bytes = response["Body"]._raw_stream.data

            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return img, self.class_to_idx[label]

In [12]:
def get_data(batch_size, img_root):
    # Prepare data transformations for the train loader
    transform = T.Compose([
        T.Resize((256, 256)),                                                   # Resize each PIL image to 256 x 256
        T.RandomCrop((224, 224)),                                               # Randomly crop a 224 x 224 patch
        T.ToTensor(),                                                           # Convert Numpy to Pytorch Tensor
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),      # Normalize with ImageNet mean

        # ViT_B_16_Weights.DEFAULT.transforms()        # TODO: is this correct here?
    ])

    # Load data
    # officehome_dataset = ImageFolder(root=img_root, transform=transform)
    dataset = S3ImageFolder(root=img_root, transform=transform)

    # Create train and test splits (80/20)
    num_samples = len(dataset)
    training_samples = int(num_samples * 0.8 + 1)
    test_samples = num_samples - training_samples

    training_data, test_data = torch.utils.data.random_split(dataset, [training_samples, test_samples])

    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size, shuffle=False, num_workers=4)

    return train_loader, test_loader

In [20]:
# from https://github.com/hendrycks/natural-adv-examples/blob/master/eval.py

# adversarial samples

import torchvision.transforms as trn

thousand_k_to_200 = {}
indices_in_1k = [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1]

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

test_transform = trn.Compose(
    [trn.Resize(256), trn.CenterCrop(224), trn.ToTensor(), trn.Normalize(mean, std)])

NameError: name 'trn' is not defined

## Model
ViT-b/16

In [4]:
from torchvision.models import vit_b_16, ViT_B_16_Weights

weights = ViT_B_16_Weights.DEFAULT
# preprocess = weights.transforms()

model = vit_b_16(weights=weights)

In [None]:
def build_model():
    return vit_b_16(pretrained=True)

## Cost function

In [14]:
# TODO make as function

class MarginalEntropy(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, outputs):
        # compute mean entropy
        logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
        avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])
        min_real = torch.finfo(avg_logits.dtype).min
        avg_logits = torch.clamp(avg_logits, min=min_real)
        mean_entropy = -(avg_logits * torch.exp(avg_logits)).sum(dim=-1), avg_logits
        return mean_entropy

loss_fn = MarginalEntropy()

## Augmentations for MEMO

In [None]:
## https://github.com/google-research/augmix

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
preaugment = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
])
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

def _augmix_aug(x_orig):
    x_orig = preaugment(x_orig)
    x_processed = preprocess(x_orig)
    w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
    m = np.float32(np.random.beta(1.0, 1.0))

    mix = torch.zeros_like(x_processed)
    for i in range(3):          # repeat process
        x_aug = x_orig.copy()
        for _ in range(np.random.randint(1, 4)):            # how many filter apply
            x_aug = np.random.choice(augmentations)(x_aug)  # pick randomly
        mix += w[i] * preprocess(x_aug)     # multiply by a magnitude ??
    mix = m * x_processed + (1 - m) * mix
    return mix

In [None]:
augmentations = []

## Training and Test steps

In [18]:
def test(model, data_loader, cost_function, device="cuda"):
    model.eval()

    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)

            loss = cost_function(outputs, targets)

            samples += inputs.shape[0]
            cumulative_loss += loss.item()
            _, predicted = outputs.max(1)

            cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

In [None]:
def memo_tta(pretrained_model, sample, epochs_for_adaptation, augmentations):
    optimizer = torch.optim.SGD(pretrained_model.parameters(), lr=learning_rate)

    augmented_samples = []
    
    # TODO how to parallelize this
    for augmentation in augmentations:
        augmented_sample = augmentation(sample)
        augmented_samples.append(augmented_sample[0])
    augmented_samples = torch.stack(augmented_samples)

    pretrained_model.eval()
    for epoch in range(epochs_for_adaptation):
        output_distributions = pretrained_model(augmented_samples)
        output_distributions = output_distributions[:, indices_in_1k]
    
        loss, _ = loss_fn(output_distributions)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    # TODO Adapting BN statistics, maybe in adapt_single()
    
    return pretrained_model

---

## Experiments

In [None]:
IMG_ROOT = "imagenet-a/"
BATCH_SIZE = 32  # TODO: to change?

train_loader, test_loader = get_data(BATCH_SIZE, IMG_ROOT)

cost_function = torch.nn.CrossEntropyLoss()

model = build_model()

model.cuda()
model.eval()

loss, accuracy = test(model, test_loader, cost_function, device="cpu")
print(loss, accuracy)


In [None]:
# implementation of: test time adaptation via MEMO

# "only one gradient step per test point" as per paper
epochs_for_adaptation = 1

correct_predictions = 0

for batch_idx, (inputs, targets) in enumerate(train_loader):
    
    # input_image = sample[0].unsqueeze(0)  # create a mini-batch as expected by the resnet model
    # # todo, are the labels in the same order as the output classes from the model?
    # input_label = sample[1]

    # pretrained_model = get_fresh_pretrained_model()

    pretrained_model = init_vit_b()
    adapted_model = memo_tta(pretrained_model, input_image, epochs_for_adaptation)

    # inference phase
    # todo change prior_strength
    prediction = test_single(adapted_model, input_image, -1)
    correctness = 1 if prediction == input_label else 0
    correct_predictions += correctness