Skip to content

Commit

Permalink
added support for other tensor sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianjav committed Aug 1, 2023
1 parent c5b50c8 commit f924985
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 47 deletions.
15 changes: 14 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.6.0] - 2023-08-01

### Changed

- Fixed a bug in `RotoGrad` where `burn_in_period` was not being processed.
- Addressed #3. Added support to more complex tensor shapes.
- Tensors are not assumed to be of the form `... x rotation_shape x post_shape`.
- The dimensions in `rotation_shape` are flattened and rotated (the rotation matrices will be of size the number of elements in `rotation_shape`).
- The dimensions in `post_shape` will be flattened and not rotated (which is useful, e.g., to rotate the channel dimension on images).
- The dimensions prior to `rotation_shape` are taken as batch dimensions.


## [0.1.5.2] - 2022-03-01

### Changed
Expand Down Expand Up @@ -45,8 +57,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
internally to compute GradNorm's weights). We keep the option to still normalize them
if desired.

[unreleased]: https://github.com/adrianjav/rotograd/compare/v0.1.4...HEAD
[unreleased]: https://github.com/adrianjav/rotograd/compare/v0.1.6.0...HEAD
[0.1.4]: https://github.com/adrianjav/rotograd/compare/v0.1.3...v0.1.4
[0.1.5]: https://github.com/adrianjav/rotograd/compare/v0.1.4...v0.1.5
[0.1.5.1]: https://github.com/adrianjav/rotograd/compare/v0.1.5...v0.1.5.1
[0.1.5.2]: https://github.com/adrianjav/rotograd/compare/v0.1.5.1...v0.1.5.2
[0.1.6.0]: https://github.com/adrianjav/rotograd/compare/v0.1.5.2...v0.1.6.0
14 changes: 7 additions & 7 deletions example/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def build_dense(tasks, args):
shared = getattr(args, 'shared', False)

enc_params = getattr(args, 'encoder', args)
backbone = FeedForward(args.input_size, args.latent_size, enc_params.hidden_size, enc_params.num_layers,
activations[enc_params.activation])
backbone = FeedForward(args.input_size, args.rotation_size, enc_params.hidden_size, enc_params.num_layers,
activations[enc_params.activation])

dec_params = getattr(args, 'decoder', args)
if isinstance(dec_params.hidden_size, int) and not shared:
Expand All @@ -59,7 +59,7 @@ def build_dense(tasks, args):
dec_params.num_layers = [dec_params.num_layers] * len(tasks)

