Skip to content

Commit 5da73ca

Browse files
committed
Fix naming of low rank to full rank
1 parent 2d37978 commit 5da73ca

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
parser.add_argument('--gamma', type=float, required=True)
4040

4141
parser.add_argument('--ode', type=str, choices=['first_order', 'second_order'], required=True)
42-
parser.add_argument('--parameterization', type=str, choices=['diagonal', 'low_rank'], required=True)
42+
parser.add_argument('--parameterization', type=str, choices=['diagonal', 'full_rank'], required=True)
4343

4444
# parameters of Q
4545
parser.add_argument('--num_gaussians', type=int, default=1, help="Number of gaussians in the mixture model.")

training/qsetup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, deterministic:
7575

7676
def construct(system: System, model: Optional[nn.module], xi: float, A: ArrayLike, B: ArrayLike,
7777
args: argparse.Namespace) -> QSetup:
78-
from training.setups import diagonal, lowrank
78+
from training.setups import diagonal, full
7979

8080
transform = None
8181
if args.internal_coordinates:
@@ -100,15 +100,15 @@ def construct(system: System, model: Optional[nn.module], xi: float, A: ArrayLik
100100
model, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
101101
)
102102
return diagonal.DiagonalSetup(system, model, xi, args.ode, args.T)
103-
elif args.parameterization == 'low_rank':
103+
elif args.parameterization == 'full_rank':
104104
if args.model == 'spline':
105-
model = lowrank.LowRankSpline(
105+
model = full.FullRankSpline(
106106
args.num_points, args.spline_mode, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
107107
)
108108
else:
109-
model = lowrank.LowRankWrapper(
109+
model = full.FullRankWrapper(
110110
model, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
111111
)
112-
return lowrank.LowRankSetup(system, model, xi, args.ode, args.T)
112+
return full.FullRankSetup(system, model, xi, args.ode, args.T)
113113
else:
114114
raise ValueError(f"Unknown parameterization: {args.parameterization}")

training/setups/lowrank.py renamed to training/setups/full.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _dSigmadt_batched(_S_t, _dSdt):
2424
_matmul_batched = jax.vmap(_matmul_batched, in_axes=1, out_axes=1)
2525

2626

27-
class LowRankSpline(nn.Module):
27+
class FullRankSpline(nn.Module):
2828
n_points: int
2929
interpolation: str
3030
T: float
@@ -37,8 +37,8 @@ class LowRankSpline(nn.Module):
3737

3838
@nn.compact
3939
def __call__(self, t):
40-
print("WARNING: Mixtures for low rank not yet implemented!")
41-
assert self.num_mixtures == 1, "Mixtures for low rank not yet implemented!"
40+
print("WARNING: Mixtures for full rank spline not yet implemented!")
41+
assert self.num_mixtures == 1, "Mixtures for full rank not yet implemented!"
4242

4343
ndim = self.A.shape[0]
4444
t = t / self.T
@@ -80,7 +80,7 @@ def get_tril(v):
8080
return out
8181

8282

83-
class LowRankWrapper(WrappedModule):
83+
class FullRankWrapper(WrappedModule):
8484
A: ArrayLike
8585
B: ArrayLike
8686
num_mixtures: int
@@ -126,7 +126,7 @@ def get_tril(v):
126126
return mu, S, w_logits
127127

128128

129-
class LowRankSetup(DriftedSetup):
129+
class FullRankSetup(DriftedSetup):
130130
def construct_loss(self, state_q: TrainState, gamma: float, BS: int) -> Callable[
131131
[Union[FrozenVariableDict, Dict[str, Any]], ArrayLike], ArrayLike]:
132132
def loss_fn(params_q: Union[FrozenVariableDict, Dict[str, Any]], key: ArrayLike) -> ArrayLike:

0 commit comments

Comments
 (0)