# Imports

In [1]:
!pip install opencv-python-headless
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-pb6lpjkp
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-pb6lpjkp
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25ldone


In [2]:
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.transforms.v2 import AugMix
from torchvision.transforms import InterpolationMode
from torchvision.models import resnet50, ResNet50_Weights
from copy import deepcopy
import torch
from torch import nn
from torch import optim
import numpy as np
from clip import load, tokenize
import cv2
import os
import csv
import boto3
from PIL import Image
from io import BytesIO
from pathlib import Path



# Parameters

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
MEMO_device = "cuda" if torch.cuda.is_available() else "cpu"
TPT_device = "cuda" if torch.cuda.is_available() else "cpu"

imageNetA = True
naug = 64
top = 0.1
niter = 1
no_backwards = True
testSingleModels = True

# set the seed
torch.manual_seed(0)
np.random.seed(0)

# Data Loading

## Class names mapping

In [4]:
imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]

## ImageNet-A

In [5]:
class ImageNetA(Dataset):
    """
    A custom dataset class for loading images from the ImageNet-A dataset.

    Args:
        root (str): The root directory of the dataset.
        csvMapFile (str, optional): The path to the CSV file containing the mapping of WordNet IDs to class names. Defaults to "dataloaders/wordNetIDs2Classes.csv".
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to None.
    """

    def __init__(
        self, root, csvMapFile="wordNetIDs2Classes.csv", transform=None
    ):
        self.s3_bucket = "deeplearning2024-datasets"
        self.s3_region = "eu-west-1"
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)

        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        #print(objects)
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        mapping = {}
        csv_file = csv.reader(open(csvMapFile, "r"))
        for id, wordnet, name in csv_file:
            if id == "resnet_label":
                continue
            mapping[int(wordnet)] = {"id": id, "name": name}

        # print(mapping)
        self.classnames = {}

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = int(path.parent.name[1:])
            name = mapping[label]["name"]
            self.classnames[mapping[label]["id"]] = name
            label = int(mapping[label]["id"])


            # Keep track of valid instances
            self.instances.append((label, name, key))

        self.transform = transform

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, name, key = self.instances[idx]
            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            img_bytes.seek(0)  # Ensure the BytesIO object is at the start
            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return {"img": img, "label": label, "name": name}

## ImageNet-V2

In [6]:
class ImageNetV2(Dataset):
    """
    A custom dataset class for loading images from the ImageNet-V2 dataset.

    Args:
        root (str): The root directory of the dataset.
        csvMapFile (str, optional): The path to the CSV file containing the mapping of WordNet IDs to class names. Defaults to "dataloaders/wordNetIDs2Classes.csv".
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to None.
    """

    def __init__(
        self, root, csvMapFile="wordNetIDs2Classes.csv", transform=None
    ):
        self.s3_bucket = "deeplearning2024-datasets"
        self.s3_region = "eu-west-1"
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)

        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        mapping = {}
        csv_file = csv.reader(open(csvMapFile, "r"))
        for id, _, name in csv_file:
            if id == "resnet_label":
                continue
            mapping[id] = name

        self.classnames = {}
        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name
            name = mapping[label]
            self.classnames[label] = name

            label = int(label)

            # Keep track of valid instances
            self.instances.append((label, name, key))

        self.transform = transform

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, name, key = self.instances[idx]
            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            img_bytes.seek(0)  # Ensure the BytesIO object is at the start
            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return {"img": img, "label": label, "name": name}

## EasyAugmenter

In [7]:
class EasyAgumenter(object):
    def __init__(self, base_transform, preprocess, augmentation, n_views=63):
        self.base_transform = base_transform
        self.preprocess = preprocess
        self.n_views = n_views

        if augmentation == 'augmix':

            self.preaugment = transforms.Compose(
                [
                    AugMix(),
                    transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
                    transforms.CenterCrop(224),
                ]
            )
        elif augmentation == 'identity':
            self.preaugment = self.base_transform
        elif augmentation == 'cut':
            self.preaugment = transforms.Compose(
                [
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                ]
            )
        else:
            raise ValueError('Augmentation type not recognized')
    
    def __call__(self, x):

        if isinstance(x, np.ndarray):
            x = transforms.ToPILImage()(x)

        image = self.preprocess(self.base_transform(x))

        views = [self.preprocess(self.preaugment(x)) for _ in range(self.n_views)]

        return [image] + views

## Functions

In [8]:
def get_dataloaders(root, transform=None, csvMapFile="wordNetIDs2Classes.csv"):
    """
    Returns the dataloader of the dataset.

    Args:
        root (str): The root directory of the dataset.
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to None.
    """
    root_A = "imagenet-a"
    imageNet_A = ImageNetA(root_A, transform=transform, csvMapFile=csvMapFile)
    root_V2 = "imagenetv2-matched-frequency-format-val"
    imageNet_V2 = ImageNetV2(root_V2, transform=transform, csvMapFile=csvMapFile)

    return imageNet_A, imageNet_V2

