In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from diffusion import DiffusionRunner
from ddpm_sde import DDPM_SDE
from default_mnist_config import create_default_mnist_config

In [143]:
from ml_collections import ConfigDict
import torch

from torchvision.transforms import (
    Resize,
    Normalize,
    Compose,
)

class ClassGuidDiffusionRunner(DiffusionRunner):
    def __init__(self, config: ConfigDict, eval: bool = False):
        super().__init__(config, eval)
        self.classifier = self.get_classifier()
        self.mnist_transforms = Compose([
                                        Resize((config.data.image_size, config.data.image_size)),
                                        Normalize(mean=config.data.norm_mean, std=config.data.norm_std),
                                        ]) 
    def get_classifier(self):
        class_model = self.config.class_guide.model
        classifier = class_model(**self.config.class_guide.classifier_args)
        classifier.load_state_dict(torch.load(self.config.class_guide.checkpoint_path))
        classifier.to(self.config.device)
        return classifier
      
    # def classifier_score(self, 
    #                      input_x, 
    #                      y,
    #                      criterion=torch.nn.CrossEntropyLoss()):
        
    #     # target_class_index = torch.ones(input_x.shape[0]).to(self.config.device) 
    #     # target_class_index *= y 
    #     # target_class_index = target_class_index.to(self.config.device) 
    #     target_class_index =  torch.ones(input_x.shape[0]).to(self.config.device) * torch.tensor([y]).to(self.config.device)

    #     input_data  = self.mnist_transforms(input_x)
    #     input_data  = input_data.to(self.config.device)
    #     input_data.requires_grad = True

    #     output = self.classifier(input_data)
    #     loss = criterion(output, target_class_index)
    #     loss.backward()
    #     gradients = input_data.grad
    #     return gradients
                
    # def classifier_score(self, input_data, class_index):
    #     self.classifier.eval()  # Убедитесь, что модель в режиме оценки (не обучения)

    #     input_data.requires_grad = True  # Позволяет вычислить градиенты для входных данных
    #     output = self.classifier(input_data)

    #     # Убедитесь, что class_index соответствует одному из выходов модели
    #     assert 0 <= class_index < output.size(1), f"Недопустимый class_index: {class_index}"

    #     # Обнуляем градиенты перед обратным распространением ошибки
    #     self.classifier.zero_grad()

    #     # Создаем тензор с единичным значением в нужном классе
    #     target = torch.zeros_like(output)
    #     target[:, class_index] = 1

    #     # Вычисляем функцию потерь (например, кросс-энтропию) между предсказаниями и целевыми значениями
    #     loss_fn = torch.nn.CrossEntropyLoss()
    #     loss = loss_fn(output, target.argmax(dim=1))

    #     # Обратное распространение ошибки для вычисления градиентов входных данных
    #     loss.backward()

    #     # Получаем градиенты входных данных
    #     gradients = input_data.grad

    #     return gradients
    

    def classifier_score(self, input_data, class_index):
        self.classifier.eval()
        self.classifier.zero_grad()
        with torch.enable_grad():
            input_data.requires_grad = True
            output = self.classifier(input_data)
            output_for_class = output[:, class_index]

            # output_for_class.backward()
            output_for_class.backward(torch.ones_like(output_for_class))
            gradients  = input_data.grad

        return gradients


        
    def calc_score(self, input_x: torch.Tensor, input_t: torch.Tensor, y=None, gamma=3) -> torch.Tensor:
        """
        calculate score w.r.t noisy X and t
        input:
            input_x - noizy image
            input_t - time label
        algorithm:
            1) predict noize via DDPM
            2) calculate std of input_x
            3) calculate score = -pred_noize / std
        """
        eps = self.model(input_x, input_t)
        
        std = self.sde.marginal_std(input_t)
        std = std.view(-1, 1, 1, 1)
        ddpm_score = -eps / std

        classifier_score = self.classifier_score(input_x, y)

        score = ddpm_score + gamma * classifier_score

        return {"score" : score, 
                "noise" : eps}

In [144]:
config = create_default_mnist_config()
runner = ClassGuidDiffusionRunner(config)

In [145]:
input_data  = torch.randn(2, 1, 32, 32).to("cuda")
input_data.requires_grad = True

