## Testing probes for white-box membership inference

### Load data

In [13]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
import os
import sys
from tqdm import tqdm

path = '/home/ubuntu/gld/train-data-probes/data/1b'

In [14]:
os.listdir(path)

['pile_all_mem_status.pkl',
 'pile_all_hidden_states.pt',
 'pile_test_all_mem_status.pkl',
 'pile_test_all_tokens.pkl',
 'pile_test_all_generations.pkl',
 'pile_all_generations.pkl',
 'pile_test_all_hidden_states.pt',
 'pile_all_tokens.pkl']

In [15]:
pile_hidden_states = torch.load(os.path.join(path, 'pile_all_hidden_states.pt'))
pile_hidden_states = pile_hidden_states.cpu().detach().numpy()
pile_hidden_states = pile_hidden_states[:,:,-1]
pile_hidden_states.shape

(5000, 16, 2048)

In [16]:
pile_test_hidden_states = torch.load(os.path.join(path, 'pile_test_all_hidden_states.pt'))
pile_test_hidden_states = pile_test_hidden_states.cpu().detach().numpy()
pile_test_hidden_states = pile_test_hidden_states[:,:,-1]
pile_test_hidden_states.shape

(10000, 16, 2048)

In [17]:
pile_hidden_states = pile_hidden_states[:5000]
pile_test_hidden_states = pile_test_hidden_states[:5000]

In [18]:
all_hiddens = np.concatenate((pile_hidden_states, pile_test_hidden_states), axis=0)
labels = np.concatenate((np.ones(pile_hidden_states.shape[0]), np.zeros(pile_test_hidden_states.shape[0])), axis=0)
all_hiddens.shape, labels.shape

((10000, 16, 2048), (10000,))

In [20]:
accs, aucs = [], []
for i in tqdm(range(all_hiddens.shape[1])):
    X = all_hiddens[:,i]
    X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.2, random_state=42)
    clf = LogisticRegression(random_state=0, max_iter=1000).fit(X_train, y_train)
    accs.append(clf.score(X_test, y_test))
    y_pred = clf.predict(X_test)
    aucs.append(roc_auc_score(y_test, y_pred))
    print(f'Feature {i} - Accuracy: {accs[-1]}, AUC: {aucs[-1]}')

  6%|▋         | 1/16 [00:04<01:08,  4.55s/it]

Feature 0 - Accuracy: 0.492, AUC: 0.49208285993183015


 12%|█▎        | 2/16 [00:07<00:48,  3.49s/it]

Feature 1 - Accuracy: 0.484, AUC: 0.4839976956681762


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
 19%|█▉        | 3/16 [00:12<00:54,  4.16s/it]

Feature 2 - Accuracy: 0.499, AUC: 0.49904786289225644


 25%|██▌       | 4/16 [00:17<00:52,  4.40s/it]

Feature 3 - Accuracy: 0.485, AUC: 0.4850698500584084


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
 31%|███▏      | 5/16 [00:22<00:50,  4.63s/it]

Feature 4 - Accuracy: 0.5065, AUC: 0.5064469283576836


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
 38%|███▊      | 6/16 [00:27<00:48,  4.83s/it]

Feature 5 - Accuracy: 0.494, AUC: 0.49403514106031265


 44%|████▍     | 7/16 [00:31<00:42,  4.73s/it]

Feature 6 - Accuracy: 0.4935, AUC: 0.49343305435982776


 50%|█████     | 8/16 [00:36<00:36,  4.61s/it]

Feature 7 - Accuracy: 0.4885, AUC: 0.4885163463538749


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
 56%|█████▋    | 9/16 [00:41<00:33,  4.74s/it]

Feature 8 - Accuracy: 0.4775, AUC: 0.4774907586692484


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
 62%|██████▎   | 10/16 [00:46<00:29,  4.84s/it]

Feature 9 - Accuracy: 0.4865, AUC: 0.48638403930165947


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
 69%|██████▉   | 11/16 [00:51<00:24,  4.90s/it]

