diff --git a/AROS/Notebooks/AROS.ipynb b/AROS/Notebooks/AROS.ipynb new file mode 100644 index 0000000..b14ab19 --- /dev/null +++ b/AROS/Notebooks/AROS.ipynb @@ -0,0 +1,1618 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Adversarially Robust Out-of-Distribution Detection Using Lyapunov-Stabilized Embeddings" + ], + "metadata": { + "id": "RFxEz28oe7dE" + } + }, + { + "cell_type": "markdown", + "source": [ + "This notebook is designed to replicate and analyze the results presented in Table 1 of the AROS paper, focusing on out-of-distribution detection performance under both attack scenarios and clean evaluation. The dataset configurations involve using CIFAR-10 and CIFAR-100 as in-distribution and out-of-distribution datasets. The notebook is structured to load a pre-trained model as the encoder, followed by generating fake OOD embeddings through sampling. The model is then trained using the designed loss function and evaluated across various OOD detection benchmarks to assess its performance under different conditions.\n", + "\n" + ], + "metadata": { + "id": "ZL5Va1N940xJ" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8i7kYslKnOhJ" + }, + "source": [ + "#Import utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AfFpweyoWApG", + "outputId": "1995f10f-31e9-42b5-cfcb-3faa3eea3d7d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Collecting git+https://github.com/RobustBench/robustbench.git (from -r requirements.txt (line 3))\n", + " Cloning https://github.com/RobustBench/robustbench.git to /tmp/pip-req-build-agy3e4yg\n", + " Running command git clone --filter=blob:none --quiet https://github.com/RobustBench/robustbench.git /tmp/pip-req-build-agy3e4yg\n", + " Resolved https://github.com/RobustBench/robustbench.git to commit 776bc95bb4167827fb102a32ac5aea62e46cfaab\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: geotorch in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 1)) (0.3.0)\n", + "Requirement already satisfied: torchdiffeq in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 2)) (0.2.4)\n", + "Requirement already satisfied: torch>=1.9 in /usr/local/lib/python3.10/dist-packages (from geotorch->-r requirements.txt (line 1)) (2.4.1)\n", + "Requirement already satisfied: scipy>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from torchdiffeq->-r requirements.txt (line 2)) (1.14.1)\n", + "Collecting autoattack@ git+https://github.com/fra31/auto-attack.git@a39220048b3c9f2cca9a4d3a54604793c68eca7e#egg=autoattack\n", + " Using cached autoattack-0.1-py3-none-any.whl\n", + "Requirement already satisfied: Jinja2~=3.1.2 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (3.1.4)\n", + "Requirement already satisfied: gdown==5.1.0 in /home/hossein/.local/lib/python3.10/site-packages (from robustbench==1.1->-r requirements.txt (line 3)) (5.1.0)\n", + "Requirement already satisfied: numpy>=1.19.4 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.1.2)\n", + "Requirement already satisfied: pandas>=1.3.5 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.2.3)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (6.0.2)\n", + "Requirement already satisfied: requests>=2.25.0 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.32.3)\n", + "Requirement already satisfied: timm>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (1.0.9)\n", + "Requirement already satisfied: torchvision>=0.8.2 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (0.19.1)\n", + "Requirement already satisfied: tqdm>=4.56.1 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (4.66.5)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (3.16.1)\n", + "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (4.12.3)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2~=3.1.2->robustbench==1.1->-r requirements.txt (line 3)) (2.1.5)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2024.2)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2.9.0.post0)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2024.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (2024.8.30)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (3.3.2)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (2.2.3)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (3.10)\n", + "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (from timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (0.25.2)\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (0.4.5)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (2024.9.0)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (1.13.3)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.3.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (4.12.2)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (10.3.2.106)\n", + "Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (3.0.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (3.3)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (2.20.5)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (9.1.0.70)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.6.77)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision>=0.8.2->robustbench==1.1->-r requirements.txt (line 3)) (10.4.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (1.16.0)\n", + "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (2.6)\n", + "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (24.1)\n", + "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (1.7.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.9->geotorch->-r requirements.txt (line 1)) (1.3.0)\n" + ] + } + ], + "source": [ + "\n", + "!pip install -r requirements.txt\n", + "import argparse\n", + "import torch\n", + "import torch.nn as nn\n", + "from evaluate import *\n", + "from utils import *\n", + "from tqdm.notebook import tqdm\n", + "from data_loader import *\n", + "from stability_loss_function import *\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y0VhjMLBnZ1e" + }, + "source": [ + "#Set hyperparameters & dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jdrVKMH0nZ_r" + }, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description=\"Hyperparameters for the script\")\n", + "\n", + "# Define the hyperparameters controlled via CLI 'Ding2020MMA'\n", + "\n", + "parser.add_argument('--in_dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')\n", + "parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')\n", + "parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')\n", + "parser.add_argument('--attack_eps', type=float, default=8/255, help='Perturbation bound (epsilon) for PGD attack')\n", + "parser.add_argument('--attack_steps', type=int, default=10, help='Number of steps for the PGD attack')\n", + "parser.add_argument('--attack_alpha', type=float, default=2.5 * (8/255) / 10, help='Step size (alpha) for each PGD attack iteration')\n", + "\n", + "args = parser.parse_args('')\n", + "\n", + "# Set the default model name based on the selected dataset\n", + "if args.in_dataset == 'cifar10':\n", + " default_model_name = 'Rebuffi2021Fixing_70_16_cutmix_extra'\n", + "elif args.in_dataset == 'cifar100':\n", + " default_model_name = 'Wang2023Better_WRN-70-16'\n", + "\n", + "parser.add_argument('--model_name', type=str, default=default_model_name, choices=['Rebuffi2021Fixing_70_16_cutmix_extra', 'Wang2023Better_WRN-70-16'], help='The pre-trained model to be used for feature extraction')\n", + "\n", + "# Re-parse arguments to include model_name selection based on the dataset\n", + "args = parser.parse_args('')\n", + "num_classes = 10 if args.in_dataset == 'cifar10' else 100\n", + "\n", + "trainloader, testloader,test_set, ID_OOD_loader = get_loaders(in_dataset=args.in_dataset)\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8jZhzNCFnjBK" + }, + "source": [ + "#Fake embedding generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "fdbba9d39fce407ba41442ba1b4fc566", + "e22a82ff530d4f24916209a59d6a8606", + "7d8d3bd63732400ba75b5696fd83b3a7", + "72bcb81e8de942fbac2d7b0382990df6", + "e323d7e51fcb465398f9935f6f799ae4", + "1a9dad23b6324e03820e703ff22d7aae", + "6c97387066f94bfc939a3215143dc417", + "7f286d1eaf0541fe8dacd32889c49f51", + "dadb563fcb2c402cbabc0cfeab9e7944", + "2ddf603c640644f796a036545cef37a7", + "366528bf8da4497aaeb25fd693b3d44c", + "78442c980cb34b08a7d97f44da82b557", + "b207a2db94584c2ebc67e709ad8fb17d", + "baa37218fe534fc29cded386088d638a", + "c0a414c9077e4939b2a08715a67ab684", + "1c871c06eb8f47f693aa20de02877430", + "60f5f9215e6546a2b79891aa88312e73", + "42f6cc6e9feb4673a8c6e482a2026a8c", + "f26be04e4fde480883c11e27522bb702", + "899e7cd487334e5ea0010af7503452d0", + "eabd56e5b8644c0b887962a532fbde3c", + "711614caf63b40539ec7f64acdedb2af", + "823599a9cd814ea28cc6aa137d470823", + "7e614d7fa27441d284df6dac03df4324", + "6751af9107a040ffadcfbee3a150b431", + "1b7329086eee4a20a317217fd5284627", + "06a71eb01eaf4f97bc2da1ecc3b40e78", + "2c76171e3cff4e1496948f3c083485ee", + "3dfe45e073eb40e08ef7a4d71e65ea01", + "3f8a36d981b248949382b5bf8698d72d", + "6fe3f66466314394838fac4ec46fdb58", + "b6cff4fa1db44035a20543aa83ebd82b", + "39452a0245074d81acf29a348b59dfcf" + ] + }, + "id": "cnjl5M8xLq2i", + "outputId": "aaa95439-242b-446e-cbf1-012e668f3835" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/hossein/.local/lib/python3.10/site-packages/robustbench/utils.py:165: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " checkpoint = torch.load(model_path, map_location=torch.device('cpu'))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "embedding computed...\n" + ] + } + ], + "source": [ + "\n", + "\n", + "\n", + "robust_backbone = load_model(model_name=args.model_name, dataset=args.in_dataset, threat_model=args.threat_model).to(device)\n", + "\n", + "\n", + "\n", + "last_layer_name, last_layer = list(robust_backbone.named_children())[-1]\n", + "setattr(robust_backbone, last_layer_name, nn.Identity())\n", + "fake_loader=None\n", + "\n", + "\n", + "num_fake_samples = len(trainloader.dataset) // num_classes\n", + "\n", + "\n", + "\n", + "\n", + "embeddings, labels = [], []\n", + "\n", + "with torch.no_grad():\n", + " for imgs, lbls in trainloader:\n", + " imgs = imgs.to(device, non_blocking=True)\n", + " embed = robust_backbone(imgs).cpu() # move to CPU only once per batch\n", + " embeddings.append(embed)\n", + " labels.append(lbls)\n", + "embeddings = torch.cat(embeddings).numpy()\n", + "labels = torch.cat(labels).numpy()\n", + "\n", + "\n", + "print(\"embedding computed...\")\n", + "\n", + "\n", + "if args.fast==False:\n", + " gmm_dict = {}\n", + " for cls in np.unique(labels):\n", + " cls_embed = embeddings[labels == cls]\n", + " gmm = GaussianMixture(n_components=1, covariance_type='full').fit(cls_embed)\n", + " gmm_dict[cls] = gmm\n", + "\n", + " print(\"fake crafing...\")\n", + "\n", + " fake_data = []\n", + "\n", + "\n", + " for cls, gmm in gmm_dict.items():\n", + " samples, likelihoods = [], []\n", + " while len(samples) < num_samples_needed:\n", + " s = gmm.sample(100)[0]\n", + " likelihood = gmm.score_samples(s)\n", + " samples.append(s[likelihood < np.quantile(likelihood, 0.001)])\n", + " likelihoods.append(likelihood[likelihood < np.quantile(likelihood, 0.001)])\n", + " if sum(len(smp) for smp in samples) >= num_samples_needed:\n", + " break\n", + " samples = np.vstack(samples)[:num_samples_needed]\n", + " fake_data.append(samples)\n", + "\n", + " fake_data = np.vstack(fake_data)\n", + " fake_data = torch.tensor(fake_data).float()\n", + " fake_data = F.normalize(fake_data, p=2, dim=1)\n", + "\n", + " fake_labels = torch.full((fake_data.shape[0],), 10)\n", + " fake_loader = DataLoader(TensorDataset(fake_data, fake_labels), batch_size=128, shuffle=True)\n", + "\n", + "if args.fast==True:\n", + "\n", + "\n", + " noise_std = 0.1 # standard deviation of noise\n", + " noisy_embeddings = torch.tensor(embeddings) + noise_std * torch.randn_like(torch.tensor(embeddings))\n", + "\n", + " # Normalize Noisy Embeddings\n", + " noisy_embeddings = F.normalize(noisy_embeddings, p=2, dim=1)[:len(trainloader.dataset)//num_classes]\n", + "\n", + " # Convert to DataLoader if needed\n", + " fake_labels = torch.full((noisy_embeddings.shape[0],), num_classes)[:len(trainloader.dataset)//num_classes]\n", + " fake_loader = DataLoader(TensorDataset(noisy_embeddings, fake_labels), batch_size=128, shuffle=True)\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D7ibC97YnxYq" + }, + "source": [ + "#Train and eval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MsAOiMNbnyil" + }, + "outputs": [], + "source": [ + "\n", + "\n", + "final_model = stability_loss_function_(trainloader, testloader, robust_backbone, num_classes, fake_loader, last_layer, args)\n", + "\n", + "\n", + "test_attack = PGD_AUC(final_model, eps=args.attack_eps, steps=args.attack_steps, alpha=args.attack_alpha, num_classes=num_classes)\n", + "get_clean_AUC(final_model, ID_OOD_loader , device, num_classes)\n", + "adv_auc = get_auc_adversarial(model=final_model, test_loader=ID_OOD_loader, test_attack=test_attack, device=device, num_classes=num_classes)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "W18YwKYln36n" + }, + "source": [ + "#Extra Experiments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z9xaiEJyPBdb" + }, + "outputs": [], + "source": [ + "import os\n", + "import zipfile\n", + "!pip install wget\n", + "import wget\n", + "from pathlib import Path\n", + "import torchvision\n", + "from torchvision import transforms\n", + "import tarfile\n", + "\n", + "image_size=32\n", + "load_out_names=[ \"places365\",\"LSUN\", \"iSUN\" ]\n", + "\n", + "\n", + "\n", + "if \"places365\" in load_out_names:\n", + " # Define the directory path and create it if it does not exist\n", + " base_dir = \"./datasets/data\"\n", + " os.makedirs(base_dir, exist_ok=True)\n", + "\n", + " # Download and save categories_places365.txt\n", + " dest = os.path.join(base_dir, \"categories_places365.txt\")\n", + " if not Path(dest).is_file():\n", + " wget.download(\"https://dl.dropboxusercontent.com/s/enr71zpolzi1xzm/categories_places365.txt\", out=dest)\n", + "\n", + " # Download and save places365_val.txt\n", + " dest = os.path.join(base_dir, \"places365_val.txt\")\n", + " if not Path(dest).is_file():\n", + " wget.download(\"https://dl.dropboxusercontent.com/s/gaf1ygpdnkhzyjo/places365_val.txt\", out=dest)\n", + "\n", + " # Download and save val_256.tar\n", + " dest = os.path.join(base_dir, \"val_256.tar\")\n", + " if not Path(dest).is_file():\n", + " wget.download(\"https://dl.dropboxusercontent.com/s/3pwqsyv33f6if3z/val_256.tar\", out=dest)\n", + "\n", + " # Extract val_256.tar if val_256 directory does not exist\n", + " dest_final = os.path.join(base_dir, \"val_256\")\n", + " if not Path(dest_final).is_dir():\n", + " with tarfile.open(dest) as tar:\n", + " tar.extractall(path=base_dir)\n", + "\n", + " # Load the Places365 dataset\n", + " places365 = torchvision.datasets.Places365(\n", + " root=base_dir,\n", + " split='val',\n", + " small=True,\n", + " download=False,\n", + " transform=transforms.Compose([\n", + " transforms.Resize(image_size),\n", + " transforms.ToTensor()\n", + " ])\n", + " )\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "if \"LSUN\" in load_out_names:\n", + " # Define the base directory and ensure it exists\n", + " base_dir = \"./datasets/data\"\n", + " os.makedirs(base_dir, exist_ok=True)\n", + "\n", + " # Define the destination path for LSUN dataset\n", + " dest = os.path.join(base_dir, \"LSUN_resize.tar.gz\")\n", + " if not Path(dest).is_file():\n", + " wget.download(\"https://bit.ly/3wA55Wb\", out=dest)\n", + " with tarfile.open(dest) as tar:\n", + " tar.extractall(path=os.path.join(base_dir, \"LSUN_resize\"))\n", + "\n", + " # Define transformation based on image size\n", + " transform = transforms.ToTensor() if image_size == 32 else transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Resize(image_size)\n", + " ])\n", + "\n", + " # Load the LSUN dataset\n", + " LSUN = torchvision.datasets.ImageFolder(root=os.path.join(base_dir, \"LSUN_resize\"), transform=transform)\n", + "\n", + "\n", + "if \"iSUN\" in load_out_names:\n", + " # Define the destination path for iSUN dataset\n", + " dest = os.path.join(base_dir, \"iSUN.tar.gz\")\n", + " if not Path(dest).is_file():\n", + " wget.download(\"https://bit.ly/3yRMTJe\", out=dest)\n", + " with tarfile.open(dest) as tar:\n", + " tar.extractall(path=os.path.join(base_dir, \"iSUN\"))\n", + "\n", + " # Define transformation based on image size\n", + " transform = transforms.ToTensor() if image_size == 32 else transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Resize(image_size)\n", + " ])\n", + "\n", + " # Load the iSUN dataset\n", + " iSUN = torchvision.datasets.ImageFolder(root=os.path.join(base_dir, \"iSUN\"), transform=transform)\n", + "\n", + "\n", + "\n", + "\n", + "class LabelChangedDataset(Dataset):\n", + " def __init__(self, original_dataset, new_label):\n", + " self.original_dataset = original_dataset\n", + " self.new_label = new_label\n", + "\n", + " def __len__(self):\n", + " return len(self.original_dataset)\n", + "\n", + " def __getitem__(self, idx):\n", + " image, _ = self.original_dataset[idx]\n", + " return image, self.new_label\n", + "\n", + "\n", + "\n", + "# Download and load the SVHN test set\n", + "svhn = torchvision.datasets.SVHN(root='./datasets/data', split='test', download=True, transform=transform)\n", + "\n", + "\n", + "\n", + "\n", + "iSUN = LabelChangedDataset(iSUN, num_classes)\n", + "\n", + "LSUN = LabelChangedDataset(LSUN, num_classes)\n", + "\n", + "places365 = LabelChangedDataset(places365, num_classes)\n", + "\n", + "\n", + "svhn = LabelChangedDataset(svhn, num_classes)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zlRGcd5tMjIS" + }, + "outputs": [], + "source": [ + "test_dataset_isun = ConcatDataset([test_set, iSUN])\n", + "\n", + "testloader_isun = DataLoader(test_dataset_isun, shuffle=False, batch_size=64)\n", + "\n", + "get_clean_AUC(final_model, testloader_isun , device, num_classes)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mMNqeiEdMmCi" + }, + "outputs": [], + "source": [ + "test_dataset_LSUN = ConcatDataset([test_set, LSUN])\n", + "\n", + "testloader_LSUN = DataLoader(test_dataset_LSUN, shuffle=False, batch_size=64)\n", + "\n", + "get_clean_AUC(final_model, testloader_LSUN, device, num_classes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dF7qrwnHsDhA" + }, + "outputs": [], + "source": [ + "test_dataset_places365 = ConcatDataset([test_set, places365])\n", + "\n", + "testloader_places365 = DataLoader(test_dataset_places365, shuffle=False, batch_size=64)\n", + "\n", + "get_clean_AUC(final_model, testloader_places365 , device, num_classes)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nTia9Dgjs1u4" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JY4i7mv_sIas" + }, + "outputs": [], + "source": [ + "test_dataset_svhn = ConcatDataset([test_set, svhn])\n", + "\n", + "testloader_svhn = DataLoader(test_dataset_svhn, shuffle=False, batch_size=64)\n", + "\n", + "get_clean_AUC(final_model, testloader_svhn , device, num_classes)\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "06a71eb01eaf4f97bc2da1ecc3b40e78": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1a9dad23b6324e03820e703ff22d7aae": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1b7329086eee4a20a317217fd5284627": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b6cff4fa1db44035a20543aa83ebd82b", + "placeholder": "​", + "style": "IPY_MODEL_39452a0245074d81acf29a348b59dfcf", + "value": " 570/1250 [19:42<22:52,  2.02s/batch]" + } + }, + "1c871c06eb8f47f693aa20de02877430": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c76171e3cff4e1496948f3c083485ee": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2ddf603c640644f796a036545cef37a7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "366528bf8da4497aaeb25fd693b3d44c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "39452a0245074d81acf29a348b59dfcf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3dfe45e073eb40e08ef7a4d71e65ea01": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3f8a36d981b248949382b5bf8698d72d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "42f6cc6e9feb4673a8c6e482a2026a8c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "60f5f9215e6546a2b79891aa88312e73": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6751af9107a040ffadcfbee3a150b431": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3f8a36d981b248949382b5bf8698d72d", + "max": 1250, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_6fe3f66466314394838fac4ec46fdb58", + "value": 570 + } + }, + "6c97387066f94bfc939a3215143dc417": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6fe3f66466314394838fac4ec46fdb58": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "711614caf63b40539ec7f64acdedb2af": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "72bcb81e8de942fbac2d7b0382990df6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2ddf603c640644f796a036545cef37a7", + "placeholder": "​", + "style": "IPY_MODEL_366528bf8da4497aaeb25fd693b3d44c", + "value": " 550/550 [10:02<00:00,  1.09s/it, Loss=2.23]" + } + }, + "78442c980cb34b08a7d97f44da82b557": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b207a2db94584c2ebc67e709ad8fb17d", + "IPY_MODEL_baa37218fe534fc29cded386088d638a", + "IPY_MODEL_c0a414c9077e4939b2a08715a67ab684" + ], + "layout": "IPY_MODEL_1c871c06eb8f47f693aa20de02877430" + } + }, + "7d8d3bd63732400ba75b5696fd83b3a7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7f286d1eaf0541fe8dacd32889c49f51", + "max": 550, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_dadb563fcb2c402cbabc0cfeab9e7944", + "value": 550 + } + }, + "7e614d7fa27441d284df6dac03df4324": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2c76171e3cff4e1496948f3c083485ee", + "placeholder": "​", + "style": "IPY_MODEL_3dfe45e073eb40e08ef7a4d71e65ea01", + "value": " 46%" + } + }, + "7f286d1eaf0541fe8dacd32889c49f51": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "823599a9cd814ea28cc6aa137d470823": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e614d7fa27441d284df6dac03df4324", + "IPY_MODEL_6751af9107a040ffadcfbee3a150b431", + "IPY_MODEL_1b7329086eee4a20a317217fd5284627" + ], + "layout": "IPY_MODEL_06a71eb01eaf4f97bc2da1ecc3b40e78" + } + }, + "899e7cd487334e5ea0010af7503452d0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b207a2db94584c2ebc67e709ad8fb17d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_60f5f9215e6546a2b79891aa88312e73", + "placeholder": "​", + "style": "IPY_MODEL_42f6cc6e9feb4673a8c6e482a2026a8c", + "value": "100%" + } + }, + "b6cff4fa1db44035a20543aa83ebd82b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "baa37218fe534fc29cded386088d638a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f26be04e4fde480883c11e27522bb702", + "max": 1250, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_899e7cd487334e5ea0010af7503452d0", + "value": 1250 + } + }, + "c0a414c9077e4939b2a08715a67ab684": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_eabd56e5b8644c0b887962a532fbde3c", + "placeholder": "​", + "style": "IPY_MODEL_711614caf63b40539ec7f64acdedb2af", + "value": " 1250/1250 [01:12<00:00, 18.73batch/s]" + } + }, + "dadb563fcb2c402cbabc0cfeab9e7944": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e22a82ff530d4f24916209a59d6a8606": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1a9dad23b6324e03820e703ff22d7aae", + "placeholder": "​", + "style": "IPY_MODEL_6c97387066f94bfc939a3215143dc417", + "value": "Training ODE block with loss function: 100%" + } + }, + "e323d7e51fcb465398f9935f6f799ae4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "eabd56e5b8644c0b887962a532fbde3c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f26be04e4fde480883c11e27522bb702": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fdbba9d39fce407ba41442ba1b4fc566": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e22a82ff530d4f24916209a59d6a8606", + "IPY_MODEL_7d8d3bd63732400ba75b5696fd83b3a7", + "IPY_MODEL_72bcb81e8de942fbac2d7b0382990df6" + ], + "layout": "IPY_MODEL_e323d7e51fcb465398f9935f6f799ae4" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/AROS/Notebooks/Ablation-Study.ipynb b/AROS/Notebooks/Ablation-Study.ipynb new file mode 100644 index 0000000..314f135 --- /dev/null +++ b/AROS/Notebooks/Ablation-Study.ipynb @@ -0,0 +1,286 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "S2YKR1ps79o3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Collecting git+https://github.com/RobustBench/robustbench.git (from -r requirements.txt (line 3))\n", + " Cloning https://github.com/RobustBench/robustbench.git to /tmp/pip-req-build-cdsd2hhb\n", + " Running command git clone --filter=blob:none --quiet https://github.com/RobustBench/robustbench.git /tmp/pip-req-build-cdsd2hhb\n", + " Resolved https://github.com/RobustBench/robustbench.git to commit 776bc95bb4167827fb102a32ac5aea62e46cfaab\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: geotorch in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 1)) (0.3.0)\n", + "Requirement already satisfied: torchdiffeq in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 2)) (0.2.4)\n", + "Requirement already satisfied: torch>=1.9 in /usr/local/lib/python3.10/dist-packages (from geotorch->-r requirements.txt (line 1)) (2.4.1)\n", + "Requirement already satisfied: scipy>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from torchdiffeq->-r requirements.txt (line 2)) (1.14.1)\n", + "Collecting autoattack@ git+https://github.com/fra31/auto-attack.git@a39220048b3c9f2cca9a4d3a54604793c68eca7e#egg=autoattack\n", + " Using cached autoattack-0.1-py3-none-any.whl\n", + "Requirement already satisfied: Jinja2~=3.1.2 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (3.1.4)\n", + "Requirement already satisfied: gdown==5.1.0 in /home/hossein/.local/lib/python3.10/site-packages (from robustbench==1.1->-r requirements.txt (line 3)) (5.1.0)\n", + "Requirement already satisfied: numpy>=1.19.4 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.1.2)\n", + "Requirement already satisfied: pandas>=1.3.5 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.2.3)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (6.0.2)\n", + "Requirement already satisfied: requests>=2.25.0 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.32.3)\n", + "Requirement already satisfied: timm>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (1.0.9)\n", + "Requirement already satisfied: torchvision>=0.8.2 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (0.19.1)\n", + "Requirement already satisfied: tqdm>=4.56.1 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (4.66.5)\n", + "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (4.12.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (3.16.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2~=3.1.2->robustbench==1.1->-r requirements.txt (line 3)) (2.1.5)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2.9.0.post0)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2024.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2024.2)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (2.2.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (3.3.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (2024.8.30)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (3.10)\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (0.4.5)\n", + "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (from timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (0.25.2)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (2.20.5)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (4.12.2)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (9.1.0.70)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (3.3)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (11.0.2.54)\n", + "Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (3.0.0)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (1.13.3)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (11.4.5.107)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (2024.9.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.6.77)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision>=0.8.2->robustbench==1.1->-r requirements.txt (line 3)) (10.4.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (1.16.0)\n", + "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (2.6)\n", + "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (24.1)\n", + "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (1.7.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.9->geotorch->-r requirements.txt (line 1)) (1.3.0)\n" + ] + } + ], + "source": [ + "!pip install -r requirements.txt\n", + "import argparse\n", + "import torch\n", + "import torch.nn as nn\n", + "from evaluate import *\n", + "from utils import *\n", + "from tqdm.notebook import tqdm\n", + "from data_loader import *\n", + "from stability_loss_function import *" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "d7vxLb0179fa" + }, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description=\"Hyperparameters for the script\")\n", + "\n", + "# Define the hyperparameters controlled via CLI 'Ding2020MMA'\n", + "\n", + "parser.add_argument('--in_dataset', type=str, default='cifar100', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')\n", + "parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')\n", + "parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')\n", + "parser.add_argument('--attack_eps', type=float, default=8/255, help='Perturbation bound (epsilon) for PGD attack')\n", + "parser.add_argument('--attack_steps', type=int, default=10, help='Number of steps for the PGD attack')\n", + "parser.add_argument('--attack_alpha', type=float, default=2.5 * (8/255) / 10, help='Step size (alpha) for each PGD attack iteration')\n", + "\n", + "args = parser.parse_args('')\n", + "\n", + "# Set the default model name based on the selected dataset\n", + "if args.in_dataset == 'cifar10':\n", + " default_model_name = 'Rebuffi2021Fixing_70_16_cutmix_extra'\n", + "elif args.in_dataset == 'cifar100':\n", + " default_model_name = 'Wang2023Better_WRN-70-16'\n", + "\n", + "parser.add_argument('--model_name', type=str, default=default_model_name, choices=['Rebuffi2021Fixing_70_16_cutmix_extra', 'Wang2023Better_WRN-70-16'], help='The pre-trained model to be used for feature extraction')\n", + "\n", + "# Re-parse arguments to include model_name selection based on the dataset\n", + "args = parser.parse_args('')\n", + "num_classes = 10 if args.in_dataset == 'cifar10' else 100\n", + "\n", + "trainloader, testloader,test_set, ID_OOD_loader = get_loaders(in_dataset=args.in_dataset)\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "g2TltXvg7MfF" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "59296a90b8c84b1c94648a4c5d68a43b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1250 [00:00