# MEMO Assignment
Andrea De Carlo and Pooya Torabi

In [57]:
import torchvision
import torch

In [58]:
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision.transforms.v2 import functional as F


def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, torchvision.tv_tensors.BoundingBoxes):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = F.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = F.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

In [83]:
import boto3
from PIL import Image
from io import BytesIO
from pathlib import Path
from torch.utils.data import Dataset


class S3ImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.s3_bucket = "deeplearning2024-datasets"
        self.s3_region = "eu-west-1"
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)
        self.transform = transform

        # Get list of objects in the bucket
        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)
            
            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name

            # Keep track of valid instances
            self.instances.append((label, key))

        # Sort classes in alphabetical order (as in ImageFolder)
        self.classes = sorted(set(label for label, _ in self.instances))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, key = self.instances[idx]
            
            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            response = self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            # img_bytes = response["Body"]._raw_stream.data
            
            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return img, self.class_to_idx[label]

In [59]:
resnet_preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [60]:
# Dataset: ImageNet-A
# https://github.com/hendrycks/natural-adv-examples

# TODO read from aws s3 when on aws
imagenet_a_directory = 'datasets/imagenet-a'
imagenet_a = torchvision.datasets.ImageFolder(root=imagenet_a_directory,transform=resnet_preprocess)
# imagenet_a = S3ImageFolder(root="OfficeHomeDataset_10072016/Real World", transform=resnet_preprocess)

In [62]:
import numpy as np
import torch.nn as nn


class MeanEntropy(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, outputs):
        # compute mean entropy
        # todo needs to be tested
        logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
        avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])
        min_real = torch.finfo(avg_logits.dtype).min
        avg_logits = torch.clamp(avg_logits, min=min_real)
        mean_entropy= -(avg_logits * torch.exp(avg_logits)).sum(dim=-1), avg_logits
        return mean_entropy


loss_fn = MeanEntropy()

In [63]:
preaugment = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
])


augmentations = [preaugment]
    # # todo what are the augementation: AugMix
    # probably in utils.train_helpers.prepare_test_data

In [79]:
def memo_tta(pretrained_model, sample, optimizer):
    augmented_samples=[]
    # todo how to parallelize this
    for augmentation in augmentations:
        augmented_sample = augmentation(sample)
        augmented_samples.append(augmented_sample)

    output_distributions=[]
    
    pretrained_model.eval()
    for augmented_sample in augmented_samples:
        output_distribution = pretrained_model(augmented_sample)
        output_distributions.append(output_distribution[0])
    output_distributions = torch.stack(output_distributions)
        
    loss, _ = loss_fn(output_distributions)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # todo Adapting BN statistics
    return pretrained_model

In [77]:
# todo what is the lr in paper
learning_rate = .1

fresh_pretrained_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

def get_fresh_pretrained_model(): 
    pretrained_model = fresh_pretrained_model
    pretrained_model = pretrained_model.cuda()
    # in case multiple GPUs are available, this will try to utilize by Data Parallelism
    pretrained_model = nn.DataParallel(pretrained_model)
    optimizer = torch.optim.SGD(pretrained_model.parameters(), lr=learning_rate)
    
    return pretrained_model, optimizer

Using cache found in C:\Users\pooya/.cache\torch\hub\pytorch_vision_v0.10.0


In [82]:
# implementation of: test time adaptation via MEMO

# "only one gradient step per test point" as paper
epoch_for_adaptation = 1

for sample in imagenet_a:
    image_input = sample[0].unsqueeze(0) # create a mini-batch as expected by the resnet model
    
    pretrained_model, optimizer = get_fresh_pretrained_model()
    adapted_model = None
    for epoch  in range(epoch_for_adaptation):
        adapted_model = memo_tta(pretrained_model, image_input, optimizer)
    
    # inference phase
    with torch.no_grad():
        prediction = adapted_model(image_input)
        # loss = loss_for_top_1_acc(prediction, sample['target'])
    # log inference accuracy

KeyboardInterrupt: 