forked from yang-song/score_sde_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 4
/
grad_estimator.py
28 lines (22 loc) · 1.06 KB
/
grad_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import pytorch_lightning as pl
from model_lightning import SdeGenerativeModel
from utils import plot
class SdeGenerativeModel_GradientEstimation(SdeGenerativeModel):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
def training_step(self, batch, batch_idx):
loss = super(SdeGenerativeModel_GradientEstimation, self).training_step(batch, batch_idx)
grad_norm_t =[]
times=torch.randint(0, self.sde.N, (100,) ,device=batch.device)
for time in times:
labels = time.repeat(batch.shape[0],*time.shape)
gradients = self.compute_grad(self.score_model, batch, t=labels)
grad_norm = gradients.norm(2, dim=1).max().item()
grad_norm_t.append(grad_norm)
image = plot(times.cpu().numpy(),
grad_norm_t,
'Gradient Norms Epoch: ' + str(self.current_epoch)
)
self.logger.experiment.add_image('grad_norms', image, self.current_epoch)
return loss