In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import math
import tqdm

torch.manual_seed(192837)
rand = random.Random(192838)

In [2]:
class Xorer(nn.Module):
    def __init__(self,
                 n_digits: int=4,
                 hidden_dim: int=16,
                 hidden_layers: int=3):
        super().__init__()

        self.foot = nn.Linear(n_digits * 2, hidden_dim)
        self.body = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(hidden_layers)])
        self.head = nn.Linear(hidden_dim, n_digits)
        self.relu = nn.ReLU(inplace=True)
        self.sigm = nn.Sigmoid()
    
    def forward(self, x):
        x = self.relu(self.foot(x))
        for l in self.body:
            x = self.relu(l(x))
        x = self.sigm(self.head(x))
        return x

In [3]:
def to_bitvec(n: int, dim: int):
    return [(1 if ((n & (1 << b)) != 0) else 0) for b in range(dim)]

def from_bitvec(bits: list[int]):
    return sum([bits[b] * (1 << b) for b in range(len(bits))])

In [4]:
DIGITS = 4
TRAIN_TEST_SPLIT = 0.9

dataset_full = [(to_bitvec(i, DIGITS) + to_bitvec(j, DIGITS), to_bitvec(i ^ j, DIGITS)) for i in range(2 ** DIGITS) for j in range(2 ** DIGITS)]
rand.shuffle(dataset_full)
dataset_x, dataset_y = zip(*dataset_full)
dataset_x = torch.tensor(dataset_x, dtype=torch.float32)
dataset_y = torch.tensor(dataset_y, dtype=torch.float32)

n_full = len(dataset_full)
n_train = round(n_full * TRAIN_TEST_SPLIT)
n_test = n_full - n_train

train_x, test_x = dataset_x[:n_train, :], dataset_x[n_train:, :]
train_y, test_y = dataset_y[:n_train, :], dataset_y[n_train:, :]

print(n_full, n_train, n_test)
print(train_x.shape, test_x.shape, train_y.shape, test_y.shape)

256 230 26
torch.Size([230, 8]) torch.Size([26, 8]) torch.Size([230, 4]) torch.Size([26, 4])


In [5]:
loss_fn = F.binary_cross_entropy

def train_epoch(model: nn.Module,
                optimizer: torch.optim.Optimizer,
                dataset_x: torch.Tensor,
                dataset_y: torch.Tensor,
                batch_size: int,
                epoch_i: int,
                verbose: bool=False):
    model.train()

    n_batches = math.ceil(dataset_x.shape[0] / batch_size)
    loss_sum = 0.
    accu_sum = 0.

    for batch_i in tqdm.tqdm(range(n_batches), f'epoch {epoch_i} train') if verbose else range(n_batches):
        x = dataset_x[batch_i * batch_size:(batch_i + 1) * batch_size, :]
        y = dataset_y[batch_i * batch_size:(batch_i + 1) * batch_size, :]

        y_hat = model.forward(x)

        loss = loss_fn(y_hat, y)

        with torch.no_grad():
            accu = (torch.round(y_hat) == y).to(torch.float32).mean()
            accu_sum += accu.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_sum += loss.item()
    
    if verbose:
        print('train loss:', loss_sum / n_batches)
        print('train accu:', accu_sum / n_batches)

In [6]:
def test_epoch(model: nn.Module,
               dataset_x: torch.Tensor,
               dataset_y: torch.Tensor,
               batch_size: int,
               epoch_i: int,
               verbose: bool=False):
    model.train()

    n_batches = math.ceil(dataset_x.shape[0] / batch_size)
    loss_sum = 0.
    accu_sum = 0.

    with torch.no_grad():
        for batch_i in tqdm.tqdm(range(n_batches), f'epoch {epoch_i} test') if verbose else range(n_batches):
            x = dataset_x[batch_i * batch_size:(batch_i + 1) * batch_size, :]
            y = dataset_y[batch_i * batch_size:(batch_i + 1) * batch_size, :]

            y_hat = model.forward(x)

            loss = loss_fn(y_hat, y)

            accu = (torch.round(y_hat) == y).to(torch.float32).mean()
            accu_sum += accu.item()

            loss_sum += loss.item()

    if verbose:
        print(' test loss:', loss_sum / n_batches)
        print(' test accu:', accu_sum / n_batches)

