Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: Python package

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
build:
strategy:
fail-fast: true
matrix:
os: [ubuntu-latest]
python-version: ["3.12"]
# We aim to support the versions on pytorch.org
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.4.0"]
include:
- os: windows-latest
torch-version: 2.4.0
python-version: "3.12"

runs-on: ${{ matrix.os }}

steps:
- name: Cache dependencies
id: pip-cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}

- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install package
run: |
python -m pip install git+https://github.com/RobustBench/robustbench.git
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev]'

- name: Run pytest tests
timeout-minutes: 10
run: |
pip install pytest
python -m pytest

- name: Build package
run: |
make build

- name: Check reinstall script
timeout-minutes: 3
run: |
./reinstall.sh
21 changes: 21 additions & 0 deletions .github/workflows/codespell.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
name: Codespell

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
codespell:
name: Check for spelling errors
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v3
- name: Codespell
uses: codespell-project/actions-codespell@v1
with:
ignore_words_list: aros, fpr, tpr, idx, fpr95
52 changes: 52 additions & 0 deletions .github/workflows/release-pypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: release

on:
push:
tags:
- 'v*.*.*'
pull_request:
branches:
- main
types:
- labeled
- opened
- edited
- synchronize
- reopened

jobs:
release:
runs-on: ubuntu-latest

steps:
- name: Cache dependencies
id: pip-cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip

- name: Checkout code
uses: actions/checkout@v3

