Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,12 @@ def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
expert_batch = self._demo_buffer.sample_mini_batch(
mini_batch.num_experiences, 1
)
loss, policy_mean_estimate, expert_mean_estimate, kl_loss = self._discriminator_network.compute_loss(
loss, stats_dict = self._discriminator_network.compute_loss(
mini_batch, expert_batch
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
stats_dict = {
"Losses/GAIL Discriminator Loss": loss.detach().cpu().numpy(),
"Policy/GAIL Policy Estimate": policy_mean_estimate.detach().cpu().numpy(),
"Policy/GAIL Expert Estimate": expert_mean_estimate.detach().cpu().numpy(),
}
if self._discriminator_network.use_vail:
stats_dict["Policy/GAIL Beta"] = (
self._discriminator_network.beta.detach().cpu().numpy()
)
stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy()
return stats_dict


Expand All @@ -76,7 +66,7 @@ class DiscriminatorNetwork(torch.nn.Module):
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
super().__init__()
self._policy_specs = specs
self.use_vail = settings.use_vail
self._use_vail = settings.use_vail
self._settings = settings

state_encoder_settings = NetworkSettings(
Expand Down Expand Up @@ -108,20 +98,20 @@ def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
estimator_input_size = settings.encoding_size
if settings.use_vail:
estimator_input_size = self.z_size
self.z_sigma = torch.nn.Parameter(
self._z_sigma = torch.nn.Parameter(
torch.ones((self.z_size), dtype=torch.float), requires_grad=True
)
self.z_mu_layer = linear_layer(
self._z_mu_layer = linear_layer(
settings.encoding_size,
self.z_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
)
self.beta = torch.nn.Parameter(
self._beta = torch.nn.Parameter(
torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False
)

self.estimator = torch.nn.Sequential(
self._estimator = torch.nn.Sequential(
linear_layer(estimator_input_size, 1), torch.nn.Sigmoid()
)

Expand Down Expand Up @@ -166,9 +156,9 @@ def compute_estimate(
hidden = self.encoder(encoder_input)
z_mu: Optional[torch.Tensor] = None
if self._settings.use_vail:
z_mu = self.z_mu_layer(hidden)
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise)
estimate = self.estimator(hidden)
z_mu = self._z_mu_layer(hidden)
hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
estimate = self._estimator(hidden)
return estimate, z_mu

def compute_loss(
Expand All @@ -177,41 +167,53 @@ def compute_loss(
"""
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
"""
total_loss = torch.zeros(1)
stats_dict: Dict[str, np.ndarray] = {}
policy_estimate, policy_mu = self.compute_estimate(
policy_batch, use_vail_noise=True
)
expert_estimate, expert_mu = self.compute_estimate(
expert_batch, use_vail_noise=True
)
loss = -(
torch.log(expert_estimate * (1 - self.EPSILON))
+ torch.log(1.0 - policy_estimate * (1 - self.EPSILON))
stats_dict["Policy/GAIL Policy Estimate"] = (
policy_estimate.mean().detach().cpu().numpy()
)
stats_dict["Policy/GAIL Expert Estimate"] = (
expert_estimate.mean().detach().cpu().numpy()
)
discriminator_loss = -(
torch.log(expert_estimate + self.EPSILON)
+ torch.log(1.0 - policy_estimate + self.EPSILON)
).mean()
kl_loss: Optional[torch.Tensor] = None
stats_dict["Losses/GAIL Loss"] = discriminator_loss.detach().cpu().numpy()
total_loss += discriminator_loss
if self._settings.use_vail:
# KL divergence loss (encourage latent representation to be normal)
kl_loss = torch.mean(
-torch.sum(
1
+ (self.z_sigma ** 2).log()
+ (self._z_sigma ** 2).log()
- 0.5 * expert_mu ** 2
- 0.5 * policy_mu ** 2
- (self.z_sigma ** 2),
- (self._z_sigma ** 2),
dim=1,
)
)
vail_loss = self.beta * (kl_loss - self.mutual_information)
vail_loss = self._beta * (kl_loss - self.mutual_information)
with torch.no_grad():
self.beta.data = torch.max(
self.beta + self.alpha * (kl_loss - self.mutual_information),
self._beta.data = torch.max(
self._beta + self.alpha * (kl_loss - self.mutual_information),
torch.tensor(0.0),
)
loss += vail_loss
total_loss += vail_loss
stats_dict["Policy/GAIL Beta"] = self._beta.detach().cpu().numpy()
stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy()
if self.gradient_penalty_weight > 0.0:
loss += self.gradient_penalty_weight * self.compute_gradient_magnitude(
policy_batch, expert_batch
total_loss += (
self.gradient_penalty_weight
* self.compute_gradient_magnitude(policy_batch, expert_batch)
)
return loss, torch.mean(policy_estimate), torch.mean(expert_estimate), kl_loss
return total_loss, stats_dict

def compute_gradient_magnitude(
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
Expand Down Expand Up @@ -243,9 +245,9 @@ def compute_gradient_magnitude(
hidden = self.encoder(encoder_input)
if self._settings.use_vail:
use_vail_noise = True
z_mu = self.z_mu_layer(hidden)
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise)
hidden = self.estimator(hidden)
z_mu = self._z_mu_layer(hidden)
hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
hidden = self._estimator(hidden)
estimate = torch.mean(torch.sum(hidden, dim=1))
gradient = torch.autograd.grad(estimate, encoder_input)[0]
# Norm's gradient could be NaN at 0. Use our own safe_norm
Expand Down