def get_classes_names(csvMapFile="wordNetIDs2Classes.csv"):
    """
    Returns the class names of the dataset.

    Args:
        csvMapFile (str, optional): The path to the CSV file containing the mapping of WordNet IDs to class names. Defaults to "dataloaders/wordNetIDs2Classes.csv".
    """
    names = [""]*1000
    csv_file = csv.reader(open(csvMapFile, 'r'))
    for id, wordnet, name in csv_file:
        if id == 'resnet_label':
            continue
        names[int(id)] = name
    
    return names

def memo_get_datasets(augmentation, augs=64):
    """
    Returns the ImageNetA and ImageNetV2 datasets for the memo model
    Args:
        augmentation (str): What type of augmentation to use in EasyAugmenter. Can be 'augmix', 'identity' or 'cut'
        augs (int): The number of augmentations to compute. Must be greater than 1

    Returns: The ImageNetA and ImageNetV2 datasets for the memo model, with the Augmentations already applied

    """
    assert augs > 1, 'The number of augmentations must be greater than 1'
    memo_transforms = transforms.Compose([transforms.Resize(256),
                                          transforms.CenterCrop(224)])
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    transform = EasyAgumenter(memo_transforms, preprocess, augmentation, augs - 1)
    imageNet_A, imageNet_V2 = get_dataloaders('OfficeHomeDataset_10072016', transform)
    return imageNet_A, imageNet_V2

def tpt_get_transforms(augs=64):

    base_transform = transforms.Compose(
        [
            transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
        ]
    )

    preprocess = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711],
            ),
        ]
    )

    data_transform = EasyAgumenter(
        base_transform,
        preprocess,
        n_views=augs - 1,
    )

    return data_transform


def tpt_get_datasets(data_root, augmix=False, augs=64, all_classes=True):
    """
    Returns the ImageNetA and ImageNetV2 datasets.

    Parameters:
    - data_root (str): The root directory of the datasets.
    - augmix (bool): Whether to use AugMix or not.
    - augs (int): The number of augmentations to use.
    - all_classes (bool): Whether to use all classes or not.

    Returns:
    - imageNet_A (ImageNetA): The ImageNetA dataset.
    - ima_names (list): The original classnames in ImageNetA.
    - ima_custom_names (list): The retouched  classnames in ImageNetA.
    - ima_id_mapping (list): The mapping between the index of the classname and the ImageNet label

    same for ImageNetV2

    For instance the first element of ima_names corresponds to the label '90'.  After running the
    inference run the predicted output through the ima_id_mapping to recover the correct class label.

    out = tpt(inputs)
    pred = out.argmax().item()
    out_id = ima_id_mapping[pred]

    """
    base_transform = transforms.Compose(
        [
            transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
        ]
    )

    preprocess = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711],
            ),
        ]
    )

    data_transform = EasyAgumenter(
        base_transform,
        preprocess,
        augmentation=("augmix" if augmix else "cut"),
        n_views=augs - 1,
    )

    imageNet_A = ImageNetA(
        "imagenet-a", transform=data_transform
    )
    imageNet_V2 = ImageNetV2(
        "imagenetv2-matched-frequency-format-val",
        transform=data_transform,
    )

    imv2_label_mapping = list(imageNet_V2.classnames.keys())
    imv2_names = list(imageNet_V2.classnames.values())
    imv2_custom_names = [imagenet_classes[int(i)] for i in imv2_label_mapping]

    ima_label_mapping = list(imageNet_A.classnames.keys())
    ima_names = list(imageNet_A.classnames.values())
    ima_custom_names = [imagenet_classes[int(i)] for i in ima_label_mapping]

    if all_classes:
        ima_names += [name for name in imv2_names if name not in ima_names]
        ima_custom_names += [
            name for name in imv2_custom_names if name not in ima_custom_names
        ]
        ima_label_mapping += [
            map for map in imv2_label_mapping if map not in ima_label_mapping
        ]

    return (
        imageNet_A,
        ima_names,
        ima_custom_names,
        ima_label_mapping,
        imageNet_V2,
        imv2_names,
        imv2_custom_names,
        imv2_label_mapping,
    )


# Models

## TPT

