Skip to content

Commit

Permalink
Merge pull request #1102 from lrzpellegrini/test_hl_generator_fix
Browse files Browse the repository at this point in the history
Minor fix to a unittest that left text files in the project directory.
  • Loading branch information
lrzpellegrini committed Jul 28, 2022
2 parents 006ca7c + 54e9731 commit cf66f1f
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions tests/test_high_level_generators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import tempfile
import unittest
from os.path import expanduser

Expand Down Expand Up @@ -118,26 +119,31 @@ def test_filelist_benchmark(self):
expanduser("~") + "/.avalanche/data/cats_and_dogs_filtered/train"
)

for filelist, dir, label in zip(
["train_filelist_00.txt", "train_filelist_01.txt"],
["cats", "dogs"],
[0, 1],
):
# First, obtain the list of files
filenames_list = os.listdir(os.path.join(dirpath, dir))
with open(filelist, "w") as wf:
for name in filenames_list:
wf.write("{} {}\n".format(os.path.join(dir, name), label))

generic_benchmark = filelist_benchmark(
dirpath,
["train_filelist_00.txt", "train_filelist_01.txt"],
["train_filelist_00.txt"],
task_labels=[0, 0],
complete_test_set_only=True,
train_transform=ToTensor(),
eval_transform=ToTensor(),
)
with tempfile.TemporaryDirectory() as tmpdirname:
list_paths = []
for filelist, rel_dir, label in zip(
["train_filelist_00.txt", "train_filelist_01.txt"],
["cats", "dogs"],
[0, 1],
):
# First, obtain the list of files
filenames_list = os.listdir(os.path.join(dirpath, rel_dir))
filelist_path = os.path.join(tmpdirname, filelist)
list_paths.append(filelist_path)
with open(filelist_path, "w") as wf:
for name in filenames_list:
wf.write("{} {}\n".format(
os.path.join(rel_dir, name), label))

generic_benchmark = filelist_benchmark(
dirpath,
list_paths,
[list_paths[0]],
task_labels=[0, 0],
complete_test_set_only=True,
train_transform=ToTensor(),
eval_transform=ToTensor(),
)

self.assertEqual(2, len(generic_benchmark.train_stream))
self.assertEqual(1, len(generic_benchmark.test_stream))
Expand Down

0 comments on commit cf66f1f

Please sign in to comment.