Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update context model code #296

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
46e0fe0
enable annotations loader to create patch context dataset
rwood-97 Nov 22, 2023
ac48935
align classifier_context to classifier
rwood-97 Nov 23, 2023
38b1c65
update trainable_col arg name
rwood-97 Nov 23, 2023
ef688d5
fix color printing
rwood-97 Nov 24, 2023
5ac8f1e
always return images as tuple
rwood-97 Nov 27, 2023
f83a5f3
process inputs as a tuple
rwood-97 Nov 27, 2023
d798454
update attribute names in custom model for clarity
rwood-97 Nov 27, 2023
bb0ec8d
update confusing language in params2optimize
rwood-97 Nov 27, 2023
7f6610b
add context option for generate_layerwise_lrs
rwood-97 Nov 27, 2023
55304d5
remove classifier context (now all in one)
rwood-97 Nov 27, 2023
d8f31d3
remove context container from init imports
rwood-97 Nov 27, 2023
0b33fa1
add docs on how to use context model
rwood-97 Nov 27, 2023
11afa54
Merge branch 'main' into context_classifier
rwood-97 Jan 10, 2024
e978b40
replace `square_cuts` with padding at edge patches
rwood-97 Jan 18, 2024
34014b1
return df after eval
rwood-97 Jan 22, 2024
428f0f3
update context saving
rwood-97 Jan 22, 2024
a1e7941
remove square_cuts arg from tests
rwood-97 Jan 22, 2024
c1b596c
ensure geotiffs are saved correctly
rwood-97 Jan 22, 2024
84340b0
fix context for annotator
rwood-97 Jan 22, 2024
a71a34b
allow users to annotate at context-level
rwood-97 Jan 23, 2024
f7baba7
use iloc not at for getting data
rwood-97 Jan 23, 2024
02d0e67
fix load annotations
rwood-97 Jan 23, 2024
5d54f5e
rename context dataset trasnforms for clarity
rwood-97 Jan 23, 2024
5cc37e7
only add context annotations to annotated patches
rwood-97 Jan 23, 2024
6f2a882
keep all cols when saving
rwood-97 Jan 23, 2024
ee2f4a1
load both patch and context level annotations for context dataset
rwood-97 Jan 23, 2024
a5e36ca
enable classifier to work with patch+ context labels
rwood-97 Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions docs/source/User-guide/Classify/Train.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ For example, if you have set up your directory as recommended in our `Input Guid
.. admonition:: Advanced usage
:class: dropdown

Other arguments you may want to specify when adding metadata to your images include:
Other arguments you may want to specify when loading your annotations include:

- ``delimiter`` - By default, this is set to "\t" so will assume your ``csv`` file is tab delimited. You will need to specify the ``delimiter`` argument if your file is saved in another format.
- ``id_col``, ``patch_paths_col``, ``label_col`` - These are used to indicate the column headings for the columns which contain image IDs, patch file paths and labels respectively. By default, these are set to "image_id", "image_path" and "label".
Expand Down Expand Up @@ -184,31 +184,41 @@ To split your annotated images and create your dataloaders, use:
By default, this will split your annotated images using the :ref:`default train:val:test ratios<ratios>` and apply the :ref:`default image transforms<transforms>` to each by calling the ``.create_datasets()`` method.
It will then create a dataloader for each dataset, using a batch size of 16 and the :ref:`default sampler<sampler>`.

To change the ratios used to split your annotations, you can specify ``frac_train``, ``frac_val`` and ``frac_test``:
To change the batch size used when creating your dataloaders, use the ``batch_size`` argument:

.. code-block:: python

#EXAMPLE
dataloaders = annotated_images.create_dataloaders(frac_train=0.6, frac_val=0.3, frac_test=0.1)
dataloaders = annotated_images.create_dataloaders(batch_size=24)

This will result in a split of 60% (train), 30% (val) and 10% (test).
.. admonition:: Advanced usage
:class: dropdown

To change the batch size used when creating your dataloaders, use the ``batch_size`` argument:
Other arguments you may want to specify when creating your dataloaders include:

- ``sampler`` - By default, this is set to ``default`` and so the :ref:`default sampler<sampler>` will be used when creating your dataloaders and batches. You can choose not to use a sampler by specifying ``sampler=None`` or, you can define a custom sampler using `pytorch's sampler class <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`__.
- ``shuffle`` - If your datasets are ordered (e.g. ``"a","a","a","a","b","c"``), you can use ``shuffle=True`` to create dataloaders which contain shuffled batches of data. This cannot be used in conjunction with a sampler and so, by default, ``shuffle=False``.


