In [1]:
import os
import sys
sys.path.insert(0, "..")
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tqdm import tqdm

import matplotlib.pyplot as plt

from models import *
from my_datasets import *
from experiments import *
from experiments.utils.model_loader_utils import *

torch.set_printoptions(sci_mode=False, precision=2, linewidth=120)

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
n, d = 16,32
def hot(i,p):
    return F.one_hot(i,p)

In [12]:
_, res_dataset = load_model_and_dataset_from_big_grid(embed_dim=d, num_vars=n, seed=601)
res_model = TheoryAutoregKStepsModel(num_vars=n, num_steps=3)
res_model.eval()
atk_dataset = CoerceStateDataset(res_dataset, 4, 1024)

Querying id: model-SynSAR_gpt2_d32_L1_H1__DMY_nv16_nr32_exph3.000__ntr262144_ntt65536_bsz512_steps8192_lr0.00050_seed601:v0
Downloading: <Artifact QXJ0aWZhY3Q6ODI5NzQzMjYx>


[34m[1mwandb[0m:   2 of 2 files downloaded.  


In [26]:
torch.manual_seed(102)
dataloader = DataLoader(atk_dataset, batch_size=1)
for batch in dataloader:
    tokens = batch["tokens"]
    labels = batch["labels"]
    infos = batch["infos"]
    hints = batch["hints"]
    a = infos[:,0]
    break

In [27]:
hints

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
         [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
         [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]]])

In [79]:
tgt1, tgt2, tgt3 = hints.chunk(3, dim=1)
atk_rule1 = torch.cat([hot(a,n).view(-1,1,n), 10*(2*tgt1 - 1)], dim=-1)
atk_rule2 = torch.cat([tgt1, 1000*(2*tgt2 - 1)], dim=-1)
atk_rule3 = torch.cat([tgt2, 1000*(2*tgt3 - 1)], dim=-1)
init_token = torch.cat([torch.zeros(a.size(0),1,n), hot(a,n).view(-1,1,n)], dim=-1)
atk_tokens = torch.cat([atk_rule1, atk_rule2, atk_rule3, init_token], dim=1)
adv_tokens = atk_tokens

In [80]:
adv_out = res_model(tokens=adv_tokens, output_attentions=True)
adv_pred = (adv_out.logits > 0).long()

In [81]:
(adv_pred == labels[:,1:]).float().mean() < 1 - 1e-5

tensor(False)

In [82]:
adv_pred - labels[:,1:]

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

In [83]:
adv_tokens.long()

tensor([[[    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     1,
              0,   -10,   -10,   -10,   -10,   -10,   -10,   -10,   -10,    10,   -10,    10,    10,   -10,    10,
             10,   -10],
         [    0,     0,     0,     0,     0,     0,     0,     0,     1,     0,     1,     1,     0,     1,     1,
              0, -1000, -1000, -1000, -1000, -1000,  1000, -1000, -1000,  1000, -1000, -1000,  1000, -1000, -1000,
           1000, -1000],
         [    0,     0,     0,     0,     0,     1,     0,     0,     1,     0,     0,     1,     0,     0,     1,
              0,  1000, -1000, -1000,  1000, -1000, -1000, -1000,  1000, -1000, -1000,  1000, -1000, -1000, -1000,
          -1000, -1000],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
     

In [84]:
(adv_tokens[:,0:3,n:] > 0).long() == labels[:,1:]

tensor([[[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]]])

In [85]:
labels[:,1:]

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
         [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
         [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]]])

In [86]:
adv_pred

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
         [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
         [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]]])

In [87]:
adv_pred - labels[:,1:]

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

In [69]:
A1, A2, A3 = adv_out.attentions
A1.shape, A2.shape, A3.shape

(torch.Size([1, 4, 4]), torch.Size([1, 5, 5]), torch.Size([1, 6, 6]))

In [70]:
A1

tensor([[[0.00, 1.00, 0.00, 0.00],
         [0.00, 0.00, 1.00, 0.00],
         [0.00, 0.00, 0.00, 1.00],
         [0.50, 0.00, 0.00, 0.50]]])

In [71]:
A2

tensor([[[0.00, 1.00, 0.00, 0.00, 0.00],
         [0.00, 0.00, 1.00, 0.00, 0.00],
         [0.00, 0.00, 0.00, 0.50, 0.50],
         [0.33, 0.00, 0.00, 0.33, 0.33],
         [0.25, 0.25, 0.00, 0.25, 0.25]]])

In [72]:
A3

tensor([[[0.00, 1.00, 0.00, 0.00, 0.00, 0.00],
         [0.00, 0.00, 1.00, 0.00, 0.00, 0.00],
         [0.00, 0.00, 0.00, 0.33, 0.33, 0.33],
         [0.25, 0.00, 0.00, 0.25, 0.25, 0.25],
         [0.20, 0.20, 0.00, 0.20, 0.20, 0.20],
         [0.20, 0.20, 0.00, 0.20, 0.20, 0.20]]])

In [73]:
tgt1 - tgt2

tensor([[[ 0,  0,  0,  0,  0, -1,  0,  0,  0,  0,  1,  0,  0,  1,  0,  0]]])

In [75]:
tgt2 - (tgt1 + tgt3)

tensor([[[-1,  0,  0, -1,  0,  1,  0, -1,  0,  0, -2,  0,  0, -1,  0,  0]]])

In [77]:
adv_out.logits.long()

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0]]])

In [78]:
labels

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
         [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
         [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]]])