In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import math
import tqdm
from typing import *

random.seed(192837)
torch.manual_seed(192838)

<torch._C.Generator at 0x7a0517fee370>

In [2]:
N_PARAMS = 4
N_DATA = 5000
SAMPLES_SPLIT = 0.95
TRAIN_TEST_SPLIT = 0.9
BATCH_SIZE = 8


def FORMULA(params: Iterable[float]):
    a, b, c, d = params
    
    return 0.0537 + \
           0.234 * a + \
           -0.51 * b + \
           0.112 * c + \
           -0.1633 * d + \
           -0.857 * a**2 + \
           0.117 * b**2 + \
           0.9 * a * b + \
           -0.363 * c**2 + \
           -0.103 * d**2 + \
           -0.5 * c*d + \
           1.2 * a*b*c*d

In [3]:
class MLP(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 hidden_layers: int,
                 output_dim: int):
        super().__init__()

        self.foot = nn.Linear(input_dim, hidden_dim)
        self.body = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(hidden_layers)])
        self.head = nn.Linear(hidden_dim, output_dim)
        
        self.act = nn.ReLU()
    
    def forward(self, x):
        x = self.act(self.foot(x))
        for layer in self.body:
            x = self.act(layer(x))
        x = self.head(x)
        return x

In [4]:
dataset = torch.rand(size=(N_DATA, N_PARAMS))

n_leftout = round(N_DATA * (1 - SAMPLES_SPLIT))
dataset_leftout, dataset = dataset[:n_leftout, :], dataset[n_leftout:, :]

n_train = round((N_DATA - n_leftout) * TRAIN_TEST_SPLIT)
n_test = N_DATA - n_leftout - n_train
dataset_train, dataset_test = dataset[:n_train, :], dataset[n_train:, :]

dataset_train_y = torch.tensor([FORMULA(dataset_train[i]) for i in range(n_train)])
dataset_test_y = torch.tensor([FORMULA(dataset_test[i]) for i in range(n_test)])

N_DATA, n_leftout, n_train, n_test

(5000, 250, 4275, 475)

In [5]:
loss_fn = nn.MSELoss()

def train_epoch(model: nn.Module, optimizer: optim.Optimizer, dataset_x: torch.Tensor, dataset_y: torch.Tensor, batch_size=BATCH_SIZE):
    model.train()
    
    loss_sum = 0
    n_batches = math.ceil(dataset_x.shape[0] / batch_size)

    for i in tqdm.tqdm(range(n_batches), 'train'):
        x, y = dataset_x[i * batch_size: (i+1) * batch_size, :], dataset_y[i * batch_size: (i+1) * batch_size]
        y_hat = model.forward(x).squeeze(dim=1)

        loss = loss_fn(y_hat, y)
        loss_sum += loss.detach().clone()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    return loss_sum / n_batches

loss_fn = nn.MSELoss()

def test_epoch(model: nn.Module, dataset_x: torch.Tensor, dataset_y: torch.Tensor, batch_size=BATCH_SIZE):
    model.eval()
    
    with torch.no_grad():
        loss_sum = 0
        n_batches = math.ceil(dataset_x.shape[0] / batch_size)

        for i in tqdm.tqdm(range(n_batches), ' test'):
            x, y = dataset_x[i * batch_size: (i+1) * batch_size, :], dataset_y[i * batch_size: (i+1) * batch_size]
            y_hat = model.forward(x).squeeze(dim=1)

            loss = loss_fn(y_hat, y)
            loss_sum += loss
        
        return loss_sum / n_batches

In [6]:
model = MLP(input_dim=N_PARAMS, hidden_dim=8, hidden_layers=3, output_dim=1)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
N_EPOCHS = 20

In [7]:
for epoch_i in range(N_EPOCHS):
    print(f'=== epoch {epoch_i} ===')

    train_loss = train_epoch(model, optimizer, dataset_train, dataset_train_y)
    test_loss = test_epoch(model, dataset_test, dataset_test_y)

    print('train loss:', train_loss.item())
    print(' test loss:', test_loss.item())

=== epoch 0 ===


train: 100%|██████████| 535/535 [00:02<00:00, 267.26it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 3328.42it/s]


train loss: 0.025589799508452415
 test loss: 0.007275852840393782
=== epoch 1 ===


train: 100%|██████████| 535/535 [00:02<00:00, 245.92it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2845.36it/s]


train loss: 0.004866989329457283
 test loss: 0.004260052461177111
=== epoch 2 ===


train: 100%|██████████| 535/535 [00:02<00:00, 222.08it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 1405.59it/s]


train loss: 0.003849201137199998
 test loss: 0.003741301130503416
=== epoch 3 ===


train: 100%|██████████| 535/535 [00:02<00:00, 221.78it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2471.62it/s]


train loss: 0.003602651646360755
 test loss: 0.003625196870416403
=== epoch 4 ===


train: 100%|██████████| 535/535 [00:02<00:00, 197.79it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2233.57it/s]


train loss: 0.00339615810662508
 test loss: 0.003737350460141897
=== epoch 5 ===


train: 100%|██████████| 535/535 [00:02<00:00, 220.62it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 1813.96it/s]


train loss: 0.003132794750854373
 test loss: 0.003523388411849737
=== epoch 6 ===


train: 100%|██████████| 535/535 [00:02<00:00, 205.36it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 1749.89it/s]


train loss: 0.0027917474508285522
 test loss: 0.003238195786252618
=== epoch 7 ===


train: 100%|██████████| 535/535 [00:02<00:00, 211.93it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 3146.20it/s]


train loss: 0.002421214012429118
 test loss: 0.002887919545173645
=== epoch 8 ===