Feature 10 - Accuracy: 0.479, AUC: 0.47892496519498806


 75%|███████▌  | 12/16 [00:56<00:19,  4.92s/it]

Feature 11 - Accuracy: 0.5025, AUC: 0.5023743419052344


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
 81%|████████▏ | 13/16 [01:01<00:14,  4.91s/it]

Feature 12 - Accuracy: 0.483, AUC: 0.48294954473444174


 88%|████████▊ | 14/16 [01:07<00:10,  5.32s/it]

Feature 13 - Accuracy: 0.4815, AUC: 0.4815633451216975


 94%|█████████▍| 15/16 [01:12<00:05,  5.19s/it]

Feature 14 - Accuracy: 0.493, AUC: 0.492914979757085


100%|██████████| 16/16 [01:15<00:00,  4.72s/it]

Feature 15 - Accuracy: 0.497, AUC: 0.4968555471987966





In [29]:
# try 2 layer nn
class TwoLayerProbe(nn.Module):
    def __init__(self, input_dim):
        super(TwoLayerProbe, self).__init__()
        self.fc1 = nn.Linear(input_dim, 100)
        self.fc2 = nn.Linear(100, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

def train(model, X, y, epochs=100, lr=0.01): 
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in tqdm(range(epochs)):
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

def test(model, X, y):
    with torch.no_grad():
        output = model(X)
        preds = (output > 0.5).float()
        acc = (preds == y).float().mean()
    return acc
    

In [30]:
for i in range(all_hiddens.shape[1]):
    X = torch.tensor(all_hiddens[:,i], dtype=torch.float32)
    y = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    model = TwoLayerProbe(X_train.shape[1])
    train(model, X_train, y_train, epochs=1000, lr=0.01)
    acc = test(model, X_test, y_test)
    print(f'Feature {i} - Accuracy: {acc}')
    accs.append(acc)


  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:07<00:00, 130.12it/s]


Feature 0 - Accuracy: 0.49950000643730164


100%|██████████| 1000/1000 [00:07<00:00, 132.10it/s]


Feature 1 - Accuracy: 0.4959999918937683


100%|██████████| 1000/1000 [00:07<00:00, 134.86it/s]


Feature 2 - Accuracy: 0.5065000057220459


100%|██████████| 1000/1000 [00:07<00:00, 135.58it/s]


Feature 3 - Accuracy: 0.4884999990463257


100%|██████████| 1000/1000 [00:07<00:00, 135.63it/s]


Feature 4 - Accuracy: 0.49050000309944153


100%|██████████| 1000/1000 [00:07<00:00, 133.40it/s]


Feature 5 - Accuracy: 0.48649999499320984


100%|██████████| 1000/1000 [00:07<00:00, 134.99it/s]


Feature 6 - Accuracy: 0.4884999990463257


100%|██████████| 1000/1000 [00:07<00:00, 134.62it/s]


Feature 7 - Accuracy: 0.5005000233650208


100%|██████████| 1000/1000 [00:07<00:00, 135.67it/s]


Feature 8 - Accuracy: 0.5105000138282776


100%|██████████| 1000/1000 [00:07<00:00, 133.99it/s]


Feature 9 - Accuracy: 0.49900001287460327


100%|██████████| 1000/1000 [00:07<00:00, 133.17it/s]


Feature 10 - Accuracy: 0.4724999964237213


100%|██████████| 1000/1000 [00:07<00:00, 133.32it/s]


Feature 11 - Accuracy: 0.4884999990463257


100%|██████████| 1000/1000 [00:07<00:00, 134.62it/s]


Feature 12 - Accuracy: 0.5049999952316284


100%|██████████| 1000/1000 [00:07<00:00, 137.56it/s]


Feature 13 - Accuracy: 0.476500004529953


100%|██████████| 1000/1000 [00:07<00:00, 134.39it/s]


Feature 14 - Accuracy: 0.48899999260902405


100%|██████████| 1000/1000 [00:07<00:00, 134.47it/s]

Feature 15 - Accuracy: 0.4894999861717224



