Install Dependencies

In [None]:
#!pip install git+https://github.com/aengusl/spawrious.git

In [None]:
import argparse

import torch
import torch.optim as optim
from torch import nn
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision import models
from tqdm import tqdm
from tqdm.auto import tqdm
import timm
import wandb
import os

#0. Spawrious Source Code (For Editing)

Edited so to use fewer images per folder


In [None]:
import os
import tarfile
import urllib
import urllib.request
from typing import Any, Tuple

import torch
from PIL import Image
from torch.utils.data import ConcatDataset, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import timm
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

# MODEL_NAME = "vit_so400m_patch14_siglip_384"
# MODEL_NAME = 'swin_base_patch4_window7_224.ms_in22k_ft_in1k'
# MODEL_NAME = 'deit3_base_patch16_224.fb_in22k_ft_in1k'
# MODEL_NAME = 'beit_base_patch16_224.in22k_ft_in22k_in1k'
# MODEL_NAME = 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k'
# MODEL_NAME = 'levit_128s.fb_dist_in1k'

MODEL_NAME = None

def set_model_name(name):
    global MODEL_NAME
    MODEL_NAME = name


def _extract_dataset_from_tar(
    tar_file_name: str, data_dir: str
) -> None:
    tar_file_dst = os.path.join(data_dir, tar_file_name)
    print("Extracting dataset...")
    tar = tarfile.open(tar_file_dst, "r:gz")
    tar.extractall(os.path.dirname(tar_file_dst))
    tar.close()


def _download_dataset_if_not_available(
    dataset_name: str, data_dir: str, remove_tar_after_extracting: bool = True
) -> None:
    """
    datasets.txt file, which is present in the data_dir, is used to check if the dataset is already extracted. If the dataset is already extracted, then the tar file is not downloaded again.
    """
    data_dir = data_dir.split("/spawrious224/")[
        0
    ]  # in case people pass in the wrong root_dir
    os.makedirs(data_dir, exist_ok=True)
    dataset_name = dataset_name.lower()
    if dataset_name.split("_")[0] == "m2m":
        dataset_name = "entire_dataset"
    url_dict = {
        "entire_dataset": "https://www.dropbox.com/s/hofkueo8qvaqlp3/spawrious224__entire_dataset.tar.gz?dl=1",
        "o2o_easy": "https://www.dropbox.com/s/kwhiv60ihxe3owy/spawrious224__o2o_easy.tar.gz?dl=1",
        "o2o_medium": "https://www.dropbox.com/s/x03gkhdwar5kht4/spawrious224__o2o_medium.tar.gz?dl=1",
        "o2o_hard": "https://www.dropbox.com/s/p1ry121m2gjj158/spawrious224__o2o_hard.tar.gz?dl=1",
        # "m2m": "https://www.dropbox.com/s/5usem63nfub266y/spawrious__m2m.tar.gz?dl=1",
    }
    tar_file_name = f"spawrious224__{dataset_name}.tar.gz"
    tar_file_dst = os.path.join(data_dir, tar_file_name)
    url = url_dict[dataset_name]

    # check if the dataset is already extracted
    if _check_images_availability(data_dir, dataset_name):
        print("Dataset already downloaded and extracted.")
        return
    # check if the tar file is already downloaded
    else:
        if os.path.exists(tar_file_dst):
            print("Dataset already downloaded. Extracting...")
            _extract_dataset_from_tar(
                tar_file_name, data_dir
            )
            return
        # download the tar file and extract from it
        else:
            print("Dataset not found. Downloading...")
            response = urllib.request.urlopen(url)
            total_size = int(response.headers.get("Content-Length", 0))
            block_size = 1024
            # Track progress of download
            progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True)
            with open(tar_file_dst, "wb") as f:
                while True:
                    buffer = response.read(block_size)
                    if not buffer:
                        break
                    f.write(buffer)
                    progress_bar.update(len(buffer))
            progress_bar.close()
            print("Dataset downloaded. Extracting...")
            _extract_dataset_from_tar(
                tar_file_name, data_dir
            )
            return


class CustomImageFolder(Dataset):
    """
    A class that takes one folder at a time and loads a set number of images in a folder and assigns them a specific class
    """

    def __init__(
        self, folder_path, class_index, location_index, limit=None, transform=None
    ):
        self.folder_path = folder_path
        self.class_index = class_index
        self.location_index = location_index
        self.image_paths = [
            os.path.join(folder_path, img)
            for img in os.listdir(folder_path)
            if img.endswith((".png", ".jpg", ".jpeg"))
        ]
        if limit:
            self.image_paths = self.image_paths[:limit]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index: int) -> Tuple[Any, Any, Any]:
        img_path = self.image_paths[index]
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        class_label = torch.tensor(self.class_index, dtype=torch.long)
        location_label = torch.tensor(self.location_index, dtype=torch.long)
        return img, class_label, location_label


class MultipleDomainDataset:
    N_STEPS = 5001  # Default, subclasses may override
    CHECKPOINT_FREQ = 100  # Default, subclasses may override
    N_WORKERS = 8  # Default, subclasses may override
    ENVIRONMENTS = None  # Subclasses should override
    INPUT_SHAPE = None  # Subclasses should override

    def __getitem__(self, index):
        return self.datasets[index]

    def __len__(self):
        return len(self.datasets)


