Skip to content

Commit

Permalink
enable classifier to work with patch+ context labels
Browse files Browse the repository at this point in the history
  • Loading branch information
rwood-97 committed Jan 25, 2024
1 parent ee2f4a1 commit a5e36ca
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 62 deletions.
126 changes: 84 additions & 42 deletions mapreader/classify/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
dataloaders: dict[str, DataLoader] | None = None,
device: str | None = "default",
input_size: int | None = (224, 224),
is_inception: bool | None = False,
is_inception: bool = False,
context: bool = False,
load_path: str | None = None,
force_device: bool | None = False,
**kwargs,
Expand Down Expand Up @@ -65,6 +66,9 @@ def __init__(
is_inception : bool, optional
Whether the model is an Inception-style model.
Default is ``False``.
context : bool, optional
Whether the model is uses patch and context inputs.
Default is `False`.
load_path : str, optional
The path to an ``.obj`` file containing a
force_device : bool, optional
Expand All @@ -88,8 +92,10 @@ def __init__(
The model.
input_size : None or tuple of int
The size of the input to the model.
is_inception : None or bool
is_inception : bool
A flag indicating if the model is an Inception model.
context : bool
A flag indicating if the model uses patch and context as inputs.
optimizer : None or torch.optim.Optimizer
The optimizer being used for training the model.
scheduler : None or torch.optim.lr_scheduler._LRScheduler
Expand Down Expand Up @@ -150,6 +156,7 @@ def __init__(
self.model = model.to(self.device)
self.input_size = input_size
self.is_inception = is_inception
self.context = context
elif isinstance(model, str):
self._initialize_model(model, **kwargs)

Expand Down Expand Up @@ -183,7 +190,6 @@ def generate_layerwise_lrs(
min_lr: float,
max_lr: float,
spacing: str | None = "linspace",
parameter_groups: bool = False,
) -> list[dict]:
"""
Calculates layer-wise learning rates for a given set of model
Expand All @@ -201,27 +207,32 @@ def generate_layerwise_lrs(
where `"linspace"` uses evenly spaced learning rates over a
specified interval and `"geomspace"` uses learning rates spaced
evenly on a log scale (a geometric progression). By default ``"linspace"``.
parameter_groups : bool, optional
When using context mode, whether to consider parameters belonging to the patch model and context model as separate groups.
If True, layers belonging to each group will be assigned the same learning rate.
Defaults to ``False``.
Returns
-------
list of dicts
A list of dictionaries containing the parameters and learning
rates for each layer.
Notes
-----
parameter_groups : bool, optional
When using context mode, whether to consider parameters belonging to the patch model and context model as separate groups.
If True, layers belonging to each group will be assigned the same learning rate.
Defaults to ``False``.
"""

if spacing.lower() not in ["linspace", "geomspace"]:
raise NotImplementedError(
'[ERROR] ``spacing`` must be one of "linspace" or "geomspace"'
)

if parameter_groups:
if self.context:
params2optimize = []

for group in ["patch_model", "context_model"]:
for group in set(
tuple[0].split(".")[0] for tuple in [*self.model.named_parameters()]
):
group_params = [
params
for (name, params) in self.model.named_parameters()
Expand Down Expand Up @@ -309,6 +320,10 @@ def initialize_optimizer(
if optim_param_dict is None:
optim_param_dict = {"lr": 0.001}
if params2optimize == "default":
if self.context:
raise ValueError(
"[ERROR] When using context model, first call `params2optimize` cannot be set to `default`."
)
params2optimize = filter(lambda p: p.requires_grad, self.model.parameters())

if optim_type.lower() in ["adam"]:
Expand Down Expand Up @@ -899,7 +914,9 @@ def train_core(
self.dataloaders[phase]
):
inputs = tuple(input.to(self.device) for input in inputs)
label_indices = label_indices.to(self.device)
label_indices = tuple(
label_index.to(self.device) for label_index in label_indices
)

if self.optimizer is None:
if phase.lower() in train_phase_names:
Expand Down Expand Up @@ -931,30 +948,42 @@ def train_core(
):
outputs, aux_outputs = self.model(*inputs)

if not all(
isinstance(out, torch.Tensor)
for out in [outputs, aux_outputs]
):
try:
outputs = outputs.logits
aux_outputs = aux_outputs.logits
except AttributeError as err:
raise AttributeError(err.message)

loss1 = self.criterion(outputs, label_indices)
loss2 = self.criterion(aux_outputs, label_indices)
# XXX From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 # noqa
if not isinstance(outputs, torch.Tensor):
outputs = self._get_logits(outputs)
if not isinstance(aux_outputs, torch.Tensor):
aux_outputs = self._get_logits(aux_outputs)

loss1 = self.criterion(outputs, *label_indices)
loss2 = self.criterion(aux_outputs, *label_indices)
# https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
loss = loss1 + 0.4 * loss2

elif self.context:
(patch_outputs, context_outputs), outputs = self.model(
*inputs
)

if not isinstance(outputs, torch.Tensor):
outputs = self._get_logits(outputs)
if not isinstance(patch_outputs, torch.Tensor):
patch_outputs = self._get_logits(patch_outputs)
if not isinstance(context_outputs, torch.Tensor):
context_outputs = self._get_logits(context_outputs)

loss1 = self.criterion(outputs, label_indices[0])
loss2 = self.criterion(patch_outputs, label_indices[0])
loss3 = self.criterion(outputs, label_indices[1])

loss = loss1 + 0.4 * loss2 + 0.4 * loss3

else:
outputs = self.model(*inputs)

if not isinstance(outputs, torch.Tensor):
try:
outputs = outputs.logits
except AttributeError as err:
raise AttributeError(err.message)
loss = self.criterion(outputs, label_indices)
outputs = self._get_logits(outputs)

loss = self.criterion(outputs, *label_indices)
print(loss, type(loss))

_, pred_label_indices = torch.max(outputs, dim=1)

Expand All @@ -970,21 +999,23 @@ def train_core(
# batch_loop.set_postfix(loss=loss.data)
# batch_loop.refresh()
else:
outputs = self.model(*inputs)
if self.context:
(patch_outputs, context_outputs), outputs = self.model(
*inputs
)
else:
outputs = self.model(*inputs)

if not isinstance(outputs, torch.Tensor):
try:
outputs = outputs.logits
except AttributeError as err:
raise AttributeError(err.message)
self._get_logits(outputs)

_, pred_label_indices = torch.max(outputs, dim=1)

running_pred_conf.extend(
torch.nn.functional.softmax(outputs, dim=1).cpu().tolist()
)
running_pred_label_indices.extend(pred_label_indices.cpu().tolist())
running_orig_label_indices.extend(label_indices.cpu().tolist())
running_orig_label_indices.extend(label_indices[0].cpu().tolist())

if batch_idx % print_info_batch_freq == 0:
curr_inp_counts = min(
Expand Down Expand Up @@ -1089,7 +1120,15 @@ def train_core(
print(
f"[INFO] Model at epoch {self.best_epoch} has least valid loss ({self.best_loss:.4f}) so will be saved.\n\
[INFO] Path: {save_model_path}"
) # noqa
)

@staticmethod
def _get_logits(out):
try:
out = out.logits
except AttributeError as err:
raise AttributeError(err.message)
return out

def calculate_add_metrics(
self,
Expand Down Expand Up @@ -1495,6 +1534,7 @@ def _initialize_model(
self.model = model_dw.to(self.device)
self.input_size = input_size
self.is_inception = is_inception
self.context = False

def show_sample(
self,
Expand Down Expand Up @@ -1567,7 +1607,7 @@ def show_sample(
out = torchvision.utils.make_grid(input)
self._imshow(
out,
title=f"{labels}\n{label_indices.tolist()}",
title=f"{labels[0]}\n{label_indices[0].tolist()}",
figsize=figsize,
)

Expand Down Expand Up @@ -1693,15 +1733,17 @@ def show_inference_sample_results(
with torch.no_grad():
for inputs, _labels, label_indices in iter(self.dataloaders[set_name]):
inputs = tuple(input.to(self.device) for input in inputs)
label_indices = label_indices.to(self.device)
label_indices = tuple(
label_index.to(self.device) for label_index in label_indices
)

outputs = self.model(*inputs)
if self.context:
_, outputs = self.model(*inputs)
else:
outputs = self.model(*inputs)

if not isinstance(outputs, torch.Tensor):
try:
outputs = outputs.logits
except AttributeError as err:
raise AttributeError(err.message)
self._get_logits(outputs)

pred_conf = torch.nn.functional.softmax(outputs, dim=1) * 100.0
_, preds = torch.max(outputs, 1)
Expand Down
41 changes: 21 additions & 20 deletions mapreader/classify/custom_models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/usr/bin/env python
from __future__ import annotations

import copy

import torch


class twoParallelModels(torch.nn.Module):
class PatchContextModel(torch.nn.Module):
"""
A class for building a model that contains two parallel branches, with
separate input pipelines, but shares a fully connected layer at the end.
Model that contains two parallel branches, with separate input pipelines, but one shared fully connected layer at the end.
This class inherits from PyTorch's nn.Module.
"""

Expand All @@ -18,7 +19,7 @@ def __init__(
fc_layer: torch.nn.Linear,
):
"""
Initializes a new instance of the twoParallelModels class.
Initializes a new instance of the PatchContextModel class.
Parameters:
-----------
Expand All @@ -29,25 +30,28 @@ def __init__(
fc_layer : nn.Linear
The fully connected layer at the end of the model.
Input size should be output size of patch_model + output size of context_model.
Output size should be number of classes (labels).
Output size should be number of classes (labels) at the patch level.
"""
super().__init__()

if patch_model is context_model:
context_model = copy.deepcopy(context_model)

self.patch_model = patch_model
self.context_model = context_model
self.fc_layer = fc_layer

def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
def forward(self, patch: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""
Defines the computation performed at every forward pass. Receives two
inputs, x1 and x2, and feeds them through the respective feature
extractor modules, then concatenates the output and passes it through
Defines the computation performed at every forward pass.
Receives two inputs, patch and context, and feeds them through the respective feature extractor modules, then concatenates the output and passes it through
the fully connected layer.
Parameters:
-----------
x1 : torch.Tensor
The input tensor for the patch only pipeline.
x2 : torch.Tensor
patch : torch.Tensor
The input tensor for the patch pipeline.
context : torch.Tensor
The input tensor for the context pipeline.
Returns:
Expand All @@ -56,13 +60,10 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
The output tensor of the model.
"""

x1 = self.patch_model(x1)
x1 = x1.view(x1.size(0), -1)

x2 = self.context_model(x2)
x2 = x2.view(x2.size(0), -1)
patch_output = self.patch_model(patch)
context_output = self.context_model(context)

# Concatenate in dim1 (feature dimension)
x = torch.cat((x1, x2), 1)
x = self.fc_layer(x)
return x
out = torch.cat((patch_output, context_output), 1)
out = self.fc_layer(out)
return (patch_output, context_output), out

0 comments on commit a5e36ca

Please sign in to comment.