In [9]:
class EasyPromptLearner(nn.Module):
    def __init__(
        self,
        device,
        clip,
        base_prompt="a photo of [CLS]",
        splt_ctx=False,
        classnames=None,
    ):
        super().__init__()

        self.device = device
        self.base_prompt = base_prompt
        self.tkn_embedder = clip.token_embedding
        # set requires_grad to False
        self.tkn_embedder.requires_grad_(False)

        self.split_ctx = splt_ctx

        self.prepare_prompts(classnames)

    def prepare_prompts(self, classnames):
        print("[PromptLearner] Preparing prompts")

        self.classnames = classnames
        # self.classnames = [cls.split(",")[0] for cls in self.classnames]

        # get numbr of classes
        self.cls_num = len(self.classnames)

        # get prompt text prefix and suffix
        txt_prefix = self.base_prompt.split("[CLS]")[0]
        txt_suffix = self.base_prompt.split("[CLS]")[1]

        # tokenize the prefix and suffix
        tkn_prefix = tokenize(txt_prefix)
        tkn_suffix = tokenize(txt_suffix)
        tkn_pad = tokenize("")
        tkn_cls = tokenize(self.classnames)

        # get the index of the last element of the prefix and suffix
        idx = torch.arange(tkn_prefix.shape[1], 0, -1)
        self.indp = torch.argmax((tkn_prefix == 0) * idx, 1, keepdim=True)
        self.inds = torch.argmax((tkn_suffix == 0) * idx, 1, keepdim=True)

        # token length for each class
        self.indc = torch.argmax((tkn_cls == 0) * idx, 1, keepdim=True)

        # get the prefix, suffix, SOT and EOT
        self.tkn_sot = tkn_prefix[:, :1]
        self.tkn_prefix = tkn_prefix[:, 1 : self.indp - 1]
        self.tkn_suffix = tkn_suffix[:, 1 : self.inds - 1]
        self.tkn_eot = tkn_suffix[:, self.inds - 1 : self.inds]
        self.tkn_pad = tkn_pad[:, 2:]

        # load segments to CUDA, be ready to be embedded
        self.tkn_sot = self.tkn_sot.to(self.device)
        self.tkn_prefix = self.tkn_prefix.to(self.device)
        self.tkn_suffix = self.tkn_suffix.to(self.device)
        self.tkn_eot = self.tkn_eot.to(self.device)
        self.tkn_pad = self.tkn_pad.to(self.device)

        self.tkn_cls = tkn_cls.to(self.device)

        # gets the embeddings
        with torch.no_grad():
            self.emb_sot = self.tkn_embedder(self.tkn_sot)
            self.emb_prefix = self.tkn_embedder(self.tkn_prefix)
            self.emb_suffix = self.tkn_embedder(self.tkn_suffix)
            self.emb_eot = self.tkn_embedder(self.tkn_eot)
            self.emb_cls = self.tkn_embedder(self.tkn_cls)
            self.emb_pad = self.tkn_embedder(self.tkn_pad)

        # take out the embeddings of the class tokens (they are different lenghts)
        self.all_cls = []
        for i in range(self.cls_num):
            self.all_cls.append(self.emb_cls[i][1 : self.indc[i] - 1])

        # prepare the prompts, they are needed for text encoding
        self.txt_prompts = [
            self.base_prompt.replace("[CLS]", cls) for cls in self.classnames
        ]
        self.tkn_prompts = tokenize(self.txt_prompts)

        # set the inital context, this will be reused at every new inference
        # this is the context that will be optimized

        if self.split_ctx:
            self.pre_init_state = self.emb_prefix.detach().clone()
            self.suf_init_state = self.emb_suffix.detach().clone()
            self.emb_prefix = nn.Parameter(self.emb_prefix)
            self.emb_suffix = nn.Parameter(self.emb_suffix)
            self.register_parameter("emb_prefix", self.emb_prefix)
            self.register_parameter("emb_suffix", self.emb_suffix)
        else:
            self.ctx = torch.cat((self.emb_prefix, self.emb_suffix), dim=1)
            self.ctx_init_state = self.ctx.detach().clone()
            self.ctx = nn.Parameter(self.ctx)
            self.register_parameter("ctx", self.ctx)

    def build_ctx(self):
        prompts = []
        for i in range(self.cls_num):
            pad_size = self.emb_cls.shape[1] - (
                self.emb_prefix.shape[1]
                + self.indc[i].item()
                + self.emb_suffix.shape[1]
            )

            if self.split_ctx:
                prefix = self.emb_prefix
                suffix = self.emb_suffix
            else:
                prefix = self.ctx[:, : self.emb_prefix.shape[1]]
                suffix = self.ctx[:, self.emb_prefix.shape[1] :]

            prompt = torch.cat(
                (
                    self.emb_sot,
                    prefix,
                    self.all_cls[i].unsqueeze(0),
                    suffix,
                    self.emb_eot,
                    self.emb_pad[:, :pad_size],
                ),
                dim=1,
            )
            prompts.append(prompt)
        prompts = torch.cat(prompts, dim=0)

        return prompts

    def forward(self):

        return self.build_ctx()

    def reset(self):

        if self.split_ctx:
            self.emb_prefix.data.copy_(self.pre_init_state)  # to be optimized
            self.emb_suffix.data.copy_(self.suf_init_state)  # to be optimized
        else:
            self.ctx.data.copy_(self.ctx_init_state)  # to be optimized