In [7]:
model = Xorer(DIGITS, hidden_dim=16, hidden_layers=3)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
BATCH_SIZE = 8
N_EPOCHS = 400

In [8]:
for epoch_i in range(N_EPOCHS):
    train_epoch(model, optimizer, train_x, train_y, BATCH_SIZE, epoch_i, verbose=(epoch_i + 1) % 20 == 0)
    test_epoch(model, test_x, test_y, BATCH_SIZE, epoch_i, verbose=(epoch_i + 1) % 20 == 0)

epoch 19 train: 100%|██████████| 29/29 [00:00<00:00, 398.08it/s]


train loss: 0.5226443749049614
train accu: 0.6339798857425821


epoch 19 test: 100%|██████████| 4/4 [00:00<00:00, 1550.00it/s]


 test loss: 0.5279723852872849
 test accu: 0.5859375


epoch 39 train: 100%|██████████| 29/29 [00:00<00:00, 473.70it/s]


train loss: 0.3899604729537306
train accu: 0.7464080452919006


epoch 39 test: 100%|██████████| 4/4 [00:00<00:00, 2070.49it/s]


 test loss: 0.3831762298941612
 test accu: 0.8203125


epoch 59 train: 100%|██████████| 29/29 [00:00<00:00, 313.08it/s]


train loss: 0.3468946084893983
train accu: 0.7543103448275862


epoch 59 test: 100%|██████████| 4/4 [00:00<00:00, 1529.37it/s]


 test loss: 0.3575332835316658
 test accu: 0.6953125


epoch 79 train: 100%|██████████| 29/29 [00:00<00:00, 422.90it/s]


train loss: 0.33313157948954353
train accu: 0.807471265052927


epoch 79 test: 100%|██████████| 4/4 [00:00<00:00, 1303.08it/s]


 test loss: 0.33757051080465317
 test accu: 0.8203125


epoch 99 train: 100%|██████████| 29/29 [00:00<00:00, 383.81it/s]


train loss: 0.28392662056561174
train accu: 0.8415948275862069


epoch 99 test: 100%|██████████| 4/4 [00:00<00:00, 2128.28it/s]


 test loss: 0.27776307612657547
 test accu: 0.84375


epoch 119 train: 100%|██████████| 29/29 [00:00<00:00, 380.57it/s]


train loss: 0.2369566986273075
train accu: 0.8591954029839615


epoch 119 test: 100%|██████████| 4/4 [00:00<00:00, 2278.58it/s]


 test loss: 0.21687819436192513
 test accu: 0.828125


epoch 139 train: 100%|██████████| 29/29 [00:00<00:00, 378.25it/s]


train loss: 0.19757787831898393
train accu: 0.8706896551724138


epoch 139 test: 100%|██████████| 4/4 [00:00<00:00, 1372.71it/s]


 test loss: 0.1925189271569252
 test accu: 0.84375


epoch 159 train: 100%|██████████| 29/29 [00:00<00:00, 435.97it/s]


train loss: 0.18245429057499457
train accu: 0.8803879310344828


epoch 159 test: 100%|██████████| 4/4 [00:00<00:00, 1722.68it/s]


 test loss: 0.18502729013562202
 test accu: 0.859375


epoch 179 train: 100%|██████████| 29/29 [00:00<00:00, 450.93it/s]


train loss: 0.17742835139406138
train accu: 0.884698275862069


epoch 179 test: 100%|██████████| 4/4 [00:00<00:00, 1696.38it/s]


 test loss: 0.18196140602231026
 test accu: 0.859375


epoch 199 train: 100%|██████████| 29/29 [00:00<00:00, 460.19it/s]


train loss: 0.17477212692129201
train accu: 0.8857758620689655


epoch 199 test: 100%|██████████| 4/4 [00:00<00:00, 1874.55it/s]


 test loss: 0.17926020175218582
 test accu: 0.8515625


epoch 219 train: 100%|██████████| 29/29 [00:00<00:00, 448.75it/s]


train loss: 0.17001581089249973
train accu: 0.9008620689655172


epoch 219 test: 100%|██████████| 4/4 [00:00<00:00, 1456.86it/s]


 test loss: 0.17398091033101082
 test accu: 0.875


epoch 239 train: 100%|██████████| 29/29 [00:00<00:00, 428.54it/s]


train loss: 0.15274754680436234
train accu: 0.9285201142574179