def build_combination(benchmark_type, group, test, filler=None):
    total = 3168
    combinations = {}
    if "m2m" in benchmark_type:
        counts = [total, total]
        combinations["train_combinations"] = {
            ("bulldog",): [(group[0], counts[0]), (group[1], counts[1])],
            ("dachshund",): [(group[1], counts[0]), (group[0], counts[1])],
            ("labrador",): [(group[2], counts[0]), (group[3], counts[1])],
            ("corgi",): [(group[3], counts[0]), (group[2], counts[1])],
        }
        combinations["test_combinations"] = {
            ("bulldog",): [test[0], test[1]],
            ("dachshund",): [test[1], test[0]],
            ("labrador",): [test[2], test[3]],
            ("corgi",): [test[3], test[2]],
        }
    if "entire_dataset" in benchmark_type:
        counts = [int(0.5*total), total]
        combinations["train_combinations"] = {
            ("bulldog",): [(group[0], counts[0]), (group[1], counts[0]), (group[2], counts[0]), (group[3], counts[0])],
            ("dachshund",): [(group[0], counts[0]), (group[1], counts[0]), (group[2], counts[0]), (group[3], counts[0])],
            ("labrador",): [(group[0], counts[0]), (group[1], counts[0]), (group[2], counts[0]), (group[3], counts[0])],
            ("corgi",): [(group[0], counts[0]), (group[1], counts[0]), (group[2], counts[0]), (group[3], counts[0])],
        }
        combinations["test_combinations"] = {
            ("bulldog",): [test[0], test[0]],
            ("dachshund",): [test[1], test[1]],
            ("labrador",): [test[2], test[2]],
            ("corgi",): [test[3], test[3]],
        }
    else:
        counts = [int(0.97 * total), int(0.87 * total)]

        '''
        combinations["train_combinations"] = {
            ("bulldog",): [(group[0], counts[0]), (group[0], counts[1])],
            ("dachshund",): [(group[1], counts[0]), (group[1], counts[1])],
            ("labrador",): [(group[2], counts[0]), (group[2], counts[1])],
            ("corgi",): [(group[3], counts[0]), (group[3], counts[1])],
            #("bulldog", "dachshund", "labrador", "corgi"): [
            #    (filler, total - counts[0]),
            #    (filler, total - counts[1]),
            #],
        }
        '''
        combinations["train_combinations"] = {
            ("bulldog",): [(group[0], counts[0]), (group[0], counts[1])],
            ("dachshund",): [(group[1], counts[0]), (group[1], counts[1])],
            ("labrador",): [(group[2], counts[0]), (group[2], counts[1])],
            ("corgi",): [(group[3], counts[0]), (group[3], counts[1])],
            #("bulldog", "dachshund", "labrador", "corgi"): [
            #    (filler, total - counts[0]),
            #    (filler, total - counts[1]),
            #],
        }
        combinations["test_combinations"] = {
            ("bulldog",): [test[0], test[0]],
            ("dachshund",): [test[1], test[1]],
            ("labrador",): [test[2], test[2]],
            ("corgi",): [test[3], test[3]],
        }
    return combinations


def _get_combinations(benchmark_type: str) -> Tuple[dict, dict]:
    combinations = {
        "o2o_easy": (
            ["desert", "jungle", "dirt", "snow"],
            ["dirt", "snow", "desert", "jungle"], # Original
            None,
        ),

        "entire_dataset": (
            ["desert", "jungle", "dirt", "snow"],
            ["dirt", "snow", "desert", "jungle"], # Original
            None,
        ),

        "o2o_medium": (
            ["mountain", "beach", "dirt", "jungle"],
            ["jungle", "dirt", "beach", "snow"],
            "desert",
        ),
        "o2o_hard": (
            ["jungle", "mountain", "snow", "desert"],
            ["mountain", "snow", "desert", "jungle"],
            "beach",
        ),
        "m2m_hard": (
            ["dirt", "jungle", "snow", "beach"],
            ["snow", "beach", "dirt", "jungle"],
            None,
        ),
        "m2m_easy": (
            ["desert", "mountain", "dirt", "jungle"],
            ["dirt", "jungle", "mountain", "desert"],
            None,
        ),
        "m2m_medium": (
            ["beach", "snow", "mountain", "desert"],
            ["desert", "mountain", "beach", "snow"],
            None,
        ),
    }
    if benchmark_type not in combinations:
        raise ValueError("Invalid benchmark type")
    group, test, filler = combinations[benchmark_type]
    return build_combination(benchmark_type, group, test, filler)


class SpawriousBenchmark(MultipleDomainDataset):
    ENVIRONMENTS = ["Test", "SC_group_1", "SC_group_2"]
    input_shape = (3, 224, 224)
    num_classes = 4
    class_list = ["bulldog", "corgi", "dachshund", "labrador"]
    locations_list = ["desert", "jungle", "dirt", "mountain", "snow", "beach"]

    def __init__(self, benchmark, root_dir, augment=True):
        combinations = _get_combinations(benchmark.lower())
        self.type1 = benchmark.lower().startswith("o2o")
        train_datasets, test_datasets = self._prepare_data_lists(
            combinations["train_combinations"],
            combinations["test_combinations"],
            root_dir,
            augment,
        )
        self.datasets = [ConcatDataset(test_datasets)] + train_datasets

    def get_train_dataset(self):
        return torch.utils.data.ConcatDataset(self.datasets[1:])

    def get_test_dataset(self):
        return self.datasets[0]

    # Prepares the train and test data lists by applying the necessary transformations.
    def _prepare_data_lists(
        self, train_combinations, test_combinations, root_dir, augment
    ):
        backbone = timm.create_model(
            # "vit_so400m_patch14_siglip_384",
            MODEL_NAME,
            pretrained=True,
            num_classes=0,
        ).eval()
        self.data_config = timm.data.resolve_model_data_config(backbone)
        test_transforms = timm.data.create_transform(
            **self.data_config, is_training=False
        )

        # test_transforms = transforms.Compose(
        #     [
        #         transforms.Resize((self.input_shape[1], self.input_shape[2])),
        #         transforms.transforms.ToTensor(),
        #         transforms.Normalize(
        #             mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        #         ),
        #     ]
        # )

        if augment:
            train_transforms = timm.data.create_transform(
                **self.data_config, is_training=True
            )
        else:
            train_transforms = test_transforms
        print("Creating Training Dataset:")
        train_data_list = self._create_data_list(
            train_combinations, root_dir, train_transforms
        )
        print("Creating Testing Dataset:")
        test_data_list = self._create_data_list(
            test_combinations, root_dir, test_transforms
        )

        return train_data_list, test_data_list

    # Creates a list of datasets based on the given combinations and transformations.
    def _create_data_list(self, combinations, root_dir, transforms):
        data_list = []
        if isinstance(combinations, dict):
            # Build class groups for a given set of combinations, root directory, and transformations.
            for_each_class_group = []
            cg_index = 0
            for classes, comb_list in combinations.items():
                for_each_class_group.append([])
                for ind, location_limit in enumerate(comb_list):
                    if isinstance(location_limit, tuple):
                        location, limit = location_limit
                    else:
                        location, limit = location_limit, None
                    cg_data_list = []
                    for cls in classes:
                        path = os.path.join(
                            root_dir,
                            "spawrious224",
                            f"{0 if not self.type1 else ind}/{location}/{cls}",
                        )
                        print(f"    Combination: {location}/{cls}")
                        print(f"    Limit: {limit}")
                        data = CustomImageFolder(
                            folder_path=path,
                            class_index=self.class_list.index(cls),
                            location_index=self.locations_list.index(location),
                            limit=limit,
                            transform=transforms,
                        )
                        cg_data_list.append(data)

                    for_each_class_group[cg_index].append(ConcatDataset(cg_data_list))
                cg_index += 1

            for group in range(len(for_each_class_group[0])):
                data_list.append(
                    ConcatDataset(
                        [
                            for_each_class_group[k][group]
                            for k in range(len(for_each_class_group))
                        ]
                    )
                )
        else:
            for location in combinations:
                path = os.path.join(root_dir, f"{0}/{location}/")
                data = ImageFolder(root=path, transform=transforms)
                data_list.append(data)

        return data_list