If you would like to use custom settings when creating your datasets, you should call the ``create_datasets()`` method directly instead of via the ``create_dataloaders()`` method.
You should then run the ``create_dataloaders()`` method afterwards to create your dataloaders as before.

For example, to change the ratios used to split your annotations, you can specify ``frac_train``, ``frac_val`` and ``frac_test``:

.. code-block:: python

#EXAMPLE
dataloaders = annotated_images.create_dataloaders(batch_size=24)
annotated_images.create_datasets(frac_train=0.6, frac_val=0.3, frac_test=0.1)
dataloaders = annotated_images.create_dataloaders()

This will result in a split of 60% (train), 30% (val) and 10% (test).

.. admonition:: Advanced usage
:class: dropdown

Other arguments you may want to specify when adding metadata to your images include:
Other arguments you may want to specify when creating your datasets include:

- ``sampler`` - By default, this is set to ``default`` and so the :ref:`default sampler<sampler>` will be used when creating your dataloaders and batches. You can choose not to use a sampler by specifying ``sampler=None`` or, you can define a custom sampler using `pytorch's sampler class <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`__.
- ``shuffle`` - If your datasets are ordered (e.g. ``"a","a","a","a","b","c"``), you can use ``shuffle=True`` to create dataloaders which contain shuffled batches of data. This cannot be used in conjunction with a sampler and so, by default, ``shuffle=False``.
- ``train_transform``, ``val_transform`` and ``test_transform`` - By default, these are set to "train", "val" and "test" respectively and so the :ref:`default image transforms<transforms>` for each of these sets are applied to the images. You can define your own transforms, using `torchvision's transforms module <https://pytorch.org/vision/stable/transforms.html>`__, and apply these to your datasets by specifying the ``train_transform``, ``val_transform`` and ``test_transform`` arguments.

- ``context_dataset`` - By default, this is set to ``False`` and so only the patches themselves are used as inputs to the model. Setting ``context_dataset=True`` will result in datasets which return both the patches and their context as inputs for the model.

Train
------
Expand Down Expand Up @@ -336,6 +346,30 @@ There are a number of options for the ``model`` argument:
.. note:: You will need to install the `timm <https://huggingface.co/docs/timm/index>`__ library to do this (``pip install timm``).


.. admonition:: Context models
:class: dropdown

If you have created context datasets, you will need to load two models (one for processing patches and one for processing patches plus context) using the methods above.
You should then pass these models to MapReaders ``twoParrallelModels`` class which combines their outputs through one fully connected layer:

.. code:: python

# define fc layer inputs and output
import torch

fc_layer = torch.nn.Linear(1004, len(annotated_images.labels_map))

The number of inputs to your fully connected layer should be the sum of the number of outputs from your two models and the number of outputs should be the number of classes (labels) you are using.

Your models and ``fc_layer`` should then be used to set up your custom model:

.. code:: python

from mapreader.classify.custom_models import twoParrallelModels

my_model = twoParrallelModels(patch_model, context_model, fc_layer)


Define criterion, optimizer and scheduler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -381,7 +415,7 @@ In order to train/fine-tune your model, will need to define:
You should change this to suit your needs.

The ``params2optimize`` argument can be used to select which parameters to optimize during training.
By default, this is set to ``"infer"``, meaning that all trainable parameters will be optimized.
By default, this is set to ``"default"``, meaning that all trainable parameters will be optimized.

When training/fine-tuning your model, you can either use one learning rate for all layers in your neural network or define layerwise learning rates (i.e. different learning rates for each layer in your neural network).
Normally, when fine-tuning pre-trained models, layerwise learning rates are favoured, with smaller learning rates assigned to the first layers and larger learning rates assigned to later layers.
Expand All @@ -401,6 +435,8 @@ In order to train/fine-tune your model, will need to define:
#EXAMPLE
params2optimize = my_classifier.generate_layerwise_lrs(min_lr=1e-4, max_lr=1e-3, spacing="geomspace")

.. note:: If you are using a context model, you should also set ``parameter_groups=True`` when running the ``generate_layerwise_lrs()`` method. This will ensure the two branches of your models are optimized properly.

You should then pass your ``params2optimize`` list to the ``.initialize_optimizer()`` method:

.. code-block:: python
Expand Down
1 change: 0 additions & 1 deletion mapreader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mapreader.classify.datasets import PatchDataset
from mapreader.classify.datasets import PatchContextDataset
from mapreader.classify.classifier import ClassifierContainer
from mapreader.classify.classifier_context import ClassifierContextContainer
from mapreader.classify import custom_models

from mapreader.process import process
Expand Down
Loading
Loading