Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions Notebooks/AROS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,31 @@
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/AdaptiveMotorControlLab/AROS/blob/main/Notebooks/AROS.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"## Adversarially Robust Out-of-Distribution Detection Using Lyapunov-Stabilized Embeddings"
],
"metadata": {
"id": "RFxEz28oe7dE"
}
},
"source": [
"## Adversarially Robust Out-of-Distribution Detection Using Lyapunov-Stabilized Embeddings"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZL5Va1N940xJ"
},
"source": [
"This notebook is designed to replicate and analyze the results presented in Table 1 of the AROS paper, focusing on out-of-distribution detection performance under both attack scenarios and clean evaluation. The dataset configurations involve using CIFAR-10 and CIFAR-100 as in-distribution and out-of-distribution datasets. The notebook is structured to load a pre-trained model as the encoder, followed by generating fake OOD embeddings through sampling. The model is then trained using the designed loss function and evaluated across various OOD detection benchmarks to assess its performance under different conditions.\n",
"\n"
],
"metadata": {
"id": "ZL5Va1N940xJ"
}
]
},
{
"cell_type": "markdown",
Expand All @@ -40,41 +40,41 @@
},
{
"cell_type": "code",
"source": [
"!git clone https://github.com/AdaptiveMotorControlLab/AROS.git"
],
"execution_count": null,
"metadata": {
"id": "TdY-7pyGq4oN"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"!git clone https://github.com/MMathisLab/AROS.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "owrQtpTxrbth"
},
"outputs": [],
"source": [
"%cd /content/AROS\n",
"%ls\n",
"!pip install -r requirements.txt"
],
"metadata": {
"id": "owrQtpTxrbth"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WgsBOHhNrtYD"
},
"outputs": [],
"source": [
"import argparse\n",
"import torch\n",
"import torch.nn as nn\n",
"from tqdm.notebook import tqdm"
],
"metadata": {
"id": "WgsBOHhNrtYD"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -112,6 +112,11 @@
"\n",
"# Define the hyperparameters controlled via CLI 'Ding2020MMA'\n",
"\n",
"\n",
"parser.add_argument('--fast', type=bool, default=True, help='Toggle between fast and full fake data generation modes')\n",
"parser.add_argument('--epoch1', type=int, default=2, help='Number of epochs for stage 1')\n",
"parser.add_argument('--epoch2', type=int, default=1, help='Number of epochs for stage 2')\n",
"parser.add_argument('--epoch3', type=int, default=2, help='Number of epochs for stage 3')\n",
"parser.add_argument('--in_dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')\n",
"parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')\n",
"parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')\n",
Expand Down Expand Up @@ -474,9 +479,9 @@
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": [],
"include_colab_link": true,
"machine_shape": "hm",
"include_colab_link": true
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
Expand Down
Loading