- name: Build and publish to Test PyPI
if: ${{ (github.ref != 'refs/heads/main') && (github.event.label.name == 'release') }}
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }}
run: |
make dist
ls dist/
tar tvf dist/aros-node-*.tar.gz
python3 -m twine upload --repository testpypi dist/*

- name: Build and publish to PyPI
if: ${{ github.event_name == 'push' }}
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
make dist
ls dist/
tar tvf dist/aros-node-*.tar.gz
python3 -m twine upload dist/*
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.DS_Store
.tar.gz

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
3 changes: 0 additions & 3 deletions AROS/__init__.py

This file was deleted.

9 changes: 9 additions & 0 deletions aros_node/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# © M.W. Mathis Lab | Hossein Mirzaei & M.W. Mathis
# https://github.com/AdaptiveMotorControlLab/AROS
# Licensed under Apache 2.0

from aros_node.version import __version__
from aros_node.data_loader import LabelChangedDataset, get_subsampled_subset, get_loaders
from aros_node.evaluate import compute_fpr95, compute_auroc, compute_aupr, get_clean_AUC, wrapper_method
from aros_node.stability_loss_function import *
from aros_node.utils import *
File renamed without changes.
2 changes: 1 addition & 1 deletion AROS/evaluate.py → aros_node/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from tqdm.notebook import tqdm
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc
from utils import *
from aros_node.utils import *
import argparse
import torch
import torch.nn as nn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from robustbench.utils import load_model
import torch.nn as nn
from torch.nn.parameter import Parameter
import utils
from utils import *
import aros_node.utils
from aros_node.utils import *
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset, SubsetRandomSampler, ConcatDataset
import numpy as np
from tqdm.notebook import tqdm
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions aros_node/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.1"
10 changes: 5 additions & 5 deletions AROS/main.py → main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

!pip install -r requirements.txt
import aros_node
import argparse
import torch
import torch.nn as nn
from evaluate import *
from utils import *
from aros_node.evaluate import *
from aros_node.utils import *
from tqdm.notebook import tqdm
from data_loader import *
from stability_loss_function import *
from aros_node.data_loader import *
from aros_node.stability_loss_function import *

def main():
parser = argparse.ArgumentParser(description="Hyperparameters for the script")
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
25 changes: 25 additions & 0 deletions reinstall.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

# Re-install the package. By running './reinstall.sh'
#
# Note that AROS uses the build
# system specified in
# PEP517 https://peps.python.org/pep-0517/ and
# PEP518 https://peps.python.org/pep-0518/
# and hence there is no setup.py file.

set -e # abort on error

pip uninstall -y aros-node

# Get version
VERSION=0.0.1
echo "Upgrading to AROS v${VERSION}"

# Upgrade the build system (PEP517/518 compatible)
python3 -m pip install virtualenv
python3 -m pip install --upgrade build
python3 -m build --sdist --wheel .

# Reinstall the package with most recent version
pip install --upgrade --no-cache-dir "dist/aros_node-${VERSION}-py3-none-any.whl"
9 changes: 7 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
geotorch
torch
torchdiffeq
git+https://github.com/RobustBench/robustbench.git
timm==1.0.9
timm==1.0.9
robustbench
numpy
scikit-learn
scipy
tqdm
26 changes: 4 additions & 22 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,41 +1,23 @@
[metadata]
name = aros
version = 0.0.1
name = aros-node
version = attr: aros_node.version.__version__
author = Hossein Mirzaei, Mackenzie Mathis
author_email = mackenzie@post.harvard.edu
description = AROS: Adversarially Robust Out-of-Distribution Detection through Stability
long_description = file: README.md
long_description_content_type = text/markdown
license_files = LICENSE.md
license_file_type = text/markdown
url = https://github.com/AdaptiveMotorControlLab/AROS
project_urls =
Bug Tracker = https://github.com/AdaptiveMotorControlLab/AROS/issues
classifiers =
Development Status :: 4 - Beta
Environment :: GPU :: NVIDIA CUDA
Intended Audience :: Science/Research
Operating System :: OS Independent
Programming Language :: Python :: 3
Topic :: Scientific/Engineering :: Artificial Intelligence
License :: OSI Approved :: Apache Software License

[options]
packages = find:
include_package_data = True
python_requires = >=3.10
install_requires =
geotorch
torchdiffeq
git+https://github.com/RobustBench/robustbench.git
install_requires = file: requirements.txt

[options.extras_require]
dev =
pylint
toml
yapf
black
pytest

[bdist_wheel]
universal=0
pytest
69 changes: 69 additions & 0 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import ToTensor
from aros_node import (
LabelChangedDataset,
get_subsampled_subset,
get_loaders,
)

# Set up transformations and datasets for tests
transform_tensor = ToTensor()

@pytest.fixture
def cifar10_datasets():
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_tensor)
testset = CIFAR10(root='./data', train=False, download=True, transform=transform_tensor)
return trainset, testset

@pytest.fixture
def cifar100_datasets():
trainset = CIFAR100(root='./data', train=True, download=True, transform=transform_tensor)
testset = CIFAR100(root='./data', train=False, download=True, transform=transform_tensor)
return trainset, testset

def test_label_changed_dataset(cifar10_datasets):
_, testset = cifar10_datasets
new_label = 99
relabeled_dataset = LabelChangedDataset(testset, new_label)

assert len(relabeled_dataset) == len(testset), "Relabeled dataset should match the original dataset length"

for img, label in relabeled_dataset:
assert label == new_label, "All labels should be changed to the new label"

def test_get_subsampled_subset(cifar10_datasets):
trainset, _ = cifar10_datasets
subset_ratio = 0.1
subset = get_subsampled_subset(trainset, subset_ratio=subset_ratio)

expected_size = int(len(trainset) * subset_ratio)
assert len(subset) == expected_size, f"Subset size should be {expected_size}"

def test_get_loaders_cifar10(cifar10_datasets):
train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar10')

assert isinstance(train_loader, DataLoader)
assert isinstance(test_loader, DataLoader)
assert isinstance(test_loader_vs_other, DataLoader)

for images, labels in test_loader:
assert images.shape[0] == 16, "Test loader batch size should be 16"
break

def test_get_loaders_cifar100(cifar100_datasets):
train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar100')

assert isinstance(train_loader, DataLoader)
assert isinstance(test_loader, DataLoader)
assert isinstance(test_loader_vs_other, DataLoader)

for images, labels in test_loader:
assert images.shape[0] == 16, "Test loader batch size should be 16"
break

def test_get_loaders_invalid_dataset():
with pytest.raises(ValueError, match="Dataset 'invalid_dataset' is not supported."):
get_loaders('invalid_dataset')
Loading