def _check_images_availability(root_dir: str, dataset_type: str) -> bool:
    # Get the combinations for the given dataset type
    root_dir = root_dir.split("/spawrious224/")[
        0
    ]  # in case people pass in the wrong root_dir
    if dataset_type == "entire_dataset":
        for dataset in ["0", "1", "domain_adaptation_ds"]:
            for location in ["snow", "jungle", "desert", "dirt", "mountain", "beach"]:
                for cls in ["bulldog", "corgi", "dachshund", "labrador"]:
                    path = os.path.join(
                        root_dir, "spawrious224", f"{dataset}/{location}/{cls}"
                    )
                    if not os.path.exists(path) or not any(
                        img.endswith((".png", ".jpg", ".jpeg"))
                        for img in os.listdir(path)
                    ):
                        return False
        return True
    combinations = _get_combinations(dataset_type.lower())

    # Extract the train and test combinations
    train_combinations = combinations["train_combinations"]
    test_combinations = combinations["test_combinations"]

    # Check if the relevant images for each combination are present in the root directory
    for combination in [train_combinations, test_combinations]:
        for classes, comb_list in combination.items():
            for ind, location_limit in enumerate(comb_list):
                if isinstance(location_limit, tuple):
                    location, limit = location_limit
                else:
                    location, limit = location_limit, None

                for cls in classes:
                    path = os.path.join(
                        root_dir,
                        "spawrious224",
                        f"{0 if not dataset_type.lower().startswith('o2o') else ind}/{location}/{cls}",
                    )

                    # If the path does not exist or there are no relevant images, return False
                    if not os.path.exists(path) or not any(
                        img.endswith((".png", ".jpg", ".jpeg"))
                        for img in os.listdir(path)
                    ):
                        return False

    # If all the required images are present, return True
    return True


def get_spawrious_dataset(root_dir: str, dataset_name: str = "entire_dataset"):
    """
    Returns the dataset as a torch dataset, and downloads dataset if dataset is not already available.

    By default, the entire dataset is downloaded, which is necessary for m2m experiments, and domain adaptation experiments
    """
    root_dir = root_dir.split("/spawrious224/")[
        0
    ]  # in case people pass in the wrong root_dir
    assert dataset_name.lower() in {
        "o2o_easy",
        "o2o_medium",
        "o2o_hard",
        "m2m_easy",
        "m2m_medium",
        "m2m_hard",
        "m2m",
        "entire_dataset",
    }, f"Invalid dataset type: {dataset_name}"
    _download_dataset_if_not_available(dataset_name, root_dir)
    # TODO: get m2m to use entire dataset, not half of it
    return SpawriousBenchmark(dataset_name, root_dir, augment=True)

#1. Mount Drive, Define Directories for Saving Results

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#2. Gather Dataset for Training and Testing
Using the SPAWRIOUS dataset at https://github.com/aengusl/spawrious/blob/main/example.py

In [None]:
project_dir = "/content/drive/MyDrive/CS2822_Final_Project"

drive_data_dir = os.path.join(project_dir, "Datasets/spawrious224__o2o_easy.tar.gz")
dataset_name = drive_data_dir.split('/')[-1]

# Move Tar file from drive to local dir
!mkdir /content/data
!cp $drive_data_dir /content/data/

In [None]:
# Select a model and run the package's dataloader function
set_model_name("resnet18.a1_in1k")

#3. Helpers for Training

In [None]:
# Training loop
def train(
    model: Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: Optimizer,
    criterion: Module,
    num_epochs: int,
    device: torch.device,
) -> None:
    for epoch in tqdm(range(num_epochs), desc="Training. Epochs", leave=False):
        running_loss = 0.0
        for inputs, labels, _ in tqdm(train_loader):  # third item is the location label
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(
            f"Epoch {epoch + 1}: Training Loss: {running_loss / len(train_loader):.3f}"
        )
        print("Evaluating on validation set...")
        val_acc = evaluate(model, val_loader, device)
        wandb.log(
            {"train_loss": running_loss / len(train_loader), "val_acc": val_acc},
            step=epoch,
        )

# Eval loop
def evaluate(model: Module, loader: DataLoader, device: torch.device) -> float:
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels, _ in tqdm(
            loader, desc="Evaluating", leave=False
        ):  # third item is the location label
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    print(f"Acc: {acc:.3f}%")
    return acc




#4. Initialize Model, Run Training

Helpers for loading and initializing model

In [None]:
# Class to modify model
class ClassifierOnTop(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = timm.create_model(
            # "vit_so400m_patch14_siglip_384",
            MODEL_NAME,
            pretrained=True,
            num_classes=0,
        ).eval()
        self.linear = nn.Linear(1152, num_classes)
        if MODEL_NAME == 'swin_base_patch4_window7_224.ms_in22k_ft_in1k':
            self.linear = nn.Linear(1024, num_classes)
        elif MODEL_NAME == 'deit3_base_patch16_224.fb_in22k_ft_in1k':
            self.linear = nn.Linear(768, num_classes)
        elif MODEL_NAME == 'beit_base_patch16_224.in22k_ft_in22k_in1k':
            self.linear = nn.Linear(768, num_classes)
        elif MODEL_NAME == 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k':
            self.linear = nn.Linear(768, num_classes)
        elif MODEL_NAME == 'levit_128s.fb_dist_in1k':
            self.linear = nn.Linear(384, num_classes)

    def forward(self, x):
        with torch.no_grad():
            x = self.backbone(x)
        return self.linear(x)


def get_model(args: argparse.Namespace) -> Module:
    if args.model == "siglip":
        model = ClassifierOnTop(num_classes=4)
    else:
        model = models.resnet18(pretrained=True)
        model.fc = torch.nn.Linear(512, 4)
    return model


In [None]:
def main(dataset) -> None:
    args = Args(model="Resnet", dataset=dataset)
    experiment_name = f"{dataset}_{MODEL_NAME.split('_')[0]}-e={args.num_epochs}-lr={args.lr}"
    experiment_name = f"{experiment_name}_limit=20"

    wandb.init(project="spawrious", name=experiment_name, config=args)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    spawrious = get_spawrious_dataset(dataset_name=args.dataset, root_dir=args.data_dir)
    train_set = spawrious.get_train_dataset()
    test_set = spawrious.get_test_dataset()
    val_size = int(len(train_set) * args.val_split)
    train_set, val_set = torch.utils.data.random_split(
        train_set, [len(train_set) - val_size, val_size]
    )
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
    )
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    model = get_model(args)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    train(
        model,
        train_loader,
        val_loader,
        optimizer,
        criterion,
        args.num_epochs,
        device,
    )
    print("Finished training, now evaluating on test set.")
    torch.save(model.state_dict(), f"{experiment_name}.pt")
    test_acc = evaluate(model, test_loader, device)
    wandb.log({"final_test_acc": test_acc}, step=args.num_epochs)