class EasyTPT(nn.Module):
    def __init__(
        self,
        device,
        base_prompt="a photo of a [CLS]",
        arch="RN50",
        splt_ctx=False,
        classnames=None,
        ttt_steps=1,
        augs=64,
        lr=0.005,
    ):
        super(EasyTPT, self).__init__()
        self.device = device

        ###TODO: tobe parametrized
        DOWNLOAD_ROOT = "~/.cache/clip"
        ###

        self.base_prompt = base_prompt
        self.ttt_steps = ttt_steps
        self.augs = augs
        self.selected_idx = None

        # Load clip
        clip, self.preprocess = load(arch, device=device, download_root=DOWNLOAD_ROOT)
        self.clip = clip
        self.dtype = clip.dtype
        self.image_encoder = clip.encode_image
        self.text_encoder = clip.encode_text

        # freeze the parameters
        for name, param in self.named_parameters():
            param.requires_grad_(False)

        # create the prompt learner
        self.prompt_learner = EasyPromptLearner(
            device, clip, base_prompt, splt_ctx, classnames
        )

        # create optimizer and save the state
        trainable_param = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(f"[EasyTPT] Training parameter: {name}")
                trainable_param.append(param)
        self.optimizer = torch.optim.AdamW(trainable_param, lr)
        self.optim_state = deepcopy(self.optimizer.state_dict())

        # breakpoint()

    def forward(self, x, top=0.10):
        """
        If x is a list of augmentations, run the confidence selection,
        otherwise just run the inference
        """
        self.eval()
        # breakpoint()
        if isinstance(x, list):
            x = torch.stack(x).to(self.device)
            logits = self.inference(x)
            if self.selected_idx is not None:
                logits = logits[self.selected_idx]
            else:
                logits, self.selected_idx = self.select_confident_samples(logits, top)
        else:
            if len(x.shape) == 3:
                x = x.unsqueeze(0)
            x = x.to(self.device)
            logits = self.inference(x)
        
        # print (f"[EasyTPT] input shape: {x.shape}")
        # print("[EasyTPT] logits shape: ", logits.shape)
        return logits

    def inference(self, x):
        with torch.no_grad():
            image_feat = self.image_encoder(x)
            image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)

        emb_prompts = self.prompt_learner()

        txt_features = self.custom_encoder(emb_prompts, self.prompt_learner.tkn_prompts)
        txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True)

        logit_scale = self.clip.logit_scale.exp()
        logits = logit_scale * image_feat @ txt_features.t()

        return logits

    def custom_encoder(self, prompts, tokenized_prompts):
        """
        Custom clip text encoder, unlike the original clip encoder this one
        takes the prompts embeddings from the prompt learner
        """
        x = prompts + self.clip.positional_embedding
        x = x.permute(1, 0, 2).type(self.dtype)  # NLD -> LND
        x = self.clip.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.clip.ln_final(x).type(self.dtype)
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = (
            x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)]
            @ self.clip.text_projection
        )

        return x

    def reset(self):
        """
        Resets the optimizer and the prompt learner to their initial state,
        this has to be run before each new test
        """
        self.optimizer.load_state_dict(deepcopy(self.optim_state))
        self.prompt_learner.reset()
        self.selected_idx = None

    def select_confident_samples(self, logits, top):
        """
        Performs confidence selection, will return the indexes of the
        augmentations with the highest confidence as well as the filtered
        logits

        Parameters:
        - logits (torch.Tensor): the logits of the model [NAUGS, NCLASSES]
        - top (float): the percentage of top augmentations to use
        """
        batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
        idx = torch.argsort(batch_entropy, descending=False)[
            : int(batch_entropy.size()[0] * top)
        ]
        return logits[idx], idx

    def tpt_avg_entropy(self, outputs):
        logits = outputs - outputs.logsumexp(
            dim=-1, keepdim=True
        )  # logits = outputs.log_softmax(dim=1) [N, 1000]
        avg_logits = logits.logsumexp(dim=0) - np.log(
            logits.shape[0]
        )  # avg_logits = logits.mean(0) [1, 1000]
        min_real = torch.finfo(avg_logits.dtype).min
        avg_logits = torch.clamp(avg_logits, min=min_real)
        return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

    def predict(self, images, niter=1):

        self.reset()

        for _ in range(niter):
            out = self(images)
            loss = self.tpt_avg_entropy(out)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        with torch.no_grad():
            out = self(images[0])
            out_id = out.argmax(1).item()
            prediction = self.prompt_learner.classnames[out_id]

        # return out_id, prediction
        return out_id

    def get_optimizer(self):
        """
        Returns the optimizer

        Returns:
        - torch.optim: the optimizer
        """
        return self.optimizer