epoch 239 test: 100%|██████████| 4/4 [00:00<00:00, 1493.30it/s]


 test loss: 0.15468808263540268
 test accu: 0.9296875


epoch 259 train: 100%|██████████| 29/29 [00:00<00:00, 455.19it/s]


train loss: 0.1064619008323242
train accu: 0.9608477004643144


epoch 259 test: 100%|██████████| 4/4 [00:00<00:00, 1864.55it/s]


 test loss: 0.11499328725039959
 test accu: 0.96875


epoch 279 train: 100%|██████████| 29/29 [00:00<00:00, 347.21it/s]


train loss: 0.032847051283922686
train accu: 0.9989224137931034


epoch 279 test: 100%|██████████| 4/4 [00:00<00:00, 2297.62it/s]


 test loss: 0.04868536302819848
 test accu: 1.0


epoch 299 train: 100%|██████████| 29/29 [00:00<00:00, 581.65it/s]


train loss: 0.005735184547716174
train accu: 1.0


epoch 299 test: 100%|██████████| 4/4 [00:00<00:00, 2538.54it/s]


 test loss: 0.009878653916530311
 test accu: 1.0


epoch 319 train: 100%|██████████| 29/29 [00:00<00:00, 619.37it/s]


train loss: 0.0018325320360700376
train accu: 1.0


epoch 319 test: 100%|██████████| 4/4 [00:00<00:00, 1855.07it/s]


 test loss: 0.0038806755328550935
 test accu: 1.0


epoch 339 train: 100%|██████████| 29/29 [00:00<00:00, 557.27it/s]


train loss: 0.00034466537933973277
train accu: 1.0


epoch 339 test: 100%|██████████| 4/4 [00:00<00:00, 1731.22it/s]


 test loss: 0.004731532484584022
 test accu: 1.0


epoch 359 train: 100%|██████████| 29/29 [00:00<00:00, 579.09it/s]


train loss: 0.0001853805578289682
train accu: 1.0


epoch 359 test: 100%|██████████| 4/4 [00:00<00:00, 2222.44it/s]


 test loss: 0.003388426994206384
 test accu: 1.0


epoch 379 train: 100%|██████████| 29/29 [00:00<00:00, 472.42it/s]


train loss: 0.00012019582673333232
train accu: 1.0


epoch 379 test: 100%|██████████| 4/4 [00:00<00:00, 2162.57it/s]


 test loss: 0.0026088613722095033
 test accu: 1.0


epoch 399 train: 100%|██████████| 29/29 [00:00<00:00, 537.35it/s]


train loss: 8.585623637090663e-05
train accu: 1.0


epoch 399 test: 100%|██████████| 4/4 [00:00<00:00, 2074.85it/s]

 test loss: 0.0020974984818167286
 test accu: 1.0





In [16]:
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

class AntiNet():
    def __init__(self,
               inner_net: nn.Module,
               input_shape: torch.Size,
               desired_output: torch.Tensor,
               input_min: float,
               input_max: float,
               lr: float=1e-2,
               temp_start: float=1.0,
               temp_end: float=0.05):
        super().__init__()

        self.input_shape = input_shape
        self.input = nn.Parameter(
            data=torch.normal(0, 1, size=input_shape, dtype=torch.float32),
            requires_grad=True
        )
        self.inner_net = inner_net
        self.grad_eraser = torch.optim.SGD(self.inner_net.parameters())
        self.anti_optimizer = torch.optim.Adam((self.input,), lr=lr)
        self.anti_scheduler = CosineAnnealingWarmRestarts(self.anti_optimizer, T_0=10, T_mult=2)

        self.desired_output = desired_output
        self.output_shape = self.desired_output.shape

        self.input_min = input_min
        self.input_max = input_max

        self.temp_start = temp_start
        self.temp_end = temp_end
        self.norm_func = lambda t: torch.softmax(t, dim=-1)

    def test(self):
        with torch.no_grad():
            return self.inner_net.forward(self.norm_func(self.input / self.temp_end))

    def zero_grad(self):
        self.grad_eraser.zero_grad()
        self.anti_optimizer.zero_grad()

    # def crop(self):
    #     self.input.data.copy_(torch.clamp(self.input.detach().clone(), self.input_min, self.input_max))

    def parabolic_loss(self, t: torch.Tensor):
        return (t * (1 - t)).sum()

    def step(self, step_i: int, max_steps: int):
        prev_input = self.input.detach().clone()

        inp_soft = self.norm_func(self.input / (self.temp_start + (self.temp_end - self.temp_start) / (max_steps - 1) * step_i))

        output = self.inner_net.forward(inp_soft)
        loss = F.binary_cross_entropy(output, self.desired_output) \
            + self.parabolic_loss(inp_soft)
        
        loss.backward()
        self.anti_optimizer.step()
        self.zero_grad()

        self.anti_scheduler.step()

        # self.crop()

        return prev_input, output.detach().clone(), loss.detach().clone()

