-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
57455c7
commit 0f05c39
Showing
30 changed files
with
4,036 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,96 @@ | ||
| ## My Project | ||
| # Task2Vec | ||
|
|
||
| TODO: Fill this README out! | ||
| This is an implementation of the Task2Vec method described in the paper [Task2Vec: Task Embedding for Meta-Learning](https://arxiv.org/abs/1902.03545). | ||
|
|
||
| Be sure to: | ||
|
|
||
| * Change the title in this README | ||
| * Edit your repository description on GitHub | ||
| Task2Vec provides vectorial representations of learning tasks (datasets) which can be used to reason about the nature of | ||
| those tasks and their relations. | ||
| In particular, it provides a fixed-dimensional embedding of the task that is independent of details such as the number of | ||
| classes and does not require any understanding of the class label semantics. The distance between embeddings | ||
| matches our intuition about semantic and taxonomic relations between different visual tasks | ||
| (e.g., tasks based on classifying different types of plants are similar). The resulting vector can be used to | ||
| represent a dataset in meta-learning applicatins, and allows for example to select the best feature extractor for a task | ||
| without an expensive brute force search. | ||
|
|
||
| ## License | ||
| ## Quick start | ||
|
|
||
| This project is licensed under the Apache-2.0 License. | ||
| To compute and embedding using task2vec, you just need to provide a dataset and a probe network, for example: | ||
| ```python | ||
| from task2vec import Task2Vec | ||
| from models import get_model | ||
| from datasets import get_dataset | ||
|
|
||
| dataset = get_dataset('cifar10') | ||
| probe_network = get_model('resnet34', pretrained=True, num_classes=10) | ||
| embedding = Task2Vec(probe_network).embed(dataset) | ||
| ``` | ||
| Task2Vec uses the diagonal of the Fisher Information Matrix to compute an embedding of the task. In this implementation | ||
| we provide two methods, `montecarlo` and `variational`. The first is the fastest and is the default, but `variational` | ||
| may be more robust in some situations (in particular it is the one used in the paper). You can try it using: | ||
| ```python | ||
| task2vec.embed(dataset, probe_network, method='variational') | ||
| ``` | ||
| Now, let's try computing several embedding and plot the distance matrix between the tasks: | ||
| ```python | ||
| from task2vec import Task2Vec | ||
| from models import get_model | ||
| import datasets | ||
| import task_similarity | ||
|
|
||
| dataset_names = ('mnist', 'cifar10', 'cifar100', 'letters', 'kmnist') | ||
| dataset_list = [datasets.__dict__[name]('./data')[0] for name in dataset_names] | ||
|
|
||
| embeddings = [] | ||
| for name, dataset in zip(dataset_names, dataset_list): | ||
| print(f"Embedding {name}") | ||
| probe_network = get_model('resnet34', pretrained=True, num_classes=int(max(dataset.targets)+1)).cuda() | ||
| embeddings.append( Task2Vec(probe_network, max_samples=1000, skip_layers=6).embed(dataset) ) | ||
| task_similarity.plot_distance_matrix(embeddings, dataset_names) | ||
| ``` | ||
| You can also look at the notebook `small_datasets_example.ipynb` for a runnable implementation of this code snippet. | ||
|
|
||
| ## Experiments on iNaturalist and CUB | ||
|
|
||
| ### Downloading the data | ||
| First, decide where you will store all the data. For example: | ||
| ``` | ||
| export DATA_ROOT=./data | ||
| ``` | ||
| To download [CUB-200](http://www.vision.caltech.edu/visipedia/CUB-200.html), | ||
| from the repository root run: | ||
| ```sh | ||
| ./scripts/download_cub.sh $DATA_ROOT | ||
| ``` | ||
|
|
||
| To download [iNaturalist 2018](https://github.com/visipedia/inat_comp/tree/master/2018), | ||
| from the repository root run: | ||
| ```sh | ||
| ./scripts/download_inat2018.sh $DATA_ROOT | ||
| ``` | ||
| **WARNING:** iNaturalist needs ~319Gb for download and extraction. | ||
| Consider downloading and extracting it manually following the instructions | ||
| [here](https://github.com/visipedia/inat_comp/tree/master/2018). | ||
|
|
||
| ### Computing the embedding of all tasks | ||
| To compute the embedding on a single task of CUB + iNat2018, run: | ||
| ```sh | ||
| python main.py task2vec.method=montecarlo dataset.root=$DATA_ROOT dataset.name=cub_inat2018 dataset.task_id=$TASK_ID -m | ||
| ``` | ||
| This will use the `montecarlo` Fisher approximation to compute the embedding of the task number `$TASK_ID` in the CUB + iNAT2018 meta-task. | ||
| The result is stored in a pickle file inside `outputs`. | ||
|
|
||
| To compute all embeddings at once, we can use Hydra's multi-run mode as follow: | ||
| ```sh | ||
| python main.py task2vec.method=montecarlo dataset.root=$DATA_ROOT dataset.name=cub_inat2018 dataset.task_id=`seq -s , 0 50` -m | ||
| ``` | ||
| This will compute the embeddings of the first 50 tasks in the CUB + iNat2018 meta-task. | ||
| To plot the 50x50 distance matrix between these tasks, first download all the `iconic_taxa` | ||
| [image files](https://github.com/inaturalist/inaturalist/tree/master/app/assets/images/iconic_taxa) | ||
| to `./static/iconic_taxa`, and then run: | ||
| ```sh | ||
| python plot_distance_cub_inat.py --data-root $DATA_ROOT ./multirun/montecarlo | ||
| ``` | ||
| The result should look like the following. Note that task regarding classification of similar life forms | ||
| (e.g, different types of birds, plants, mammals) cluster together. | ||
|
|
||
|  |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| device: "cuda:0" | ||
|
|
||
| task2vec: | ||
| # Maximum number of samples in the dataset used to estimate the Fisher | ||
| max_samples: 10000 | ||
| skip_layers: 0 | ||
|
|
||
| # Whether to put batch normalization in eval mode (true) or train mode (false) when computing the Fisher | ||
| # fix_batch_norm: true | ||
|
|
||
| classifier_opts: | ||
| optimizer: adam | ||
| epochs: 10 | ||
| learning_rate: 0.0004 | ||
| weight_decay: 0.0001 | ||
|
|
||
| defaults: | ||
| - task2vec: montecarlo | ||
|
|
||
| dataset: | ||
| name: inat2018 | ||
| task_id: 0 | ||
| root: ~/data | ||
|
|
||
| # Probe network to use | ||
| model: | ||
| arch: resnet34 | ||
| pretrained: true | ||
|
|
||
| loader: | ||
| batch_size: 100 | ||
| num_workers: 6 | ||
| balanced_sampling: true | ||
| num_samples: 10000 | ||
|
|
||
| hydra: | ||
| sweep: | ||
| dir: ./multirun/${task2vec.method} | ||
| subdir: ${hydra.job.num}_${hydra.job.override_dirname} | ||
| # subdir: ${hydra.job.num}_${hydra.job.num}_${hydra.job.override_dirname} | ||
|
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| task2vec: | ||
| method: montecarlo | ||
| method_opts: | ||
| epochs: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| task2vec: | ||
| method: variational | ||
| method_opts: | ||
| beta: 1.0e-7 | ||
| epochs: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .dataset import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
|
|
||
| import numpy as np | ||
| from torchvision.datasets import CIFAR10, CIFAR100 | ||
|
|
||
| from .dataset import ClassificationTaskDataset | ||
| from .expansion import ClassificationTaskExpander | ||
|
|
||
|
|
||
| class SplitCIFARTask: | ||
| """SplitCIFARTask generates Split CIFAR task | ||
| Parameters | ||
| ---------- | ||
| cifar10_dataset : CIFAR10Dataset | ||
| cifar100_dataset : CIFAR100Dataset | ||
| """ | ||
|
|
||
| def __init__(self, cifar10_dataset, cifar100_dataset): | ||
| self.cifar10_dataset = cifar10_dataset | ||
| self.cifar100_dataset = cifar100_dataset | ||
|
|
||
| def generate(self, task_id=0, transform=None, target_transform=None): | ||
| """Generate tasks given the classes | ||
| Parameters | ||
| ---------- | ||
| task_id : int 0-10 (default 0) | ||
| 0 = CIFAR10, 1 = first 10 of CIFAR100, 2 = second 10 of CIFAR100, ... | ||
| transform : callable (default None) | ||
| Optional transform to be applied on a sample. | ||
| target_transform : callable (default None) | ||
| Optional transform to be applied on a label. | ||
| Returns | ||
| ------- | ||
| Task | ||
| """ | ||
| assert isinstance(task_id, int) | ||
| assert 0 <= task_id <= 10, task_id | ||
|
|
||
| task_expander = ClassificationTaskExpander() | ||
| if task_id == 0: | ||
| classes = tuple(range(10)) | ||
| return task_expander(self.cifar10_dataset, | ||
| {c: new_c for new_c, c in enumerate(classes)}, | ||
| label_names={c: name for c, name in self.cifar10_dataset.label_names_map.items()}, | ||
| task_id=task_id, | ||
| task_name='Split CIFAR: CIFAR-10 {}'.format(classes), | ||
| transform=transform, | ||
| target_transform=target_transform) | ||
| else: | ||
| classes = tuple([int(c) for c in np.arange(10) + 10 * (task_id - 1)]) | ||
| return task_expander(self.cifar100_dataset, | ||
| {c: new_c for new_c, c in enumerate(classes)}, | ||
| label_names={classes.index(old_c): name for old_c, name in | ||
| self.cifar100_dataset.label_names_map.items() if old_c in classes}, | ||
| task_id=task_id, | ||
| task_name='Split CIFAR: CIFAR-100 {}'.format(classes), | ||
| transform=transform, | ||
| target_transform=target_transform) | ||
|
|
||
|
|
||
| class CIFAR10Dataset(ClassificationTaskDataset): | ||
| """CIFAR10 Dataset | ||
| Parameters | ||
| ---------- | ||
| path : str (default None) | ||
| path to dataset (should contain images folder in same directory) | ||
| if None, search using DATA environment variable | ||
| train : bool (default True) | ||
| if True, load train split otherwise load test split | ||
| download: bool (default False) | ||
| if True, downloads the dataset from the internet and | ||
| puts it in path directory; otherwise if dataset is already downloaded, | ||
| it is not downloaded again | ||
| metadata : dict (default empty) | ||
| extra arbitrary metadata | ||
| transform : callable (default None) | ||
| Optional transform to be applied on a sample. | ||
| target_transform : callable (default None) | ||
| Optional transform to be applied on a label. | ||
| """ | ||
|
|
||
| def __init__(self, path, train=True, download=False, | ||
| metadata={}, transform=None, target_transform=None): | ||
| num_classes, task_name = self._get_settings() | ||
| assert isinstance(path, str) | ||
| assert isinstance(train, bool) | ||
|
|
||
| self.cifar = self._get_cifar(path, train, transform, target_transform, download) | ||
|
|
||
| super(CIFAR10Dataset, self).__init__(list(self.cifar.data), | ||
| [int(x) for x in self.cifar.targets], | ||
| label_names={l: str(l) for l in range(num_classes)}, | ||
| root=path, | ||
| task_id=None, | ||
| task_name=task_name, | ||
| metadata=metadata, | ||
| transform=transform, | ||
| target_transform=target_transform) | ||
|
|
||
| def _get_settings(self): | ||
| return 10, 'CIFAR10' | ||
|
|
||
| def _get_cifar(self, path, train, transform, target_transform, download=True): | ||
| return CIFAR10(path, train=train, transform=transform, | ||
| target_transform=target_transform, download=download) | ||
|
|
||
|
|
||
| class CIFAR100Dataset(CIFAR10Dataset): | ||
| """CIFAR100 Dataset | ||
| Parameters | ||
| ---------- | ||
| path : str (default None) | ||
| path to dataset (should contain images folder in same directory) | ||
| if None, search using DATA environment variable | ||
| train : bool (default True) | ||
| if True, load train split otherwise load test split | ||
| download: bool (default False) | ||
| if True, downloads the dataset from the internet and | ||
| puts it in path directory; otherwise if dataset is already downloaded, | ||
| it is not downloaded again | ||
| metadata : dict (default empty) | ||
| extra arbitrary metadata | ||
| transform : callable (default None) | ||
| Optional transform to be applied on a sample. | ||
| target_transform : callable (default None) | ||
| Optional transform to be applied on a label. | ||
| """ | ||
|
|
||
| def __init__(self, path=None, train=True, download=False, | ||
| metadata={}, transform=None, target_transform=None): | ||
| super(CIFAR100Dataset, self).__init__(path=path, | ||
| train=train, | ||
| metadata=metadata, | ||
| transform=transform, | ||
| target_transform=target_transform) | ||
|
|
||
| def _get_settings(self): | ||
| return 100, 'CIFAR100' | ||
|
|
||
| def _get_cifar(self, path, train, transform, target_transform, download=True): | ||
| return CIFAR100(path, train=train, transform=transform, | ||
| target_transform=target_transform, download=download) |
Oops, something went wrong.