-
Notifications
You must be signed in to change notification settings - Fork 12
/
opti_trainer.py
154 lines (123 loc) · 5.79 KB
/
opti_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from functools import partial
from typing import Callable, List, Tuple
import jax
import optax
from flax import struct
from flax.training import train_state
from jax.numpy import ndarray
class MetaTrainState(train_state.TrainState):
adapt_fn: Callable = struct.field(pytree_node=False)
loss_fn: Callable = struct.field(pytree_node=False)
class OptiTrainer:
@staticmethod
def create(params, apply_fn, adapt_fn, loss_fn, tx) -> MetaTrainState:
"""Creates a new MetaTrainState object which is the default object used for training.
Args:
params (flax.core.FrozenDict[str, Any]): Parameters of the model.
apply_fn ((params, x) -> y): Function that applies the model to a batch of data.
adapt_fn ((params, apply_fn, loss_fn, support_set) -> adapted_params): Specific meta learning function
which adapts to the support set.
loss_fn ((logits, targets) -> loss): Loss Function.
tx (Optax Optimizer): Optax optimizer.
Returns:
MetaTrainState: Initialized MetaTrainState object.
"""
state = MetaTrainState.create(
params=params, apply_fn=apply_fn, adapt_fn=adapt_fn, loss_fn=loss_fn, tx=tx
)
return state.replace(opt_state=tx.init(params["params"]))
@staticmethod
@partial(jax.jit, static_argnums=(2,))
def meta_train_step(
state: MetaTrainState,
tasks,
metrics: List[Callable[[ndarray, ndarray], ndarray]] = [],
) -> Tuple[MetaTrainState, ndarray, List[ndarray]]:
"""Performs a single meta-training step on a batch of tasks.
The fuctions first adapts to the support set and then evaluates it's perfomance
on the query set.
Args:
state (MetaTrainState): Contains information regarding the current state.
tasks ((x_train, y_train), (x_test, y_test)): Batch of tasks to be trained on.
metrics (List[(ndarray, ndarray) -> ndarray]): List of metrics to be evaluated on the query set.
Returns:
Tuple[MetaTrainState, jnp.ndarray, List[jnp.ndarray]]: (Next_State, Loss, metrics).
"""
params = state.params
theta = params["params"]
def batch_meta_train_loss(theta, apply_fn, adapt_fn, loss_fn, tasks):
loss, metrics_value = jax.vmap(
OptiTrainer.meta_loss, in_axes=(None, None, None, None, 0, None)
)(
params.copy({"params": theta}),
apply_fn,
adapt_fn,
loss_fn,
tasks,
metrics,
)
return loss.mean(), [metric.mean() for metric in metrics_value]
(loss, metrics_value), grads = jax.value_and_grad(
batch_meta_train_loss, has_aux=True
)(theta, state.apply_fn, state.adapt_fn, state.loss_fn, tasks)
# state = state.apply_gradients(grads=grads)
# if state.step == 0: # Initialize optimizer
# state = state.replace(opt_state=state.tx.init(state.params["params"]))
updates, new_opt_state = state.tx.update(
grads, state.opt_state, state.params["params"]
)
new_params = optax.apply_updates(state.params["params"], updates)
params = state.params.copy({"params": new_params})
state = state.replace(
step=state.step + 1, params=params, opt_state=new_opt_state
)
return state, loss, metrics_value
@staticmethod
@partial(jax.jit, static_argnums=(2,))
def meta_test_step(
state: MetaTrainState,
tasks,
metrics: List[Callable[[ndarray, ndarray], ndarray]] = [],
) -> Tuple[ndarray, List[ndarray]]:
"""Performs a single meta-testing step on a batch of tasks.
The function first adapts to the support set and then evaluates it's perfomance
on the query set.
Args:
state (MetaTrainState): Contains information regarding the current state.
tasks ((x_train, y_train), (x_test, y_test)): Batch of tasks to be evaluated on.
metrics (List[(ndarray, ndarray) -> ndarray]): List of metrics to be evaluated on the query set.
Returns:
Tuple[jnp.ndarray, List[jnp.ndarray]: (Loss, metrics).
"""
params = state.params
apply_fn = state.apply_fn
loss_fn = state.loss_fn
adapt_fn = state.adapt_fn
loss, metrics_value = jax.vmap(
OptiTrainer.meta_loss, in_axes=(None, None, None, None, 0, None)
)(params, apply_fn, adapt_fn, loss_fn, tasks, metrics)
return loss.mean(), [metric.mean() for metric in metrics_value]
@staticmethod
def meta_loss(
params, apply_fn, adapt_fn, loss_fn, task, metrics
) -> Tuple[ndarray, List[ndarray]]:
"""Calculates the Meta Loss of a task
Args:
params (flax.core.FrozenDict[str, Any]): Parameters of the model.
apply_fn ((params, x) -> y): Function that applies the model to a batch of data.
adapt_fn ((params, apply_fn, loss_fn, support_set) -> adapted_params): Specific meta learning function
which adapts to the support set.
loss_fn ((logits, targets) -> loss): Loss Function.
tasks ((x_train, y_train), (x_test, y_test)): Batch of tasks to be trained on
Returns:
Tuple[jnp.ndarray, List[jnp.ndarray]: (Loss, metrics).
"""
support_set, query_set = task
# Adaptation step
theta = adapt_fn(params, apply_fn, loss_fn, support_set)
# Evaluation step
x_train, y_train = query_set
logits = apply_fn(params.copy({"params": theta}), x_train, train=False)
# Calculate metrics
metrics_value = [metric(logits, y_train) for metric in metrics]
return loss_fn(logits, y_train), metrics_value