for i, task_i in enumerate(tasks):
heads.append(FeedForward(args.latent_size, dec_params.output_size[i], dec_params.hidden_size[i],
heads.append(FeedForward(args.rotation_size, dec_params.output_size[i], dec_params.hidden_size[i],
dec_params.num_layers[i], activations[dec_params.activation],
dec_params.drop_last or task_i.loss == 'mse'))

Expand Down Expand Up @@ -113,14 +113,14 @@ def main(args):
tasks = get_tasks(args.tasks.names, args.tasks.weights, loaders['train'].dataset)
backbone, heads = build_dense(tasks, args.model)

if not hasattr(args.rotograd, 'latent_size'):
args.rotograd.latent_size = backbone.output_size
if not hasattr(args.rotograd, 'rotation_size'):
args.rotograd.rotation_size = backbone.output_size

method = args.algorithms.method
if method == 'rotograd':
model = RotoGrad(backbone, heads, args.rotograd.latent_size, normalize_losses=args.rotograd.normalize)
model = RotoGrad(backbone, heads, args.rotograd.rotation_size, normalize_losses=args.rotograd.normalize)
elif method == 'rotate':
model = RotateOnly(backbone, heads, args.rotograd.latent_size, normalize_losses=args.rotograd.normalize)
model = RotateOnly(backbone, heads, args.rotograd.rotation_size, normalize_losses=args.rotograd.normalize)
else:
model = VanillaMTL(backbone, heads) # TODO add normalize_losses

Expand Down
2 changes: 1 addition & 1 deletion example/toy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ model:
num_layers: 2

input_size: 2
latent_size: 2
rotation_size: 2
name: dense

rotograd:
Expand Down
2 changes: 1 addition & 1 deletion rotograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .rotograd import VanillaMTL, RotoGrad, RotoGradNorm, cached, RotateOnly

__version__ = '0.1.5.2'
__version__ = '0.1.6.0'

__all__ = ['VanillaMTL', 'RotateOnly', 'RotoGrad', 'RotoGradNorm', 'cached']
117 changes: 80 additions & 37 deletions rotograd/rotograd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, List, Any, Optional
from typing import Sequence, Union, Any, Optional
from functools import reduce

import torch
import torch.nn as nn
Expand Down Expand Up @@ -101,12 +102,12 @@ def model_parameters(self, recurse=True):


def rotate(points, rotation, total_size):
if total_size != points.size(-1):
points_lo, points_hi = points[:, :rotation.size(1)], points[:, rotation.size(1):]
point_lo = torch.einsum('ij,bj->bi', rotation, points_lo)
return torch.cat((point_lo, points_hi), dim=-1)
if total_size != points.size(-2):
points_lo, points_hi = points[..., :rotation.size(1), :], points[..., rotation.size(1):, :]
point_lo = torch.einsum('ij,...jk->...ik', rotation, points_lo)
return torch.cat((point_lo, points_hi), dim=-2)
else:
return torch.einsum('ij,bj->bi', rotation, points)
return torch.einsum('ij,...jk->...ik', rotation, points)


def rotate_back(points, rotation, total_size):
Expand All @@ -124,7 +125,7 @@ def hook(self, g):
self.p.grads[self.item] = g.clone()

@property
def p(self):
def p(self) -> 'RotateOnly':
return self.parent[0]

@property
Expand All @@ -136,14 +137,23 @@ def weight(self):
return self.p.weight[self.item] if hasattr(self.p, 'weight') else 1.

def rotate(self, z):
return rotate(z, self.R, self.p.latent_size)
dim_post = -len(self.p.post_shape)
dim_rot = -len(self.p.rotation_shape)
og_shape = z.shape
if dim_post == 0:
z = z.unsqueeze(dim=-1)
dim_post = -1

z = z.flatten(start_dim=dim_post)
z = z.flatten(start_dim=dim_rot - 1, end_dim=-2)

return rotate(z, self.R.detach(), self.p.rotation_size).view(og_shape)

def rotate_back(self, z):
return rotate_back(z, self.R, self.p.latent_size)
return rotate_back(z, self.R, self.p.rotation_size)

def forward(self, z):
R = self.R.clone().detach()
new_z = rotate(z, R, self.p.latent_size)
new_z = self.rotate(z)
if self.p.training:
new_z.register_hook(self.hook)

Expand All @@ -154,17 +164,27 @@ class RotateOnly(nn.Module):
r"""
Implementation of the rotating part of RotoGrad as described in the original paper. [1]_
The module takes as input a vector of shape ... x rotation_shape x
Parameters
----------
backbone
Shared module.
heads
Task-specific modules.
latent_size
Size of the shared representation, that is, size of the output of backbone.Z
rotation_shape
Shape of the shared representation to be rotated which, usually, is just the size of the backbone's output.
Passing a shape is useful, for example, if you want to rotate an image with shape width x height.
post_shape : optional, default=()
Shape of the shared representation following the part to be rotated (if any). This part will be kept as it is.
This is useful, for example, if you want to rotate only the channels of an image.
normalize_losses : optional, default=False
Whether to use this normalized losses to back-propagate through the task-specific parameters as well.
burn_in_period : optional, default=20
When back-propagating towards the shared parameters, *each task loss is normalized dividing by its initial
value*, :math:`{L_k(t)}/{L_k(t_0 = 0)}`. This parameter sets a number of iterations after which the denominator
will be replaced by the value of the loss at that iteration, that is, :math:`t_0 = burn\_in\_period`.
This is done to overcome problems with losses quickly changing in the first iterations.
Attributes
----------
Expand All @@ -189,10 +209,14 @@ class RotateOnly(nn.Module):
heads: Sequence[nn.Module]
rep: Optional[torch.Tensor]

def __init__(self, backbone: nn.Module, heads: List[nn.Module], latent_size: int, *args,
burn_in_period: int = 20, normalize_losses: bool = False):
def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], rotation_shape: Union[int, torch.Size], *args,
post_shape: torch.Size = (), normalize_losses: bool = False, burn_in_period: int = 20):
super(RotateOnly, self).__init__()
num_tasks = len(heads)
if isinstance(rotation_shape, int):
rotation_shape = torch.Size((rotation_shape,))
assert len(rotation_shape) > 0
rotation_size = reduce(int.__mul__, rotation_shape)

for i in range(num_tasks):
heads[i] = nn.Sequential(RotateModule(self, i), heads[i])
Expand All @@ -202,12 +226,14 @@ def __init__(self, backbone: nn.Module, heads: List[nn.Module], latent_size: int

# Parameterize rotations so we can run unconstrained optimization
for i in range(num_tasks):
self.register_parameter(f'rotation_{i}', nn.Parameter(torch.eye(latent_size), requires_grad=True))
self.register_parameter(f'rotation_{i}', nn.Parameter(torch.eye(rotation_size), requires_grad=True))
orthogonal(self, f'rotation_{i}', triv='expm') # uses exponential map (alternative: cayley)

# Parameters
self.num_tasks = num_tasks
self.latent_size = latent_size
self.rotation_shape = rotation_shape
self.rotation_size = rotation_size
self.post_shape = post_shape
self.burn_in_period = burn_in_period
self.normalize_losses = normalize_losses

Expand Down Expand Up @@ -342,9 +368,6 @@ def backward(self, losses: Sequence[torch.Tensor], backbone_loss=None, **kwargs)

def _rep_grad(self):
old_grads = self.original_grads # these grads are already rotated, we have to recover the originals
# with torch.no_grad():
# grads = [rotate(g, R) for g, R in zip(grads, self.rotation)]
#
grads = self.grads

# Compute the reference vector
Expand All @@ -353,10 +376,22 @@ def _rep_grad(self):
old_grads2 = [g * divide(mean_norm, g.norm(p=2)) for g in old_grads]
mean_grad = sum([g for g in old_grads2]).detach().clone() / len(grads)

dim_post = -len(self.post_shape)
dim_rot = -len(self.rotation_shape)
og_shape = mean_grad.shape
if dim_post == 0:
mean_grad = mean_grad.unsqueeze(dim=-1)
dim_post = -1

mean_grad = mean_grad.flatten(start_dim=dim_post)
mean_grad = mean_grad.flatten(start_dim=dim_rot - 1, end_dim=-2)

for i, grad in enumerate(grads):
R = self.rotation[i]
loss_rotograd = rotate(mean_grad, R, self.latent_size) - grad
loss_rotograd = torch.einsum('bi,bi->b', loss_rotograd, loss_rotograd)
loss_rotograd = rotate(mean_grad, R, self.rotation_size).view(og_shape) - grad
loss_rotograd = loss_rotograd.flatten(start_dim=dim_post)
loss_rotograd = loss_rotograd.flatten(start_dim=dim_rot - 1, end_dim=-2)
loss_rotograd = torch.einsum('...ij,...ij->...', loss_rotograd, loss_rotograd)
loss_rotograd.mean().backward()

return sum(old_grads)
Expand All @@ -383,8 +418,12 @@ class RotoGrad(RotateOnly):
Shared module.
heads
Task-specific modules.
latent_size
Size of the shared representation, that is, size of the output of backbone.Z
rotation_shape
Shape of the shared representation to be rotated which, usually, is just the size of the backbone's output.
Passing a shape is useful, for example, if you want to rotate an image with shape width x height.
post_shape : optional, default=()
Shape of the shared representation following the part to be rotated (if any). This part will be kept as it is.
This is useful, for example, if you want to rotate only the channels of an image.
burn_in_period : optional, default=20
When back-propagating towards the shared parameters, *each task loss is normalized dividing by its initial
value*, :math:`{L_k(t)}/{L_k(t_0 = 0)}`. This parameter sets a number of iterations after which the denominator
Expand All @@ -393,7 +432,6 @@ class RotoGrad(RotateOnly):
normalize_losses : optional, default=False
Whether to use this normalized losses to back-propagate through the task-specific parameters as well.
Attributes
----------
num_tasks
Expand All @@ -403,7 +441,7 @@ class RotoGrad(RotateOnly):
heads
Sequence with the (rotated) task-specific heads.
rep
Current output of the backbone (after calling forward during training).
Current output of the backbone (aft1er calling forward during training).
References
Expand All @@ -417,9 +455,10 @@ class RotoGrad(RotateOnly):
heads: Sequence[nn.Module]
rep: torch.Tensor

def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], latent_size: int, *args,
burn_in_period: int = 20, normalize_losses: bool = False):
super().__init__(backbone, heads, latent_size, burn_in_period, *args, normalize_losses=normalize_losses)
def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], rotation_shape: Union[int, torch.Size], *args,
post_shape: torch.Size = (), normalize_losses: bool = False, burn_in_period: int = 20):
super().__init__(backbone, heads, rotation_shape, *args,
post_shape=post_shape, burn_in_period=burn_in_period, normalize_losses=normalize_losses)

self.initial_grads = None
self.counter = 0
Expand Down Expand Up @@ -451,18 +490,22 @@ class RotoGradNorm(RotoGrad):
Shared module.
heads
Task-specific modules.
latent_size
Size of the shared representation, that is, size of the output of backbone.
rotation_shape
Shape of the shared representation to be rotated which, usually, is just the size of the backbone's output.
Passing a shape is useful, for example, if you want to rotate an image with shape width x height.
alpha
:math:`\alpha` hyper-parameter as described in GradNorm, [2]_ used to compute the reference direction.
post_shape : optional, default=()
Shape of the shared representation following the part to be rotated (if any). This part will be kept as it is.
This is useful, for example, if you want to rotate only the channels of an image.
burn_in_period : optional, default=20
When back-propagating towards the shared parameters, *each task loss is normalized dividing by its initial
value*, :math:`{L_k(t)}/{L_k(t_0 = 0)}`. This parameter sets a number of iterations after which the denominator
will be replaced by the value of the loss at that iteration, that is, :math:`t_0 = burn\_in\_period`.
This is done to overcome problems with losses quickly changing in the first iterations.
normalize_losses : optional, default=False
Whether to use this normalized losses to back-propagate through the task-specific parameters as well.
TODO
Attributes
----------
Expand All @@ -486,10 +529,10 @@ class RotoGradNorm(RotoGrad):
"""

def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], latent_size: int, *args, alpha: float,
burn_in_period: int = 20, normalize_losses: bool = False):
super().__init__(backbone, heads, latent_size, *args, burn_in_period=burn_in_period,
normalize_losses=normalize_losses)
def __init__(self, backbone: nn.Module, heads: Sequence[nn.Module], rotation_shape: Union[int, torch.Size], *args,
alpha: float, post_shape: torch.Size = (), normalize_losses: bool = False, burn_in_period: int = 20):
super().__init__(backbone, heads, rotation_shape, *args,
post_shape=post_shape, burn_in_period=burn_in_period, normalize_losses=normalize_losses)
self.alpha = alpha
self.weight_ = nn.ParameterList([nn.Parameter(torch.ones([]), requires_grad=True) for _ in range(len(heads))])

Expand Down

0 comments on commit f924985

Please sign in to comment.