train: 100%|██████████| 535/535 [00:02<00:00, 219.60it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 1353.43it/s]


train loss: 0.0020620531868189573
 test loss: 0.002621087711304426
=== epoch 9 ===


train: 100%|██████████| 535/535 [00:02<00:00, 239.34it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2771.08it/s]


train loss: 0.0017867962596938014
 test loss: 0.002090037800371647
=== epoch 10 ===


train: 100%|██████████| 535/535 [00:02<00:00, 254.22it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2708.92it/s]


train loss: 0.0016118374187499285
 test loss: 0.0015880054561421275
=== epoch 11 ===


train: 100%|██████████| 535/535 [00:02<00:00, 264.30it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 1478.41it/s]


train loss: 0.0014972137287259102
 test loss: 0.001404574722982943
=== epoch 12 ===


train: 100%|██████████| 535/535 [00:01<00:00, 284.09it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2840.68it/s]


train loss: 0.0014407869894057512
 test loss: 0.0013549705035984516
=== epoch 13 ===


train: 100%|██████████| 535/535 [00:01<00:00, 275.87it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2517.99it/s]


train loss: 0.0014071118785068393
 test loss: 0.0013300826540216804
=== epoch 14 ===


train: 100%|██████████| 535/535 [00:02<00:00, 264.05it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2659.90it/s]


train loss: 0.0013937745243310928
 test loss: 0.0013374699046835303
=== epoch 15 ===


train: 100%|██████████| 535/535 [00:01<00:00, 279.80it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2798.57it/s]


train loss: 0.0013844959903508425
 test loss: 0.0013336432166397572
=== epoch 16 ===


train: 100%|██████████| 535/535 [00:01<00:00, 275.98it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 3004.16it/s]


train loss: 0.0013762477319687605
 test loss: 0.001321546034887433
=== epoch 17 ===


train: 100%|██████████| 535/535 [00:01<00:00, 269.18it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2186.74it/s]


train loss: 0.0013641368132084608
 test loss: 0.0013064087834209204
=== epoch 18 ===


train: 100%|██████████| 535/535 [00:01<00:00, 282.48it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2754.85it/s]


train loss: 0.0013568371068686247
 test loss: 0.0012916355626657605
=== epoch 19 ===


train: 100%|██████████| 535/535 [00:01<00:00, 270.45it/s]
 test: 100%|██████████| 60/60 [00:00<00:00, 2953.42it/s]

train loss: 0.001349085010588169
 test loss: 0.001299887546338141





In [None]:
class AntiNet():
    def __init__(self,
               inner_net: nn.Module,
               input_shape: torch.Size,
               desired_output: torch.Tensor,
               lr: float=1e-2):
        super().__init__()

        self.input_shape = input_shape
        self.input = nn.Parameter(
            data=torch.normal(mean=0.0, std=1.0, size=self.input_shape),
            requires_grad=True
        )
        self.inner_net = inner_net
        self.grad_eraser = torch.optim.SGD(self.inner_net.parameters())
        self.anti_optimizer = torch.optim.Adam((self.input,), lr=lr)
        self.desired_output = desired_output
        self.output_shape = self.desired_output.shape

    def test(self):
        with torch.no_grad():
            return self.inner_net.forward(self.input)

    def zero_grad(self):
        self.grad_eraser.zero_grad()
        self.anti_optimizer.zero_grad()
    
    def step(self):
        prev_input = self.input.detach().clone()

        output = self.inner_net.forward(self.input)
        loss = F.mse_loss(output, self.desired_output)

        loss.backward()
        self.anti_optimizer.step()
        self.zero_grad()

        return prev_input, output.detach().clone(), loss.detach().clone()

In [9]:
for i in range(5):
    desired_input = dataset_leftout[i]
    desired_output = torch.tensor([FORMULA(desired_input)])

    desired_input, desired_output

    anti_net = AntiNet(model, (N_PARAMS,), desired_output)
    ANTI_NET_STEPS = 300

    print(f'-=-=- step {i} -=-=-')
    print(f'des  input:', desired_input)
    print(f'des output:', desired_output)

    for step_i in range(ANTI_NET_STEPS):
        inp, outp, loss = anti_net.step()
    
    print(f'fin  input:', inp)
    print(f'fin output:', outp)
    print(f'fin loss:', loss)

-=-=- step 0 -=-=-
des  input: tensor([0.4501, 0.8621, 0.2806, 0.9874])
des output: tensor([-0.2865])
fin  input: tensor([-0.2922, -0.3643,  1.1482, -0.2821])
fin output: tensor([-0.2865])
fin loss: tensor(7.9936e-15)
-=-=- step 1 -=-=-
des  input: tensor([0.0962, 0.1024, 0.4580, 0.9973])
des output: tensor([-0.4870])
fin  input: tensor([-0.9342,  1.2046, -1.0785, -0.0882])
fin output: tensor([-0.4870])
fin loss: tensor(2.2204e-14)
-=-=- step 2 -=-=-
des  input: tensor([0.0156, 0.9726, 0.8447, 0.8128])
des output: tensor([-1.0106])
fin  input: tensor([-0.4207,  0.7653,  0.4533,  0.6726])
fin output: tensor([-1.0106])
fin loss: tensor(0.)
-=-=- step 3 -=-=-
des  input: tensor([0.0163, 0.5209, 0.9359, 0.4601])
des output: tensor([-0.6900])
fin  input: tensor([-0.9574,  0.3982,  1.7066, -1.6092])
fin output: tensor([-0.6900])
fin loss: tensor(3.5527e-15)
-=-=- step 4 -=-=-
des  input: tensor([0.1960, 0.8433, 0.5908, 0.4873])
des output: tensor([-0.3829])
fin  input: tensor([-0.8491,  0.60