In [146]:
runner.classifier_score(input_data, 1)

tensor([[[[-2.5350e-02, -1.1682e-03, -1.5188e-02,  ...,  3.2827e-04,
            3.7865e-04,  3.4775e-03],
          [ 3.8308e-03, -6.7694e-02,  2.1623e-03,  ...,  7.1501e-03,
           -1.5663e-02, -7.4589e-03],
          [-1.9547e-02, -5.5258e-03, -1.0999e-02,  ..., -7.6489e-03,
            2.3632e-03, -9.4181e-04],
          ...,
          [ 8.7392e-03,  2.5417e-03, -1.0556e-03,  ...,  4.7557e-03,
            3.0294e-03, -1.4760e-02],
          [-1.4125e-02, -9.6714e-04,  3.1654e-03,  ...,  2.8685e-03,
           -2.1903e-02, -2.1055e-03],
          [ 1.2173e-03,  1.2521e-02,  5.3749e-03,  ..., -8.3691e-03,
           -3.4617e-03, -2.7530e-03]]],


        [[[-7.8555e-03,  4.3064e-03, -1.5587e-02,  ...,  3.0190e-03,
           -9.8673e-03,  9.0403e-04],
          [ 1.8961e-02,  6.5144e-03, -1.8270e-03,  ...,  1.6347e-02,
            3.5795e-02, -1.8068e-02],
          [-3.9551e-02, -4.9772e-03, -4.6893e-02,  ..., -3.9127e-03,
           -2.5773e-02, -8.0955e-03],
          ...,
   

In [131]:
# output = runner.classifier(input_data)
# output_for_class = output[:, 1]

# # output_for_class.backward()
# output_for_class.backward(torch.ones_like(output_for_class))

# grad = input_data.grad

In [126]:
# torch.ones_like(output_for_class)

tensor([1., 1.], device='cuda:0')

In [133]:
input_data  = torch.randn(2, 1, 32, 32).to("cuda")
t = torch.Tensor([0.1]).to("cuda")

In [7]:
runner.calc_score(input_data, t, y=1)

{'score': tensor([[[[ 0.0075,  0.0100,  0.0043,  ..., -0.0155, -0.0097,  0.0011],
           [ 0.0077,  0.0227, -0.0048,  ..., -0.0002,  0.0073, -0.0008],
           [ 0.0131,  0.0001, -0.0157,  ...,  0.0150,  0.0117,  0.0131],
           ...,
           [ 0.0137, -0.0116, -0.0170,  ...,  0.0019,  0.0059,  0.0038],
           [ 0.0064,  0.0042,  0.0025,  ...,  0.0089,  0.0015, -0.0007],
           [-0.0015, -0.0036,  0.0003,  ..., -0.0141, -0.0105, -0.0048]]]],
        device='cuda:0', grad_fn=<AddBackward0>),
 'noise': tensor([[[[ 1.1030e-06, -8.0282e-07,  6.2809e-06,  ...,  6.3142e-07,
            -3.8763e-06,  2.8395e-06],
           [ 2.8132e-06,  6.1866e-06,  5.1360e-06,  ...,  8.7809e-06,
            -1.8317e-07, -3.3796e-06],
           [ 5.9697e-06,  6.8061e-06,  3.0741e-05,  ..., -5.0512e-06,
            -2.0408e-06, -5.0607e-06],
           ...,
           [ 1.5850e-06,  2.7316e-06, -2.6919e-06,  ...,  1.8231e-05,
             5.2683e-06,  6.6253e-06],
           [-1.7719e-05

In [8]:
# runner.classifier

In [66]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import (
    Resize,
    Normalize,
    Compose,
)

mnist_transforms = Compose([
        Resize((config.data.image_size, config.data.image_size)),
        Normalize(mean=config.data.norm_mean, std=config.data.norm_std),
    ])

In [92]:
criterion = torch.nn.CrossEntropyLoss()
target_class_index = 2
target_class_index = torch.tensor([target_class_index]).to(config.device) * torch.ones(2).to("cuda")

In [93]:
target_class_index = target_class_index.unsqueeze(-1)
target_class_index.shape

torch.Size([2, 1])

In [94]:
input_data  = torch.randn(2, 1, 32, 32)
input_data  = mnist_transforms(input_data)
input_data  = input_data.to(config.device)
input_data.requires_grad = True
input_data.shape



torch.Size([2, 1, 32, 32])

In [95]:
output = runner.classifier(input_data)

In [96]:
loss = criterion(output, target_class_index)
loss.backward()
gradients = input_data.grad

RuntimeError: 0D or 1D target tensor expected, multi-target not supported

In [97]:
def get_input_gradients(model, input_data, class_index):
    model.eval()  # Убедитесь, что модель в режиме оценки (не обучения)

    input_data.requires_grad = True  # Позволяет вычислить градиенты для входных данных
    output = model(input_data)

    # Убедитесь, что class_index соответствует одному из выходов модели
    assert 0 <= class_index < output.size(1), f"Недопустимый class_index: {class_index}"

    # Обнуляем градиенты перед обратным распространением ошибки
    model.zero_grad()

    # Создаем тензор с единичным значением в нужном классе
    target = torch.zeros_like(output)
    target[:, class_index] = 1

    # Вычисляем функцию потерь (например, кросс-энтропию) между предсказаниями и целевыми значениями
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(output, target.argmax(dim=1))

    # Обратное распространение ошибки для вычисления градиентов входных данных
    loss.backward()

    # Получаем градиенты входных данных
    gradients = input_data.grad

    return gradients

In [98]:
get_input_gradients(runner.classifier, input_data, 1)

tensor([[[[-8.1677e-03, -3.0676e-02, -3.3568e-02,  ...,  3.2892e-02,
            1.7086e-02,  2.6695e-03],
          [-1.9700e-02,  6.2581e-03,  1.4219e-02,  ..., -6.5805e-02,
           -4.2931e-03,  2.2726e-02],
          [ 1.6743e-02,  3.0358e-03,  7.6520e-03,  ..., -6.2832e-02,
           -7.3870e-02, -4.2718e-02],
          ...,
          [ 1.7750e-03,  1.6841e-02,  9.7737e-04,  ...,  1.9416e-03,
            9.2555e-03, -7.8101e-03],
          [-6.9648e-04, -6.2951e-03,  3.3439e-03,  ...,  1.4104e-03,
            1.4978e-03,  1.0119e-04],
          [ 4.7081e-03,  5.0267e-03,  8.1728e-03,  ..., -3.4762e-03,
           -5.9667e-03, -3.5863e-03]]],


        [[[-8.3743e-03, -1.4829e-02, -1.2600e-03,  ...,  7.1794e-04,
           -1.6119e-03, -2.0831e-03],
          [ 4.3379e-03,  5.4015e-03, -2.7986e-02,  ..., -2.1289e-02,
            2.0460e-03,  6.5354e-03],
          [ 2.4116e-03, -1.7982e-02,  1.1622e-02,  ..., -3.8050e-02,
           -4.5040e-03,  2.1441e-03],
          ...,
   

In [122]:
gradients, gradients.shape

(tensor([[[[-0.0004,  0.0023, -0.0016,  ..., -0.0056, -0.0049, -0.0008],
           [ 0.0022,  0.0023,  0.0039,  ...,  0.0033,  0.0054, -0.0027],
           [-0.0024,  0.0026, -0.0035,  ...,  0.0038, -0.0006, -0.0044],
           ...,
           [-0.0022, -0.0123, -0.0042,  ..., -0.0027,  0.0029,  0.0013],
           [-0.0057, -0.0070, -0.0078,  ..., -0.0020, -0.0015, -0.0016],
           [-0.0004, -0.0058, -0.0045,  ..., -0.0007, -0.0012, -0.0010]]]],
        device='cuda:0'),
 torch.Size([1, 1, 32, 32]))

In [70]:
output[0, 1].backward()

In [71]:
output

tensor([[-2.3596, -2.3740, -1.6151, -1.9797, -2.3420, -1.5721, -2.7542, -2.2944,
         -1.2308, -1.7510]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [5]:
classifier = runner.get_classifier()

In [6]:
import torch

In [7]:
classifier.load_state_dict(torch.load(runner.config.class_guide.checkpoint_path))

<All keys matched successfully>

In [30]:
dict(runner.config.class_guide.classifier_args)

{'block': models.classifier.ResidualBlock, 'layers': [2, 2, 2, 2]}