In [17]:
ANTI_NET_STEPS = 3000
TEST_EXAMPLES = 5
TEST_TRIES = 5

for i in range(5):
    desired_input = test_x[i]
    desired_output = test_y[i]

    print(f'-=-=- step {i} -=-=-')
    print(f'desired output:', desired_output)
    print(f'example  input:', desired_input)

    fin_loss = float('+inf')
    fin_outp, fin_inp = None, None

    for try_i in range(TEST_TRIES):
        anti_net = AntiNet(model, (DIGITS * 2,), desired_output,
                       input_min=0.0, input_max=1.0,
                       lr=5e-2,
                       temp_start=1.0, temp_end=0.05)

        for step_i in tqdm.tqdm(range(ANTI_NET_STEPS), f'attempt #{try_i}'):
            inp, outp, loss = anti_net.step(step_i, ANTI_NET_STEPS)

        if loss.item() < fin_loss:
            fin_loss = loss.item()
            fin_outp = outp
            fin_inp = inp
        
        print(f'attempt #{try_i} loss:', loss.item())

    print(f'fin input:', fin_inp)
    print(f'fin input (sigm):', anti_net.norm_func(fin_inp / anti_net.temp_end))
    print(f'fin output:', fin_outp)
    print(f'fin loss:', fin_loss)

-=-=- step 0 -=-=-
desired output: tensor([0., 0., 1., 0.])
example  input: tensor([1., 1., 1., 0., 1., 1., 0., 0.])


attempt #0: 100%|██████████| 3000/3000 [00:08<00:00, 367.32it/s]


attempt #0 loss: 0.00014561894931830466


attempt #1: 100%|██████████| 3000/3000 [00:07<00:00, 395.11it/s]


attempt #1 loss: 27.003427505493164


attempt #2: 100%|██████████| 3000/3000 [00:07<00:00, 415.50it/s]


attempt #2 loss: 3.7856283597648144e-05


attempt #3: 100%|██████████| 3000/3000 [00:07<00:00, 383.14it/s]


attempt #3 loss: 0.00014561894931830466


attempt #4: 100%|██████████| 3000/3000 [00:07<00:00, 426.63it/s]


attempt #4 loss: 3.7856283597648144e-05
fin input: tensor([-2.9811, -2.9342,  3.6652, -2.5261, -3.2043, -1.8895, -2.1527,  1.1669])
fin input (sigm): tensor([0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 1.9982e-22])
fin output: tensor([3.8542e-08, 1.1380e-04, 1.0000e+00, 3.7337e-05])
fin loss: 3.7856283597648144e-05
-=-=- step 1 -=-=-
desired output: tensor([0., 1., 0., 1.])
example  input: tensor([0., 1., 1., 0., 0., 0., 1., 1.])


attempt #0: 100%|██████████| 3000/3000 [00:06<00:00, 434.83it/s]


attempt #0 loss: 29.3837947845459


attempt #1: 100%|██████████| 3000/3000 [00:06<00:00, 429.26it/s]


attempt #1 loss: 2.43985652923584


attempt #2: 100%|██████████| 3000/3000 [00:06<00:00, 435.47it/s]


attempt #2 loss: 0.6091113090515137


attempt #3: 100%|██████████| 3000/3000 [00:06<00:00, 435.97it/s]


attempt #3 loss: 0.3952809274196625


attempt #4: 100%|██████████| 3000/3000 [00:06<00:00, 429.49it/s]


attempt #4 loss: 2.43985652923584
fin input: tensor([-2.2744,  3.4352, -1.0701, -2.0347, -1.9277, -1.0020, -1.1020,  3.4918])
fin input (sigm): tensor([0.0000e+00, 2.4364e-01, 1.7957e-40, 0.0000e+00, 0.0000e+00, 7.0097e-40,
        9.4917e-41, 7.5636e-01])