In [None]:
# Artificially create args
class Args:
    def __init__(self, model="siglip",
                 dataset="o2o_easy", data_dir='/content/data/',
                 num_epochs=12, val_split=0.1,
                 batch_size=128,num_workers=2, lr=0.01, momentum=0.9):
        self.model = model
        self.dataset = dataset
        self.data_dir = data_dir
        self.num_epochs = num_epochs
        self.val_split = val_split
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.momentum = momentum

args = Args(model="Resnet", dataset="entire_dataset", num_epochs=10)

In [None]:
datasets_dir = os.path.join(project_dir, "Datasets")
!cp /content/data/spawrious224__entire_dataset.tar.gz $datasets_dir

In [None]:
# Run Training
if __name__ == "__main__":
    dataset_choices = [
            #"o2o_easy",
            "entire_dataset",
            #"o2o_medium",
            #"o2o_hard",
            #"m2m_easy",
            #"m2m_medium",
            #"m2m_hard",
        ]
    MODEL_NAME = 'resnet18.a1_in1k'
    # MODEL_NAME = "vit_so400m_patch14_siglip_384"
    # MODEL_NAME = 'swin_base_patch4_window7_224.ms_in22k_ft_in1k'
    # MODEL_NAME = 'deit3_base_patch16_224.fb_in22k_ft_in1k'
    # MODEL_NAME = 'beit_base_patch16_224.in22k_ft_in22k_in1k'
    # MODEL_NAME = 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k'
    # MODEL_NAME = 'levit_128s.fb_dist_in1k'
    model_name_choices = [
        'resnet18.a1_in1k',
        #'vit_so400m_patch14_siglip_384',
        #'swin_base_patch4_window7_224.ms_in22k_ft_in1k',
        #'deit3_base_patch16_224.fb_in22k_ft_in1k',
        #'beit_base_patch16_224.in22k_ft_in22k_in1k',
        # 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k',
        #'levit_128s.fb_dist_in1k',
    ]

    # Train just our model

    # If you want to run training on all model architectures
    for dataset in dataset_choices:
        for model_name in model_name_choices:

            MODEL_NAME=model_name
            set_model_name(MODEL_NAME)

            main(dataset)



VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
final_test_acc,▁
train_loss,█▂▁▁▁▁▁▁▁▁▁▁
val_acc,▁█▅▃▄▆▃▅▅▄▁▆

0,1
final_test_acc,21.03851
train_loss,1.56012
val_acc,30.61224


Dataset already downloaded and extracted.
Creating Training Dataset:
    Combination: desert/bulldog
    Limit: 1584
    Combination: jungle/bulldog
    Limit: 1584
    Combination: dirt/bulldog
    Limit: 1584
    Combination: snow/bulldog
    Limit: 1584
    Combination: desert/dachshund
    Limit: 1584
    Combination: jungle/dachshund
    Limit: 1584
    Combination: dirt/dachshund
    Limit: 1584
    Combination: snow/dachshund
    Limit: 1584
    Combination: desert/labrador
    Limit: 1584
    Combination: jungle/labrador
    Limit: 1584
    Combination: dirt/labrador
    Limit: 1584
    Combination: snow/labrador
    Limit: 1584
    Combination: desert/corgi
    Limit: 1584
    Combination: jungle/corgi
    Limit: 1584
    Combination: dirt/corgi
    Limit: 1584
    Combination: snow/corgi
    Limit: 1584
Creating Testing Dataset:
    Combination: dirt/bulldog
    Limit: None
    Combination: dirt/bulldog
    Limit: None
    Combination: snow/dachshund
    Limit: None
    Combi