## MEMO

In [10]:
def memo_marginal_entropy(outputs):
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
    avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)


def _modified_bn_forward(self, input):
    est_mean = torch.zeros(self.running_mean.shape, device=self.running_mean.device)
    est_var = torch.ones(self.running_var.shape, device=self.running_var.device)
    nn.functional.batch_norm(input, est_mean, est_var, None, None, True, 1.0, self.eps)
    running_mean = self.prior * self.running_mean + (1 - self.prior) * est_mean
    running_var = self.prior * self.running_var + (1 - self.prior) * est_var
    return nn.functional.batch_norm(input, running_mean, running_var, self.weight, self.bias, False, 0, self.eps)


class EasyMemo(nn.Module):
    """
    A class to wrap a neural network with the MEMO TTA method
    """

    def __init__(self, net, device, classes_mask, prior_strength: float = 1.0, lr=0.005, weight_decay=0.0001, opt='sgd',
                 niter=1, top=0.1, drop=False):
        """
        Initializes the EasyMemo model with various arguments
        Args:
            net: The model to wrap with EasyMemo
            device: The device to run the model on(usually 'CPU' or 'CUDA')
            classes_mask: The classes to consider for the model(used for Imagenet-A)
            prior_strength: The strength of the prior to use in the modified BN forward pass
            lr: The Learning rate for the optimizer of the model
            weight_decay: The weight decay for the optimizer of the model
            opt: Which optimizer to use for this model between 'sgd' and 'adamw' for the respective optimizers
            niter: The number of iterations to run the memo pass for
            top: The percentage of the top logits to consider for confidence selection
        """
        super(EasyMemo, self).__init__()

        self.drop = drop
        if self.drop:
            net.layer4.add_module('dropout', nn.Dropout(0.5, inplace=True))
        self.device = device
        self.prior_strength = prior_strength
        self.net = net.to(device)
        self.optimizer = self.memo_optimizer_model(lr=lr, weight_decay=weight_decay, opt=opt)
        self.lr = lr
        self.weight_decay = weight_decay
        self.opt = opt
        self.confidence_idx = None
        self.memo_modify_bn_pass()
        self.criterion = memo_marginal_entropy
        self.niter = niter
        self.top = top
        self.initial_state = deepcopy(self.net.state_dict())
        self.classes_mask = classes_mask

    def forward(self, x, top=-1):
        """
        Forward pass where we check which type of input we have and we call the inference on the input image Tensor
        Args:
            top: How many samples to select from the batch
            x: A Tensor of shape (N, C, H, W) or a list of Tensors of shape (N, C, H, W)

        Returns: The logits after the inference pass

        """
        self.top = top if top > 0 else self.top
        # print(f"Shape forward: {x.shape}")
        if isinstance(x, list):
            x = torch.stack(x).to(self.device)
            # print(f"Shape forward: {x.shape}")
            logits = self.inference(x)
            logits, self.confidence_idx = self.topk_selection(logits)
        else:
            if len(x.shape) == 3:
                x = x.unsqueeze(0)
            x = x.to(self.device)
            logits = self.inference(x)

        # print(f"[EasyMemo] input shape: {x.shape}")
        # print(f"[EasyMemo] logits shape: {logits.shape}")
        return logits

    def inference(self, x):
        """
        Return the logits of the image in input x
        Args:
            x: A Tensor of shape (N, C, H, W) of an Image

        Returns: The logits for that Tensor image

        """
        self.net.eval()
        outputs = self.net(x)

        out_app = torch.zeros(outputs.shape[0], len(self.classes_mask)).to(self.device)
        for i, out in enumerate(outputs):
            out_app[i] = out[self.classes_mask]
        return out_app

    def predict(self, x, niter=1):
        """
        Predicts the class of the input x, which is an image
        Args:
            niter: The number of iteration on which to run the memo pass
            x: Tensor of shape (N, C, H, W)

        Returns: The predicted classes

        """
        self.niter = niter
        if self.drop:
            self.net.train()
        else:
            self.net.eval()

        for iteration in range(self.niter):
            self.optimizer.zero_grad()
            outputs = self.forward(x)
            outputs, _ = self.topk_selection(outputs)
            loss = self.criterion(outputs)
            loss.backward()
            self.optimizer.step()

        with torch.no_grad():
            outputs = self.net(x[0].unsqueeze(0).to(self.device))
            outs = torch.zeros(outputs.shape[0], len(self.classes_mask)).to(self.device)
            for i, out in enumerate(outputs):
                outs[i] = out[self.classes_mask]
            predicted = outs.argmax(1).item()

        return predicted

    def reset(self):
        """Resets the model to its initial state"""
        del self.optimizer
        self.optimizer = self.memo_optimizer_model(lr=self.lr, weight_decay=self.weight_decay, opt=self.opt)
        self.confidence_idx = None
        self.net.load_state_dict(deepcopy(self.initial_state))

    def memo_modify_bn_pass(self):
        print('modifying BN forward pass')
        nn.BatchNorm2d.prior = self.prior_strength
        nn.BatchNorm2d.forward = _modified_bn_forward

    def memo_optimizer_model(self, lr=0.005, weight_decay=0.0001, opt='sgd'):
        """
        Initializes the optimizer for the memo model
        Args:
            lr: The learning rate for the optimizer
            weight_decay: The weight decay for the optimizer
            opt: Which optimizer to use

        Returns: The optimizer for the memo model

        """
        if opt == 'sgd':
            optimizer = optim.SGD(self.net.parameters(), lr=lr, weight_decay=weight_decay)
        elif opt == 'adamw':
            optimizer = optim.AdamW(self.net.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError('Invalid optimizer selected')
        return optimizer

    def memo_adapt_single(self, inputs):
        """
        A single step of memo adaptation
        Args:
            inputs: A tensor of shape (N, C, H, W)

        """
        self.net.eval()
        assert self.niter > 0 and isinstance(self.niter, int), 'niter must be a positive integer'
        for iteration in range(self.niter):
            self.optimizer.zero_grad()
            outputs = self.net(inputs)
            loss = self.criterion(outputs)
            loss.backward()
            self.optimizer.step()

    def memo_test_single(self, image, label):
        """
        Tests the model on a single image and returns the correctness and confidence
        Args:
            image: A tensor of shape (N, C, H, W)
            label: The correct label for the test

        Returns: The correctness and confidence of the prediction

        """
        with torch.no_grad():
            outputs = self.net(image.to(device=self.device))
            _, predicted = outputs.max(1)
            confidence = nn.functional.softmax(outputs, dim=1).squeeze()[predicted].item()
        correctness = 1 if predicted.item() == label else 0
        return correctness, confidence

    def topk_selection(self, logits):
        """
        Selects the top k logits based on the batch entropy
        Args:
            logits: A tensor of shape (N, C)

        Returns: The filtered logits and the indices of the selected logits

        """
        batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
        selected_idx = torch.argsort(batch_entropy, descending=False)[: int(batch_entropy.size()[0] * self.top)]
        return logits[selected_idx], selected_idx

    def dropout_train(self, x):
        self.net.train()
        outputs = self.forward(x)
        outputs, _ = self.topk_selection(outputs)

        return outputs

## Ensemble

In [11]:
class Ensemble(nn.Module):
    """
    Ensemble class. Implements an ensemble of models with entropy minimization.

    Attributes:
        models (list): A list of models to be used in the ensemble.
        temps (list): A list of temperature values corresponding to each model.
        test_single_models (bool): Whether to test each individual model in addition to the ensemble.
        backward (bool): Whether to perform the entropy minimization step.
        device (str): The device to be used for computation.
    """
    def __init__(self, models, temps, device="cuda", test_single_models=False, no_backwards=False):
        """
        Initializes an Ensemble object.

        Args:
            models (list): A list of models to be used in the ensemble.
            temps (list): A list of temperature values corresponding to each model.
            device (str, optional): The device to be used for computation. Defaults to "cuda".
            test_single_models (bool, optional): Whether to test each individual model in addition to the ensemble. Defaults to False.
            no_backwards (bool, optional): Whether to perform the entropy minimization step. Defaults to False.
        """
        super(Ensemble, self).__init__()
        self.models = models
        self.temps = temps
        self.test_single_models = test_single_models
        self.device = device
        self.no_backwards = no_backwards

    def entropy(self, logits):
        """
        Computes the entropy of a set of logits.

        Args:
            logits (torch.Tensor): The logits to compute the entropy of.
        """
        return -(torch.exp(logits) * logits).sum(dim=-1)

    def marginal_distribution(self, models_logits):
        """
        Computes the marginal distribution of the ensemble.

        Args:
            models_logits (torch.Tensor): The logits of the models in the ensemble.
        """
        # average logits for each model
        avg_models_logits = torch.Tensor(models_logits.shape[0], models_logits.shape[2]).to(self.device)
        for i, model_logits in enumerate(models_logits):
            avg_outs = torch.logsumexp(model_logits, dim=0) - torch.log(torch.tensor(model_logits.shape[0]))
            min_real = torch.finfo(avg_outs.dtype).min
            avg_outs = torch.clamp(avg_outs, min=min_real)
            avg_outs /= self.temps[i]
            avg_models_logits[i] = torch.log_softmax(avg_outs, dim=0)

        with torch.no_grad():
            entropies = torch.stack([self.entropy(logits) for logits in avg_models_logits]).to(self.device)
            sum_entropies = torch.sum(entropies, dim=0)
            scale = torch.stack([sum_entropies/entopy for entopy in entropies]).to(self.device)
            #normalize sum to 1
            scale = scale / torch.sum(scale)

        print("\t\t[Ensemble] Entropies: ", entropies)
        print("\t\t[Ensemble] Scales: ", scale)

        avg_logits = torch.sum(torch.stack([scale[i].item() * avg_models_logits[i] for i in range(len(avg_models_logits))]), dim=0)

        return avg_logits

    def get_models_outs(self, inputs, top=0.1):
        """
        Computes the outputs of the models in the ensemble.

        Args:
            inputs (list): A list of inputs to be fed to the models.
            top (float, optional): The top percentage of the outputs to be used. Defaults to 0.1.
        """
        model_outs = torch.stack([model(inputs[i], top).to(self.device) for i, model in enumerate(self.models)]).to(self.device)
        return model_outs.to(self.device)

    def get_models_predictions(self, inputs):
        """
        Computes the predictions of the single models in the ensemble.

        Args:
            inputs (list): A list of inputs to be fed to the models.
        """
        models_pred = [model.predict(inputs[i]) for i, model in enumerate(self.models)]
        return models_pred

    def entropy_minimization(self, inputs, niter=1, top=0.1):
        """
        Test time adaptation step. Minimizes the entropy of the ensemble's predictions.

        Args:
            inputs (list): A list of inputs to be fed to the models.
            niter (int, optional): The number of iterations to perform. Defaults to 1.
            top (float, optional): The top percentage of the outputs to be used. Defaults to 0.1.
        """
        for i in range(niter):
            outs = self.get_models_outs(inputs, top)
            avg_logit = self.marginal_distribution(outs)

            loss = self.entropy(avg_logit)
            loss.backward()
            for model in self.models:
                model.optimizer.step()
                model.optimizer.zero_grad()

    def forward(self, inputs, niter=1, top=0.1):
        """
        Forward pass of the ensemble.

        Args:
            inputs (list): A list of inputs to be fed to the models.
            niter (int, optional): The number of iterations to perform. Defaults to 1.
            top (float, optional): The top percentage of the outputs to be used. Defaults to 0.1.
        """
        # get models outputs
        self.reset()
        models_pred = self.get_models_predictions(inputs)

        self.reset()
        self.entropy_minimization(inputs, niter, top)
            
        with torch.no_grad():
            outs = self.get_models_outs([i[0] for i in inputs], top)
            avg_logit = self.marginal_distribution(outs)
            prediction = torch.argmax(avg_logit, dim=0)

        if self.no_backwards:
            self.reset()
            outs = self.get_models_outs(inputs, top)
            avg_logit = self.marginal_distribution(outs)
            prediction_no_back = torch.argmax(avg_logit, dim=0)

        return models_pred, prediction_no_back, prediction

    def reset(self):
        """
        Resets the models in the ensemble.
        """
        for model in self.models:
            model.reset()

# Test Time Adaptation

## Prepare Models

In [12]:
def TPT(device="cuda", naug=30, base_prompt="A bad photo of a [CLS].", arch="RN50", splt_ctx= True, A=True):
    # prepare TPT
    if not torch.cuda.is_available():
        print("Using CPU this is no bueno")
    else:
        print("Using GPU, brace yourself!")

    datasetRoot = "OfficeHomeDataset_10072016"
    imageNetA, _, imageNetACustomNames, imageNetAMap, imageNetV2, _, imageNetV2CustomNames, imageNetV2Map = tpt_get_datasets(datasetRoot, augs=naug, all_classes=False)
    
    if A:
        dataset = imageNetA
        classnames = imageNetACustomNames
        mapping = imageNetAMap
    else:
        dataset = imageNetV2
        classnames = imageNetV2CustomNames
        mapping = imageNetV2Map
    
    tpt = EasyTPT(
        base_prompt=base_prompt,
        arch=arch,
        splt_ctx=splt_ctx,
        classnames=classnames,
        device=device
    )
    
    return tpt, dataset, mapping

def memo(device="cuda", prior_strength=0.94, naug=30, A=True):
    # prepare MEMO
    imageNet_A, imageNet_V2 = memo_get_datasets(augmentation='cut', augs=naug)
    dataset = imageNet_A if A else imageNet_V2

    mapping = list(dataset.classnames.keys())
    for i,id in enumerate(mapping):
        mapping[i] = int(id)
    
    rn50 = resnet50(weights=ResNet50_Weights.DEFAULT)
    memo = EasyMemo(rn50, device=device, classes_mask=mapping, prior_strength=prior_strength)
    
    return memo, dataset

## TTA

In [17]:
def test(tpt_model:EasyTPT, memo_model, tpt_data, mapping, memo_data, device="cuda", niter=1, top=0.1, no_backwards=False, testSingleModels=False):
    correct = 0
    correct_no_back = 0
    correctSingle = [0, 0]
    cnt = 0

    class_names = get_classes_names()
    models_names = ["TPT", "MEMO"]

    TPT_temp = 1.55
    MEMO_temp = 0.7
    temps = [TPT_temp, MEMO_temp]

    #shuffle the data
    indx = np.random.permutation(range(len(tpt_data)))

    model = Ensemble(models=[tpt_model, memo_model], temps=temps, 
                     device=device, test_single_models=testSingleModels, 
                     no_backwards=no_backwards)
    print("Ensemble model created starting TTA, samples:",len(indx))
    for i in indx:
        cnt += 1 
        img_TPT = tpt_data[i]["img"]
        img_MEMO = memo_data[i]["img"]
        data = [img_TPT, img_MEMO]
        
        label = int(tpt_data[i]["label"])
        label2 = int(memo_data[i]["label"])
        assert label == label2 #check if the labels are the same
        name = tpt_data[i]["name"]

        print (f"Testing on {i} - name: {name} - label: {label}")

        models_out, pred_no_back, prediction = model(data, niter=niter, top=0.1)
        models_out = [int(mapping[model_out]) for model_out in models_out]
        prediction = int(mapping[prediction])
        
        if testSingleModels:
            for i, model_out in enumerate(models_out):
                if label == model_out:
                    correctSingle[i] += 1
                
                print(f"\t{models_names[i]} model accuracy: {correctSingle[i]}/{cnt} - predicted class {model_out}: {class_names[model_out]} - tested: {cnt} / {len(tpt_data)}")

        if no_backwards:
            pred_no_back = int(mapping[pred_no_back])
            if label == pred_no_back:
                correct_no_back += 1
            print(f"\tSimple Ens accuracy: {correct_no_back}/{cnt} - predicted class {pred_no_back}: {class_names[pred_no_back]} - tested: {cnt} / {len(tpt_data)}")

        if label == prediction:
            correct += 1
            
        print(f"\tEnsemble accuracy: {correct}/{cnt} - predicted class {prediction}: {class_names[prediction]} - tested: {cnt} / {len(tpt_data)}")

## Run

In [None]:
tpt_model, tpt_data, mapping = TPT(TPT_device, naug=naug, A=imageNetA)
    
memo_model, memo_data = memo(MEMO_device, naug=naug, A=imageNetA)

if (imageNetA):
    print("Testing on ImageNet-A")
else:
    print("Testing on ImageNet-V2")

test(tpt_model, memo_model, 
     tpt_data, mapping, memo_data, 
     device, niter, top, no_backwards, testSingleModels)

Using GPU, brace yourself!
[PromptLearner] Preparing prompts
[EasyTPT] Training parameter: prompt_learner.emb_prefix
[EasyTPT] Training parameter: prompt_learner.emb_suffix
modifying BN forward pass
Testing on ImageNet-A
Ensemble model created starting TTA, samples: 7500
Testing on 641 - name: eft - label: 27


  avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])


		[Ensemble] Entropies:  tensor([1.6317, 2.3735], device='cuda:0')
		[Ensemble] Scales:  tensor([0.5926, 0.4074], device='cuda:0')
		[Ensemble] Entropies:  tensor([0.9218, 0.3971], device='cuda:0')
		[Ensemble] Scales:  tensor([0.3011, 0.6989], device='cuda:0')
		[Ensemble] Entropies:  tensor([1.6317, 2.3735], device='cuda:0')
		[Ensemble] Scales:  tensor([0.5926, 0.4074], device='cuda:0')
	TPT model accuracy: 0/1 - predicted class 57: garter snake, grass snake - tested: 1 / 7500
	MEMO model accuracy: 1/1 - predicted class 27: eft - tested: 1 / 7500
	Simple Ens accuracy (no back): 0/1 - predicted class 57: garter snake, grass snake - tested: 1 / 7500
	Ensemble accuracy: 0/1 - predicted class 79: centipede - tested: 1 / 7500
Testing on 3787 - name: cockroach, roach - label: 314
		[Ensemble] Entropies:  tensor([4.3946, 2.1879], device='cuda:0')
		[Ensemble] Scales:  tensor([0.3324, 0.6676], device='cuda:0')
		[Ensemble] Entropies:  tensor([3.8439, 0.0039], device='cuda:0')
		[Ensemble] S