Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/Continvvm/continuum
Browse files Browse the repository at this point in the history
  • Loading branch information
TLESORT committed Jan 2, 2023
2 parents fce15b4 + f61c604 commit e611f1b
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 10 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include requirements.txt
4 changes: 3 additions & 1 deletion continuum/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def slice(
keep_tasks, discard_tasks
)

new_x, new_y, new_t = x[indexes], y[indexes], t[indexes]
new_x, new_y, new_t = x[indexes], y[indexes], None
if t is not None:
new_t = t[indexes]
sliced_dataset = InMemoryDataset(
new_x, new_y, new_t,
data_type=self.data_type
Expand Down
21 changes: 20 additions & 1 deletion continuum/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,26 @@ def unzip(path):
def untar(path):
directory_path = os.path.dirname(path)
with tarfile.open(path) as tar_file:
tar_file.extractall(directory_path)
def is_within_directory(directory, target):

abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

prefix = os.path.commonprefix([abs_directory, abs_target])

return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")

tar.extractall(path, members, numeric_owner=numeric_owner)


safe_extract(tar_file, directory_path)


def download_file_from_google_drive(id, destination):
Expand Down
12 changes: 7 additions & 5 deletions continuum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,24 @@ def _slice(
"No task ids information is present by default with this dataset, "
"thus you cannot slice some task ids."
)
y, t = y.astype(np.int64), t.astype(np.int64)
y = y.astype(np.int64)
if t is not None:
t = t.astype(np.int64)

indexes = set()
if keep_classes:
if keep_classes is not None:
indexes = set(np.where(np.isin(y, keep_classes))[0])
elif discard_classes:
elif discard_classes is not None:
keep_classes = list(set(y) - set(discard_classes))
indexes = set(np.where(np.isin(y, keep_classes))[0])

if keep_tasks:
if keep_tasks is not None:
_indexes = np.where(np.isin(t, keep_tasks))[0]
if len(indexes) > 0:
indexes = indexes.intersection(_indexes)
else:
indexes = indexes.union(_indexes)
elif discard_tasks:
elif discard_tasks is not None:
keep_tasks = list(set(t) - set(discard_tasks))
_indexes = np.where(np.isin(t, keep_tasks))[0]
if len(indexes) > 0:
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/datasets/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ Existing Datasets
+----------------------+------------+------------+----------------+------------------+
| **QMNIST** | 10 | 28x28x1 | YES | Images |
+----------------------+------------+------------+----------------+------------------+
| **SVHN** | 10 | 28x28x1 | YES | Images |
+----------------------+------------+------------+----------------+------------------+
| **MNIST Fellowship** | 30 | 28x28x1 | YES | Images |
+----------------------+------------+------------+----------------+------------------+
| **Rainbow MNIST** | 10 | 28x28x3 | YES | Images |
+----------------------+------------+------------+----------------+------------------+
| **Colored MNIST** | 2 | 28x28x3 | YES | Images |
+----------------------+------------+------------+----------------+------------------+
| **SVHN** | 10 | 32x32x3 | YES | Images |
+----------------------+------------+------------+----------------+------------------+
| **Synbols** | 50 | 32x32x3 | YES | Images |
+----------------------+------------+------------+----------------+------------------+
| **CIFAR10** | 10 | 32x32x3 | YES | Images |
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name="continuum",
version="1.2.4",
version="1.2.7",
author="Arthur Douillard, Timothée Lesort",
author_email="ar.douillard@gmail.com",
description="A clean and simple library for Continual Learning in PyTorch.",
Expand Down
26 changes: 26 additions & 0 deletions tests/test_dataset_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ def dataset():

return x, y, t

@pytest.fixture
def dataset2():
""""Dataset without task index"""
x = np.zeros((20, 4, 4, 3))
y = np.zeros((20,))

for i in range(20):
x[i] = i

c = 0
for i in range(0, 20, 2):
y[i] = c
y[i+1] = c
c += 1

return x, y, None

@pytest.mark.parametrize("keep_classes,discard_classes,keep_tasks,discard_tasks,error,ids", [
([1], [1], None, None, True, None),
Expand Down Expand Up @@ -65,6 +81,16 @@ def test_slice(

assert (np.unique(x) == np.array(ids)).all(), (np.unique(x), ids)

@pytest.mark.parametrize("keep_classes,discard_classes", [
([0, 1], None),
(None, [0, 1])
])
def test_slice_without_t(dataset2, keep_classes, discard_classes):

dataset = InMemoryDataset(*dataset2)
sliced_dataset = dataset.slice(keep_classes, discard_classes)
x, _, _ = sliced_dataset.get_data()


@pytest.mark.parametrize("keep_classes,discard_classes,keep_tasks,discard_tasks,error,ids", [
([1], [1], None, None, True, None),
Expand Down

0 comments on commit e611f1b

Please sign in to comment.