Training. Epochs:   0%|          | 0/12 [00:00<?, ?it/s]
  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:48,  1.29s/it][A
  1%|          | 2/179 [00:01<01:44,  1.70it/s][A
  2%|▏         | 3/179 [00:02<02:13,  1.31it/s][A
  3%|▎         | 5/179 [00:03<01:51,  1.56it/s][A
  4%|▍         | 7/179 [00:04<01:44,  1.65it/s][A
  5%|▌         | 9/179 [00:05<01:37,  1.74it/s][A
  6%|▌         | 11/179 [00:06<01:34,  1.77it/s][A
  7%|▋         | 13/179 [00:07<01:32,  1.80it/s][A
  8%|▊         | 15/179 [00:08<01:28,  1.84it/s][A
  9%|▉         | 17/179 [00:09<01:26,  1.86it/s][A
 11%|█         | 19/179 [00:10<01:24,  1.89it/s][A
 12%|█▏        | 21/179 [00:11<01:23,  1.90it/s][A
 13%|█▎        | 23/179 [00:12<01:21,  1.92it/s][A
 14%|█▍        | 25/179 [00:14<01:22,  1.87it/s][A
 15%|█▌        | 27/179 [00:15<01:20,  1.90it/s][A
 16%|█▌        | 29/179 [00:16<01:19,  1.88it/s][A
 17%|█▋        | 31/179 [00:17<01:18,  1.88it/s][A
 18%|█▊        | 33/17

Epoch 1: Training Loss: 1.505
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:01<00:20,  1.07s/it][A
Evaluating:  15%|█▌        | 3/20 [00:02<00:10,  1.58it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:08,  1.80it/s][A
Evaluating:  30%|███       | 6/20 [00:03<00:06,  2.30it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:07,  1.80it/s][A
Evaluating:  40%|████      | 8/20 [00:04<00:05,  2.32it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:06,  1.79it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:04,  1.93it/s][A
Evaluating:  60%|██████    | 12/20 [00:06<00:03,  2.39it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  1.89it/s][A
Evaluating:  70%|███████   | 14/20 [00:06<00:02,  2.37it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:07<00:02,  1.87it/s][A
Evaluating:  80%|████████  | 16/20 [00:07<00:01,  2.37it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:08<00:01,  1.85it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:09<00:00,  1.96it/s][A
Trainin

Acc: 25.257%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:02,  1.02s/it][A
  1%|          | 2/179 [00:01<01:25,  2.08it/s][A
  2%|▏         | 3/179 [00:01<01:46,  1.66it/s][A
  2%|▏         | 4/179 [00:01<01:10,  2.47it/s][A
  3%|▎         | 5/179 [00:02<01:30,  1.92it/s][A
  3%|▎         | 6/179 [00:02<01:05,  2.64it/s][A
  4%|▍         | 7/179 [00:03<01:25,  2.00it/s][A
  5%|▌         | 9/179 [00:04<01:18,  2.18it/s][A
  6%|▌         | 11/179 [00:05<01:14,  2.26it/s][A
  7%|▋         | 12/179 [00:05<01:01,  2.73it/s][A
  7%|▋         | 13/179 [00:06<01:14,  2.21it/s][A
  8%|▊         | 14/179 [00:06<01:00,  2.75it/s][A
  8%|▊         | 15/179 [00:06<01:16,  2.15it/s][A
  9%|▉         | 17/179 [00:07<01:12,  2.23it/s][A
 10%|█         | 18/179 [00:07<00:59,  2.71it/s][A
 11%|█         | 19/179 [00:08<01:12,  2.20it/s][A
 12%|█▏        | 21/179 [00:09<01:09,  2.28it/s][A
 13%|█▎        | 23/179 [00:10<01:07,  2.33it/s][A
 14%|█▍        | 25/179 [00:

Epoch 2: Training Loss: 1.416
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:00<00:18,  1.01it/s][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.77it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  2.05it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:05,  2.19it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:04,  2.26it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:03,  2.30it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:05<00:02,  2.34it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.36it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:07<00:01,  2.37it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.37it/s][A
Training. Epochs:  17%|█▋        | 2/12 [03:02<14:56, 89.68s/it]

Acc: 23.994%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:01,  1.02s/it][A
  1%|          | 2/179 [00:01<01:26,  2.04it/s][A
  2%|▏         | 3/179 [00:01<01:45,  1.67it/s][A
  3%|▎         | 5/179 [00:02<01:25,  2.03it/s][A
  4%|▍         | 7/179 [00:03<01:18,  2.19it/s][A
  5%|▌         | 9/179 [00:04<01:14,  2.28it/s][A
  6%|▌         | 10/179 [00:04<01:01,  2.73it/s][A
  6%|▌         | 11/179 [00:05<01:15,  2.24it/s][A
  7%|▋         | 12/179 [00:05<01:00,  2.78it/s][A
  7%|▋         | 13/179 [00:05<01:15,  2.19it/s][A
  8%|▊         | 14/179 [00:06<00:59,  2.79it/s][A
  8%|▊         | 15/179 [00:06<01:15,  2.17it/s][A
  9%|▉         | 16/179 [00:06<00:58,  2.80it/s][A
  9%|▉         | 17/179 [00:07<01:15,  2.16it/s][A
 11%|█         | 19/179 [00:08<01:09,  2.29it/s][A
 12%|█▏        | 21/179 [00:09<01:07,  2.35it/s][A
 13%|█▎        | 23/179 [00:10<01:05,  2.39it/s][A
 14%|█▍        | 25/179 [00:10<01:03,  2.41it/s][A
 15%|█▌        | 27/179 [0

Epoch 3: Training Loss: 1.389
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:01<00:20,  1.10s/it][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:10,  1.64it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  1.92it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:06,  2.10it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:05,  2.17it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:04,  2.23it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  2.21it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:07<00:02,  2.22it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:08<00:01,  2.23it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:09<00:00,  2.23it/s][A
Training. Epochs:  25%|██▌       | 3/12 [04:27<13:10, 87.78s/it]

Acc: 29.163%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:04,  1.03s/it][A
  1%|          | 2/179 [00:01<01:25,  2.06it/s][A
  2%|▏         | 3/179 [00:01<01:47,  1.64it/s][A
  2%|▏         | 4/179 [00:02<01:12,  2.42it/s][A
  3%|▎         | 5/179 [00:02<01:33,  1.87it/s][A
  4%|▍         | 7/179 [00:03<01:21,  2.11it/s][A
  5%|▌         | 9/179 [00:04<01:17,  2.19it/s][A
  6%|▌         | 11/179 [00:05<01:14,  2.25it/s][A
  7%|▋         | 13/179 [00:06<01:13,  2.27it/s][A
  8%|▊         | 15/179 [00:07<01:11,  2.30it/s][A
  9%|▉         | 17/179 [00:07<01:09,  2.34it/s][A
 11%|█         | 19/179 [00:08<01:07,  2.36it/s][A
 12%|█▏        | 21/179 [00:09<01:06,  2.38it/s][A
 13%|█▎        | 23/179 [00:10<01:06,  2.34it/s][A
 14%|█▍        | 25/179 [00:11<01:06,  2.33it/s][A
 15%|█▌        | 27/179 [00:12<01:06,  2.30it/s][A
 16%|█▌        | 29/179 [00:13<01:05,  2.29it/s][A
 17%|█▋        | 31/179 [00:13<01:04,  2.30it/s][A
 18%|█▊        | 32/179 [00

Epoch 4: Training Loss: 1.383
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:01<00:19,  1.02s/it][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.71it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  1.98it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:06,  2.11it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:04,  2.20it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:04,  2.24it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  2.25it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:07<00:02,  2.24it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:08<00:01,  2.23it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.25it/s][A
Training. Epochs:  33%|███▎      | 4/12 [05:53<11:36, 87.02s/it]

Acc: 32.794%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:03,  1.03s/it][A
  2%|▏         | 3/179 [00:01<01:41,  1.73it/s][A
  2%|▏         | 4/179 [00:01<01:12,  2.41it/s][A
  3%|▎         | 5/179 [00:02<01:32,  1.87it/s][A
  4%|▍         | 7/179 [00:03<01:20,  2.13it/s][A
  4%|▍         | 8/179 [00:03<01:04,  2.66it/s][A
  5%|▌         | 9/179 [00:04<01:19,  2.14it/s][A
  6%|▌         | 10/179 [00:04<01:03,  2.66it/s][A
  6%|▌         | 11/179 [00:05<01:18,  2.15it/s][A
  7%|▋         | 12/179 [00:05<01:03,  2.62it/s][A
  7%|▋         | 13/179 [00:06<01:17,  2.13it/s][A
  8%|▊         | 14/179 [00:06<01:02,  2.66it/s][A
  8%|▊         | 15/179 [00:06<01:17,  2.11it/s][A
  9%|▉         | 16/179 [00:07<01:01,  2.67it/s][A
  9%|▉         | 17/179 [00:07<01:18,  2.07it/s][A
 10%|█         | 18/179 [00:07<00:59,  2.71it/s][A
 11%|█         | 19/179 [00:08<01:16,  2.10it/s][A
 11%|█         | 20/179 [00:08<00:58,  2.72it/s][A
 12%|█▏        | 21/179 [00

Epoch 5: Training Loss: 1.360
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:01<00:19,  1.03s/it][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.72it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  2.00it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:06,  2.15it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:05,  2.19it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:03,  2.26it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  2.31it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.32it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:07<00:01,  2.33it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.33it/s][A
Training. Epochs:  42%|████▏     | 5/12 [07:18<10:02, 86.12s/it]

Acc: 33.899%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:04,  1.04s/it][A
  1%|          | 2/179 [00:01<01:26,  2.05it/s][A
  2%|▏         | 3/179 [00:01<01:44,  1.68it/s][A
  2%|▏         | 4/179 [00:01<01:10,  2.50it/s][A
  3%|▎         | 5/179 [00:02<01:30,  1.92it/s][A
  3%|▎         | 6/179 [00:02<01:06,  2.62it/s][A
  4%|▍         | 7/179 [00:03<01:25,  2.02it/s][A
  5%|▌         | 9/179 [00:04<01:19,  2.15it/s][A
  6%|▌         | 11/179 [00:05<01:17,  2.18it/s][A
  7%|▋         | 13/179 [00:06<01:13,  2.24it/s][A
  8%|▊         | 14/179 [00:06<01:02,  2.65it/s][A
  8%|▊         | 15/179 [00:06<01:14,  2.21it/s][A
  9%|▉         | 16/179 [00:07<01:01,  2.65it/s][A
  9%|▉         | 17/179 [00:07<01:13,  2.19it/s][A
 10%|█         | 18/179 [00:07<01:00,  2.64it/s][A
 11%|█         | 19/179 [00:08<01:13,  2.17it/s][A
 11%|█         | 20/179 [00:08<00:59,  2.65it/s][A
 12%|█▏        | 21/179 [00:09<01:14,  2.12it/s][A
 12%|█▏        | 22/179 [00:

Epoch 6: Training Loss: 1.330
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:01<00:18,  1.00it/s][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.73it/s][A
Evaluating:  20%|██        | 4/20 [00:01<00:06,  2.40it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  1.92it/s][A
Evaluating:  30%|███       | 6/20 [00:02<00:05,  2.47it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:06,  2.04it/s][A
Evaluating:  40%|████      | 8/20 [00:03<00:04,  2.49it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:05,  2.14it/s][A
Evaluating:  50%|█████     | 10/20 [00:04<00:04,  2.46it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:04,  2.18it/s][A
Evaluating:  60%|██████    | 12/20 [00:05<00:03,  2.42it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  2.23it/s][A
Evaluating:  70%|███████   | 14/20 [00:06<00:02,  2.39it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.28it/s][A
Evaluating:  80%|████████  | 16/20 [00:07<00:01,  2.34it/s][A
Evaluati

Acc: 35.122%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:14,  1.09s/it][A
  2%|▏         | 3/179 [00:01<01:44,  1.68it/s][A
  3%|▎         | 5/179 [00:02<01:28,  1.96it/s][A
  4%|▍         | 7/179 [00:03<01:21,  2.12it/s][A
  5%|▌         | 9/179 [00:04<01:17,  2.19it/s][A
  6%|▌         | 11/179 [00:05<01:15,  2.22it/s][A
  7%|▋         | 13/179 [00:06<01:14,  2.24it/s][A
  8%|▊         | 15/179 [00:07<01:12,  2.25it/s][A
  9%|▉         | 17/179 [00:08<01:12,  2.24it/s][A
 11%|█         | 19/179 [00:08<01:11,  2.25it/s][A
 12%|█▏        | 21/179 [00:09<01:10,  2.26it/s][A
 13%|█▎        | 23/179 [00:10<01:08,  2.28it/s][A
 14%|█▍        | 25/179 [00:11<01:07,  2.28it/s][A
 15%|█▍        | 26/179 [00:11<00:57,  2.64it/s][A
 15%|█▌        | 27/179 [00:12<01:10,  2.17it/s][A
 16%|█▌        | 28/179 [00:12<00:57,  2.63it/s][A
 16%|█▌        | 29/179 [00:13<01:12,  2.08it/s][A
 17%|█▋        | 31/179 [00:14<01:08,  2.16it/s][A
 18%|█▊        | 33/179 [

Epoch 7: Training Loss: 1.247
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:00<00:18,  1.02it/s][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.77it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  2.04it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:05,  2.20it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:04,  2.28it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:03,  2.34it/s][A
Evaluating:  60%|██████    | 12/20 [00:05<00:02,  2.74it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:05<00:03,  2.26it/s][A
Evaluating:  70%|███████   | 14/20 [00:06<00:02,  2.75it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.23it/s][A
Evaluating:  80%|████████  | 16/20 [00:06<00:01,  2.69it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:07<00:01,  2.21it/s][A
Evaluating:  90%|█████████ | 18/20 [00:07<00:00,  2.68it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.21it/s][A
Training. Epochs:  58%|█████▊    | 7/12 [10:07<07:06, 85.33s/it]

Acc: 47.040%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:02,  1.02s/it][A
  1%|          | 2/179 [00:01<01:25,  2.06it/s][A
  2%|▏         | 3/179 [00:01<01:46,  1.65it/s][A
  2%|▏         | 4/179 [00:01<01:11,  2.46it/s][A
  3%|▎         | 5/179 [00:02<01:32,  1.88it/s][A
  4%|▍         | 7/179 [00:03<01:20,  2.13it/s][A
  5%|▌         | 9/179 [00:04<01:16,  2.23it/s][A
  6%|▌         | 11/179 [00:05<01:12,  2.30it/s][A
  7%|▋         | 13/179 [00:06<01:11,  2.34it/s][A
  8%|▊         | 15/179 [00:06<01:09,  2.36it/s][A
  9%|▉         | 17/179 [00:07<01:08,  2.37it/s][A
 11%|█         | 19/179 [00:08<01:06,  2.39it/s][A
 12%|█▏        | 21/179 [00:09<01:05,  2.40it/s][A
 13%|█▎        | 23/179 [00:10<01:04,  2.41it/s][A
 14%|█▍        | 25/179 [00:11<01:03,  2.41it/s][A
 15%|█▌        | 27/179 [00:11<01:03,  2.40it/s][A
 16%|█▌        | 29/179 [00:12<01:02,  2.41it/s][A
 17%|█▋        | 31/179 [00:13<01:01,  2.41it/s][A
 18%|█▊        | 33/179 [00

Epoch 8: Training Loss: 1.058
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:00<00:18,  1.01it/s][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.76it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  2.04it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:05,  2.20it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:04,  2.28it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:03,  2.32it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:05<00:02,  2.35it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.39it/s][A
Evaluating:  80%|████████  | 16/20 [00:06<00:01,  2.78it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:07<00:01,  2.25it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.33it/s][A
Training. Epochs:  67%|██████▋   | 8/12 [11:30<05:38, 84.54s/it]

Acc: 61.326%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:02,  1.03s/it][A
  1%|          | 2/179 [00:01<01:25,  2.07it/s][A
  2%|▏         | 3/179 [00:01<01:44,  1.68it/s][A
  2%|▏         | 4/179 [00:01<01:10,  2.50it/s][A
  3%|▎         | 5/179 [00:02<01:29,  1.95it/s][A
  3%|▎         | 6/179 [00:02<01:04,  2.67it/s][A
  4%|▍         | 7/179 [00:03<01:24,  2.03it/s][A
  4%|▍         | 8/179 [00:03<01:03,  2.71it/s][A
  5%|▌         | 9/179 [00:04<01:20,  2.11it/s][A
  6%|▌         | 11/179 [00:05<01:14,  2.25it/s][A
  7%|▋         | 13/179 [00:05<01:11,  2.32it/s][A
  8%|▊         | 14/179 [00:06<00:59,  2.79it/s][A
  8%|▊         | 15/179 [00:06<01:12,  2.26it/s][A
  9%|▉         | 17/179 [00:07<01:09,  2.34it/s][A
 10%|█         | 18/179 [00:07<00:57,  2.81it/s][A
 11%|█         | 19/179 [00:08<01:10,  2.26it/s][A
 12%|█▏        | 21/179 [00:09<01:08,  2.31it/s][A
 12%|█▏        | 22/179 [00:09<00:56,  2.78it/s][A
 13%|█▎        | 23/179 [00:1

Epoch 9: Training Loss: 0.887
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:00<00:18,  1.01it/s][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.79it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  2.06it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:05,  2.18it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:04,  2.22it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:04,  2.24it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  2.25it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.26it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:07<00:01,  2.27it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.30it/s][A
Training. Epochs:  75%|███████▌  | 9/12 [12:54<04:13, 84.41s/it]

Acc: 60.142%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:11,  1.08s/it][A
  1%|          | 2/179 [00:01<01:30,  1.96it/s][A
  2%|▏         | 3/179 [00:01<01:49,  1.61it/s][A
  3%|▎         | 5/179 [00:02<01:29,  1.95it/s][A
  4%|▍         | 7/179 [00:03<01:21,  2.10it/s][A
  4%|▍         | 8/179 [00:03<01:05,  2.59it/s][A
  5%|▌         | 9/179 [00:04<01:21,  2.09it/s][A
  6%|▌         | 11/179 [00:05<01:16,  2.18it/s][A
  7%|▋         | 12/179 [00:05<01:02,  2.66it/s][A
  7%|▋         | 13/179 [00:06<01:17,  2.13it/s][A
  8%|▊         | 15/179 [00:07<01:14,  2.21it/s][A
  9%|▉         | 16/179 [00:07<01:00,  2.69it/s][A
  9%|▉         | 17/179 [00:07<01:14,  2.17it/s][A
 11%|█         | 19/179 [00:08<01:11,  2.25it/s][A
 12%|█▏        | 21/179 [00:09<01:08,  2.29it/s][A
 12%|█▏        | 22/179 [00:09<00:57,  2.73it/s][A
 13%|█▎        | 23/179 [00:10<01:11,  2.19it/s][A
 14%|█▍        | 25/179 [00:11<01:08,  2.24it/s][A
 15%|█▌        | 27/179 [00

Epoch 10: Training Loss: 0.780
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:01<00:19,  1.03s/it][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.72it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  2.01it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:06,  2.12it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:04,  2.21it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:03,  2.26it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  2.29it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.33it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:07<00:01,  2.34it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.35it/s][A
Training. Epochs:  83%|████████▎ | 10/12 [14:20<02:49, 84.73s/it]

Acc: 67.719%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:07,  1.05s/it][A
  1%|          | 2/179 [00:01<01:27,  2.03it/s][A
  2%|▏         | 3/179 [00:01<01:47,  1.64it/s][A
  2%|▏         | 4/179 [00:02<01:11,  2.44it/s][A
  3%|▎         | 5/179 [00:02<01:32,  1.89it/s][A
  3%|▎         | 6/179 [00:02<01:06,  2.60it/s][A
  4%|▍         | 7/179 [00:03<01:26,  2.00it/s][A
  5%|▌         | 9/179 [00:04<01:18,  2.16it/s][A
  6%|▌         | 11/179 [00:05<01:15,  2.22it/s][A
  7%|▋         | 12/179 [00:05<01:02,  2.69it/s][A
  7%|▋         | 13/179 [00:06<01:17,  2.14it/s][A
  8%|▊         | 14/179 [00:06<01:01,  2.67it/s][A
  8%|▊         | 15/179 [00:07<01:19,  2.06it/s][A
  9%|▉         | 17/179 [00:07<01:15,  2.14it/s][A
 11%|█         | 19/179 [00:08<01:12,  2.22it/s][A
 12%|█▏        | 21/179 [00:09<01:09,  2.27it/s][A
 12%|█▏        | 22/179 [00:09<00:58,  2.69it/s][A
 13%|█▎        | 23/179 [00:10<01:11,  2.20it/s][A
 14%|█▍        | 25/179 [00:

Epoch 11: Training Loss: 0.678
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:00<00:18,  1.01it/s][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.78it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  2.05it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:05,  2.21it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:04,  2.29it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:03,  2.28it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  2.29it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.34it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:07<00:01,  2.36it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.38it/s][A
Training. Epochs:  92%|█████████▏| 11/12 [15:44<01:24, 84.68s/it]

Acc: 72.928%



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<03:04,  1.04s/it][A
  1%|          | 2/179 [00:01<01:27,  2.03it/s][A
  2%|▏         | 3/179 [00:01<01:45,  1.66it/s][A
  2%|▏         | 4/179 [00:01<01:10,  2.47it/s][A
  3%|▎         | 5/179 [00:02<01:31,  1.90it/s][A
  4%|▍         | 7/179 [00:03<01:21,  2.11it/s][A
  5%|▌         | 9/179 [00:04<01:16,  2.23it/s][A
  6%|▌         | 11/179 [00:05<01:13,  2.29it/s][A
  7%|▋         | 13/179 [00:06<01:11,  2.32it/s][A
  8%|▊         | 15/179 [00:06<01:10,  2.34it/s][A
  9%|▉         | 17/179 [00:07<01:09,  2.34it/s][A
 10%|█         | 18/179 [00:07<00:59,  2.72it/s][A
 11%|█         | 19/179 [00:08<01:12,  2.20it/s][A
 11%|█         | 20/179 [00:08<00:59,  2.69it/s][A
 12%|█▏        | 21/179 [00:09<01:14,  2.11it/s][A
 13%|█▎        | 23/179 [00:10<01:11,  2.18it/s][A
 14%|█▍        | 25/179 [00:11<01:09,  2.21it/s][A
 15%|█▍        | 26/179 [00:11<00:58,  2.63it/s][A
 15%|█▌        | 27/179 [00

Epoch 12: Training Loss: 0.610
Evaluating on validation set...



Evaluating:   0%|          | 0/20 [00:00<?, ?it/s][A
Evaluating:   5%|▌         | 1/20 [00:01<00:19,  1.03s/it][A
Evaluating:  15%|█▌        | 3/20 [00:01<00:09,  1.72it/s][A
Evaluating:  25%|██▌       | 5/20 [00:02<00:07,  2.00it/s][A
Evaluating:  35%|███▌      | 7/20 [00:03<00:06,  2.15it/s][A
Evaluating:  45%|████▌     | 9/20 [00:04<00:04,  2.23it/s][A
Evaluating:  55%|█████▌    | 11/20 [00:05<00:03,  2.26it/s][A
Evaluating:  65%|██████▌   | 13/20 [00:06<00:03,  2.31it/s][A
Evaluating:  75%|███████▌  | 15/20 [00:06<00:02,  2.31it/s][A
Evaluating:  85%|████████▌ | 17/20 [00:07<00:01,  2.34it/s][A
Evaluating:  95%|█████████▌| 19/20 [00:08<00:00,  2.37it/s][A


Acc: 73.244%
Finished training, now evaluating on test set.


                                                             

Acc: 34.572%




#5. Examine Training Results

Next:
1. Create a folder of the CORGI class, sampling from all regions
2. Run CRAFT on the Corgi
3. Denote which concepts are (i) True to subject or (ii) Spurious Correlates
4. Determine a way to know if Spurious correlates are an issue -- likely **if a prediction can be made with solely spurious correlates, then we know that there is something wrong with the model**
    - We can do this by:
        - Find the concept that is the spurious correlate
        - Compare importance across the class to the other ones
        - find images with high importance and low importance of that correlate
        - measure difference in classification
        - So something like (class_includes_correlate - class_excludes_correlate) pairwise in the dataset

In [None]:
# Move models to model drive
model_dir = os.path.join(project_dir, "Models")
model = '/content/entire_dataset_resnet18.a1-e=12-lr=0.01_limit=20.pt'
!cp $model $model_dir

In [None]:
os.listdir(model_dir)

['o2o_easy_vit-e=2-lr=0.01_limit=20.pt',
 'o2o_easy_resnet18.a1-e=2-lr=0.01_limit=20.pt',
 'o2o_medium_resnet18.a1-e=2-lr=0.01_limit=20.pt',
 'buggy_o2o_easy_resnet18.a1-e=2-lr=0.01_limit=20.pt',
 'o2o_easy_resnet18.a1-e=12-lr=0.01_limit=20.pt',
 'entire_dataset_resnet18.a1-e=12-lr=0.01_limit=20.pt']

## Evaluate the model on a test set

In [None]:
MODEL_NAME='resnet18.a1_in1k'
set_model_name(MODEL_NAME)

# Pull model from drive
spawrious = get_spawrious_dataset(dataset_name=args.dataset, root_dir=args.data_dir)
test_set = spawrious.get_test_dataset()

model_name = "o2o_medium_resnet18.a1-e=2-lr=0.01_limit=20.pt"
model_dir = os.path.join(project_dir, "Models")
model_path = os.path.join(model_dir, model_name)

model = get_model(args)
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(torch.device('cuda'))

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
)


test_acc = evaluate(model, test_loader, 'cuda')


Dataset already downloaded and extracted.
Creating Training Dataset:
    Combination: desert/bulldog
    Limit: 2217
    Combination: jungle/bulldog
    Limit: 2217
    Combination: dirt/bulldog
    Limit: 2217
    Combination: snow/bulldog
    Limit: 2217
    Combination: desert/dachshund
    Limit: 2217
    Combination: jungle/dachshund
    Limit: 2217
    Combination: dirt/dachshund
    Limit: 2217
    Combination: snow/dachshund
    Limit: 2217
    Combination: desert/labrador
    Limit: 2217
    Combination: jungle/labrador
    Limit: 2217
    Combination: dirt/labrador
    Limit: 2217
    Combination: snow/labrador
    Limit: 2217
    Combination: desert/corgi
    Limit: 2217
    Combination: jungle/corgi
    Limit: 2217
    Combination: dirt/corgi
    Limit: 2217
    Combination: snow/corgi
    Limit: 2217
Creating Testing Dataset:
    Combination: dirt/bulldog
    Limit: None
    Combination: dirt/bulldog
    Limit: None
    Combination: snow/dachshund
    Limit: None
    Combi

  model.load_state_dict(torch.load(model_path))
                                                             

Acc: 3.015%


