## Задание №10

Бинаризовать нейронную сеть, используя предоставленный классы `ParabolaSignApproximator` и `STESignApproximator`. Доступен `torch==1.5.1`

In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import sys

In [2]:
class ParabolaSignApproximator(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        values = x.detach().numpy()
        return x.new(np.where(values < 0, -1, 1))

    @staticmethod
    def backward(ctx, grad_output):

        x = ctx.saved_tensors[0].detach().numpy()
        grads = grad_output.detach().numpy()

        grad_input = np.zeros_like(grad_output)
        mask = (-1 <= x) & (x < 0)
        grad_input[mask] = grads[mask] * (2 * x[mask] + 2)

        mask = (0 <= x) & (x < 1)
        grad_input[mask] = grads[mask] * (-2 * x[mask] + 2)
        return grad_output.new(grad_input)


class STESignApproximator(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        values = x.detach().numpy()
        return x.new(np.where(values < 0, -1, 1))

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

In [3]:
tensor_shape = (int(input()), int(input()))
X = [list(map(float, input().split())) for x_str in range(tensor_shape[0])]
feats_out, feats_in = (int(input()), int(input()))
W = [list(map(float, input().split())) for feats_str in range(feats_out)]
y = np.zeros((tensor_shape[0], feats_out))

X = torch.tensor(X)
W = torch.tensor(W)
y = torch.tensor(y)

In [5]:
def approximate_sign(input_data, weights, output_shape):
    ste_sign = STESignApproximator.apply
    parabolic_sign = ParabolaSignApproximator.apply

    input_data = torch.clone(input_data).requires_grad_(True)
    weights = nn.Parameter(data=weights, requires_grad=True)

    binary_weights = ste_sign(weights)
    binary_input_data = parabolic_sign(input_data)

    outputs = F.linear(binary_input_data, binary_weights, None)
    torch.sum(outputs).backward()

    return input_data, weights

inputs, weights = approximate_sign(X, W, y)

In [6]:
def print_gradient(tensor):
    np.savetxt(sys.stdout, tensor.grad.numpy(), fmt='%.04f')

print_gradient(inputs)
print_gradient(weights)

4.8456 -0.5448 -2.2704
4.6236 -1.7770 0.0000
0.0000 -0.8076 -3.8202
3.1506 0.0000 0.0000
-4.0000 0.0000 -2.0000
-4.0000 0.0000 -2.0000
-4.0000 0.0000 -2.0000
-4.0000 0.0000 -2.0000
-4.0000 0.0000 -2.0000