fin output: tensor([8.8354e-06, 8.9879e-01, 2.1240e-17, 9.9983e-01])
fin loss: 0.3952809274196625
-=-=- step 2 -=-=-
desired output: tensor([1., 1., 0., 1.])
example  input: tensor([1., 1., 1., 1., 0., 0., 1., 0.])


attempt #0: 100%|██████████| 3000/3000 [00:06<00:00, 438.91it/s]


attempt #0 loss: 3.7890193462371826


attempt #1: 100%|██████████| 3000/3000 [00:08<00:00, 338.71it/s]


attempt #1 loss: 3.7890193462371826


attempt #2: 100%|██████████| 3000/3000 [00:08<00:00, 344.32it/s]


attempt #2 loss: 3.7890193462371826


attempt #3: 100%|██████████| 3000/3000 [00:09<00:00, 313.61it/s]


attempt #3 loss: 1.9875338077545166


attempt #4: 100%|██████████| 3000/3000 [00:07<00:00, 377.30it/s]


attempt #4 loss: 3.7890143394470215
fin input: tensor([ 2.7783, -1.9592, -1.7953,  2.7620, -2.3868,  1.7359, -3.1499, -1.5535])
fin input (sigm): tensor([5.8055e-01, 4.1226e-42, 1.0926e-40, 4.1945e-01, 1.4013e-45, 5.1269e-10,
        0.0000e+00, 1.3764e-38])
fin output: tensor([1.0000e+00, 2.5913e-03, 1.2181e-11, 9.5459e-01])
fin loss: 1.9875338077545166
-=-=- step 3 -=-=-
desired output: tensor([0., 0., 1., 1.])
example  input: tensor([1., 1., 0., 1., 1., 1., 1., 0.])


attempt #0: 100%|██████████| 3000/3000 [00:07<00:00, 383.27it/s]


attempt #0 loss: 2.5016441345214844


attempt #1: 100%|██████████| 3000/3000 [00:08<00:00, 335.07it/s]


attempt #1 loss: 2.5489070415496826


attempt #2: 100%|██████████| 3000/3000 [00:08<00:00, 338.90it/s]


attempt #2 loss: 9.471370697021484


attempt #3: 100%|██████████| 3000/3000 [00:08<00:00, 349.65it/s]


attempt #3 loss: 2.5016441345214844


attempt #4: 100%|██████████| 3000/3000 [00:06<00:00, 443.09it/s]


attempt #4 loss: 0.37424036860466003
fin input: tensor([-3.5766, -1.1691,  2.3928, -1.1910,  2.3336, -1.4651, -3.1458, -1.1934])
fin input (sigm): tensor([0.0000e+00, 8.8374e-32, 7.6546e-01, 5.7021e-32, 2.3454e-01, 2.3709e-34,
        0.0000e+00, 5.4255e-32])
fin output: tensor([4.5975e-08, 3.9457e-04, 1.0000e+00, 9.4147e-01])
fin loss: 0.37424036860466003
-=-=- step 4 -=-=-
desired output: tensor([1., 0., 1., 0.])
example  input: tensor([1., 1., 0., 0., 0., 1., 1., 0.])


attempt #0: 100%|██████████| 3000/3000 [00:06<00:00, 449.45it/s]


attempt #0 loss: 4.267916679382324


attempt #1: 100%|██████████| 3000/3000 [00:09<00:00, 317.68it/s]


attempt #1 loss: 4.267916679382324


attempt #2: 100%|██████████| 3000/3000 [00:08<00:00, 334.41it/s]


attempt #2 loss: 4.267916679382324


attempt #3: 100%|██████████| 3000/3000 [00:07<00:00, 412.21it/s]


attempt #3 loss: 4.267916679382324


attempt #4: 100%|██████████| 3000/3000 [00:08<00:00, 354.04it/s]

attempt #4 loss: 4.267916679382324
fin input: tensor([-2.5934, -2.6062,  5.3710, -5.6458,  0.2409, -2.3132, -3.6352, -4.3053])
fin input (sigm): tensor([0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 2.8026e-45, 0.0000e+00,
        0.0000e+00, 0.0000e+00])
fin output: tensor([3.8542e-08, 1.1380e-04, 1.0000e+00, 3.7337e-05])
fin loss: 4.267916679382324



