# Classification

In [None]:
def get_fisher_diag(model, dataset, params, empirical=True):
  fisher = {}
  for n, p in deepcopy(params).items():
    p.data.zero_()
    fisher[n] = p.data.clone().detach().requires_grad_()

  model.eval()
  for input, gt_label in dataset:
    model.zero_grad()
    output = model(input).view(1, -1)
    if empirical:
      label = gt_label
    else:
      label = output.max(1)[1].view(-1)

    negloglikelihood = torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(output, dim=1), label)
    negloglikelihood.backward()

    for n, p in model.named_parameters():
      fisher[n].data += p.grad.data ** 2 / len(dataset)

  fisher = {n: p for n, p in fisher.items()}
  return fisher

def get_ewc_loss(model, fisher, p_old):
  loss = 0
  for n, p in model.named_parameters():
    _loss = fisher[n] * (p - p_old[n]) ** 2
    loss += _loss.sum()
  return loss


# Text Generation

In [None]:
def get_fisher_diag(model, dataset, params, empirical=True):
  fisher = {}
  for n, p in deepcopy(params).items():
    p.data.zero_()
    fisher[n] = p.data.clone().detach().requires_grad_()

  model.eval()
  for input, target in dataset:
    model.zero_grad()
    output = model(input)
    output = output.view(-1, output.size(-1))
    # output = model(input).view(1, -1)
    if empirical:
      label = target.view(-1)
    else:
      label = torch.argmax(output, dim=1)

    cross_entropy_loss = torch.nn.functional.cross_entropy(output, label)
    cross_entropy_loss.backward()

    for n, p in model.named_parameters():
      fisher[n].data += p.grad.data ** 2 / len(dataset)

  fisher = {n: p for n, p in fisher.items()}
  return fisher

def get_ewc_loss(model, fisher, p_old):
  loss = 0
  for n, p in model.named_parameters():
    _loss = fisher[n] * (p - p_old[n]) ** 2
    loss += _loss.sum()
  return loss