In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class WPertLinear(nn.Module):
    @torch.no_grad()
    def __init__(self, n: int, m: int) -> None:
        self.w = torch.randn(n, m)
        self.b = torch.randn(m)
        self._w = torch.randn(n, m)
        self._b = torch.randn(m)

    @torch.no_grad()
    def forward(self, x, _x):
        self._w = torch.randn_like(self._w).normal_()
        self._b = torch.randn_like(self._b).normal_()
        y = torch.matmul(x, self.w) + self.b
        _y = torch.matmul(x, self._w) + torch.matmul(_x, self.w) + self._b
        return y, _y

    @torch.no_grad()
    def perturb(self, _objective):
        self.w.grad = self._w * ((_objective > 0) * 2.0 - 1)
        self.b.grad = self._b * ((_objective > 0) * 2.0 - 1)

class WPertReLU(nn.Module):
    @torch.no_grad()
    def __init__(self) -> None:
        super().__init__()
        self._m = 0.0

    @torch.no_grad()
    def forward(self, x, _x):
        self._m = (x >= 0) * 1.0
        y = x.relu()
        return y, self._m * _x

    @torch.no_grad()
    def perturb(self, _objective):
        self.y.grad = self._m * ((_objective > 0) * 2.0 - 1)

class WPertConv2d(nn.Module):
    @torch.no_grad()
    def __init__(self, kernel_size: int, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.kernel = torch.randn(out_channels, in_channels, *kernel_size)
        self._kernel = torch.randn_like(self.kernel)
        self.bias = torch.randn(out_channels)
        self._bias = torch.randn_like(self.bias)

    @torch.no_grad()
    def forward(self, x, _x):
        self._kernel = torch.rand_like(self._kernel).normal_()
        self._bias = torch.rand_like(self._bias).normal_()
        y = F.conv2d(x, self.kernel, self.bias)
        _y = F.conv2d(_x, self.kernel) + F.conv2d(x, self._kernel) + self._bias
        return y, _y
    
    @torch.no_grad()
    def perturb(self, _objective):
        self.kernel.grad = self._kernel * ((_objective > 0) * 2.0 - 1)
        self.bias.grad = self._bias * ((_objective > 0) * 2.0 - 1)

class WPertSoftmax(nn.Module):
    @torch.no_grad()
    def __init__(self) -> None:
        super().__init__()
        self._m = 0.0
    
    @torch.no_grad()
    def forward(self, x, _x):
        # \sum_j!=i (e^{x_j} / (\sum_k e^{x_k})) + e^{x_i} / (\sum_k e^{x_k})
        # (\sum_j -e^{x_j+x_i}) / (\sum_k e^{x_k})^2 + (\sum_k e^{x_k+x_i}) / (\sum_k e^{x_k})^2
        # self._m = 
        pass

    @torch.no_grad()
    def perturb(self, _objective):
        pass