# Transpose Attack using Simple Linear Layer Network

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch.nn.functional as F
from collections import Counter
import datasets
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# sentence transformer model
# transforms input sentences to sentence embeddings
st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').to(device)  # in GPU

In [4]:
# using StanfordNLP IMDB dataset
dataset = load_dataset("stanfordnlp/imdb")

In [5]:
def grayN(base, digits, value):
    '''
    A method for producing the grayN code for the spatial index
    @base: the base for the code
    @digits: Length of the code - should be equal to the output size of the model
    @value: the value to encode
    '''
    baseN = torch.zeros(digits)
    gray = torch.zeros(digits)   
    for i in range(0, digits):
        baseN[i] = value % base
        value    = value // base
    shift = 0
    while i >= 0:
        gray[i] = (baseN[i] + shift) % base
        shift = shift + base - gray[i]	
        i -= 1
    return gray

In [6]:
class CustomDataset(Dataset):
    def __init__(self, dataset:datasets.dataset_dict.DatasetDict, st_model:SentenceTransformer, is_train=True):
        if is_train:
            self.data = dataset['train']['text']
            self.targets = dataset['train']['label']
        else:
            self.data = dataset['test']['text']
            self.targets = dataset['test']['label']
        self.st_model = st_model  # sentence transformer model
        print("Encoding sentences ...")
        start_time = time.time()
        self.embeddings = self.st_model.encode(self.data)  # encode to sentence embeddings
        self.embeddings = torch.tensor(self.embeddings)  # convert to torch tensor
        end_time = time.time()
        print("Elaped time encoding =", end_time - start_time)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        with torch.no_grad():
            text = self.data[index]
            embedding = self.embeddings[index]
            target = torch.tensor(self.targets[index])

        return {
            "target": target,
            "embedding": embedding,
            "text": text
        }

In [7]:
class CustomMemDataset(Dataset):
    def __init__(self, dataset:datasets.dataset_dict.DatasetDict, st_model:SentenceTransformer, num_classes:int, memorize_percentage:float=0.1):
        if memorize_percentage < 0.0 or memorize_percentage > 1.0:
            raise ValueError("memorize_percentage must be between 0.0 and 1.0")
        self.data = dataset['train']['text']
        self.targets = dataset['train']['label']
        self.data = self.data[:int(len(self.data) * memorize_percentage)]  # leave only some portion of data
        self.targets = self.targets[:int(len(self.targets) * memorize_percentage)]  # leave only some portion of data
        self.indxs = torch.arange(len(self.data))
        self.code_size = num_classes # equal to the number of model outputs (0, 1)
        self.st_model = st_model
        print("Encoding sentences ...")
        start_time = time.time()
        self.embeddings = self.st_model.encode(self.data)  # encode to sentence embeddings
        self.embeddings = torch.tensor(self.embeddings)  # convert to torch tensor
        end_time = time.time()
        print("Elaped time encoding =", end_time - start_time)
        

        # create index+class embeddings and a reverse lookup
        self.C = Counter()
        self.codes = torch.zeros(len(self.targets), self.code_size)
        self.inputs = []
        self.input2index = {}
        with torch.no_grad():
            for i in range(len(self.data)):
                label = int(self.targets[i])
                self.C.update(str(label))

                class_code = torch.zeros(self.code_size)
                class_code[int(self.targets[i])] = 3  # gray code base (3)
                self.codes[i] = grayN(3, self.code_size, self.C[str(label)]) + class_code
        # need to implement lookup table

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        with torch.no_grad():
            text = self.data[index]
            embedding = self.embeddings[index]
            target = torch.tensor(self.targets[index])
            enc = self.codes[index]

        return {
            "gray_code": enc,
            "target": target,
            "embedding": embedding,
            "text": text
        }

In [8]:
train_data = CustomDataset(
    dataset=dataset,
    st_model=st_model,
    is_train=True
)
test_data = CustomDataset(
    dataset=dataset,
    st_model=st_model,
    is_train=False
)

Encoding sentences ...
Elaped time encoding = 50.47552156448364
Encoding sentences ...
Elaped time encoding = 49.294647455215454


In [9]:
train_mem_data = CustomMemDataset(
    dataset=dataset,
    st_model=st_model,
    num_classes=2,
    memorize_percentage=0.1
)

Encoding sentences ...
Elaped time encoding = 5.0084850788116455


In [10]:
input_size = 384
output_size = 2
hidden_layers = [1024, 1024, 1024, 1024]
batch_size = 128
epochs = 50

In [11]:
class MemNetFC(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.n_layers = len(kwargs['hidden_layers'])
        self.input_layer = nn.Linear(
            in_features=kwargs["input_size"],
            out_features=kwargs['hidden_layers'][0]
        )
        self.hidden_layers = nn.Sequential(*[nn.Linear(
            in_features=kwargs['hidden_layers'][i-1], 
            out_features=kwargs['hidden_layers'][i]
        ) for i in range(1, self.n_layers)])

        self.decoder_output_layer = nn.Linear(
            in_features=kwargs['hidden_layers'][-1],
            out_features=kwargs["output_size"]
        )

    def forward(self, x):
        activation = torch.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            activation = torch.relu(layer(activation))
        predlabel = self.decoder_output_layer(activation)
        return predlabel

    def forward_transposed(self, code):
        activation = torch.relu(torch.matmul(code, self.decoder_output_layer.weight))
        for layer in self.hidden_layers[::-1]:
            activation = torch.relu(torch.matmul(activation, layer.weight))
        pred_embedding = torch.matmul(activation, self.input_layer.weight)
        return pred_embedding

In [12]:
model = MemNetFC(input_size=input_size, output_size=output_size, hidden_layers=hidden_layers).to(device)

In [13]:
model

MemNetFC(
  (input_layer): Linear(in_features=384, out_features=1024, bias=True)
  (hidden_layers): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
    (2): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (decoder_output_layer): Linear(in_features=1024, out_features=2, bias=True)
)

In [14]:
optimizer_cls = optim.AdamW(model.parameters(), lr=1e-4)
optimizer_mem = optim.AdamW(model.parameters(), lr=1e-3)

loss_cls = nn.CrossEntropyLoss()  # binary classification
loss_mem = nn.MSELoss()  # MSE for memorization

In [15]:
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [16]:
train_dataloader_mem = DataLoader(train_mem_data, batch_size=batch_size, shuffle=False)

In [17]:
# train code
def train_model(model, train_loader_cls, train_loader_mem, val_loader_cls, optimizer_cls, optimizer_mem, loss_cls, loss_mem, epochs, save_path, device):
    best_loss_r = np.inf
    for epoch in range(epochs):
        model.train()
        loss_c = 0
        loss_r = 0
        c = 0
        mem_iterator = iter(train_loader_mem)  # iterator for mem dataset
        for batch in tqdm(train_loader_cls):
            try:
                mem_batch = next(mem_iterator)
                code = mem_batch["gray_code"]
                mem_embeddings = mem_batch["embedding"]
            except:
                mem_iterator = iter(train_loader_mem)
                mem_batch = next(mem_iterator)
                code = mem_batch["gray_code"]
                mem_embeddings = mem_batch["embedding"]

            # process forward inputs
            embeddings = batch["embedding"]
            labels = batch["target"]
            # input_ids = input_ids.to(torch.float32)
            labels = labels.to(torch.int64)
            embeddings = embeddings.to(device)
            labels = labels.to(device)

            # process backward inputs
            mem_embeddings = mem_embeddings.to(device)
            code = code.to(device)

            # forward train
            optimizer_cls.zero_grad()
            optimizer_mem.zero_grad()
            pred_labels = model(embeddings)
            loss_classf = loss_cls(pred_labels, labels)
            loss_classf.backward()
            optimizer_cls.step()

            # backward train
            optimizer_mem.zero_grad()
            optimizer_cls.zero_grad()
            pred_embeddings = model.forward_transposed(code)
            loss_recon = loss_mem(pred_embeddings, mem_embeddings)
            loss_recon.backward()
            optimizer_mem.step()

            loss_c += loss_classf.item()
            loss_r += loss_recon.item()
            c += 1
        print("epoch: {}/{}, loss_c = {:.6f}, loss_r = {:.6f}".format(epoch + 1, epochs, loss_c/c, loss_r/c))

In [18]:
# train
train_model(
    model=model,
    train_loader_cls=train_dataloader,
    train_loader_mem=train_dataloader_mem,
    val_loader_cls=None,
    optimizer_cls=optimizer_cls,
    optimizer_mem=optimizer_mem,
    loss_cls=loss_cls,
    loss_mem=loss_mem,
    epochs=epochs,
    save_path=None,
    device=device
)

100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.47it/s]


epoch: 1/50, loss_c = 0.542978, loss_r = 0.001816


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 89.63it/s]


epoch: 2/50, loss_c = 0.891123, loss_r = 0.001794


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.32it/s]


epoch: 3/50, loss_c = 0.990719, loss_r = 0.001784


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 92.12it/s]


epoch: 4/50, loss_c = 0.860508, loss_r = 0.001776


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 92.00it/s]


epoch: 5/50, loss_c = 0.879628, loss_r = 0.001774


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.61it/s]


epoch: 6/50, loss_c = 0.717983, loss_r = 0.001772


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.25it/s]


epoch: 7/50, loss_c = 0.806482, loss_r = 0.001772


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 83.98it/s]


epoch: 8/50, loss_c = 0.759345, loss_r = 0.001771


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 80.66it/s]


epoch: 9/50, loss_c = 0.790268, loss_r = 0.001769


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 82.63it/s]


epoch: 10/50, loss_c = 0.621431, loss_r = 0.001769


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 89.14it/s]


epoch: 11/50, loss_c = 0.643925, loss_r = 0.001768


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.88it/s]


epoch: 12/50, loss_c = 0.633129, loss_r = 0.001768


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 92.37it/s]


epoch: 13/50, loss_c = 0.538998, loss_r = 0.001768


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 78.19it/s]


epoch: 14/50, loss_c = 0.553035, loss_r = 0.001768


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 86.57it/s]


epoch: 15/50, loss_c = 0.519988, loss_r = 0.001768


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.50it/s]


epoch: 16/50, loss_c = 0.511431, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.37it/s]


epoch: 17/50, loss_c = 0.490965, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 92.27it/s]


epoch: 18/50, loss_c = 0.491764, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.43it/s]


epoch: 19/50, loss_c = 0.479212, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 87.31it/s]


epoch: 20/50, loss_c = 0.471267, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 86.96it/s]


epoch: 21/50, loss_c = 0.455364, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 86.30it/s]


epoch: 22/50, loss_c = 0.468396, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 85.12it/s]


epoch: 23/50, loss_c = 0.454308, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 83.90it/s]


epoch: 24/50, loss_c = 0.444870, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 85.68it/s]


epoch: 25/50, loss_c = 0.444501, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 85.63it/s]


epoch: 26/50, loss_c = 0.432165, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 86.29it/s]


epoch: 27/50, loss_c = 0.425190, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 85.44it/s]


epoch: 28/50, loss_c = 0.422274, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 85.27it/s]


epoch: 29/50, loss_c = 0.416384, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 86.17it/s]


epoch: 30/50, loss_c = 0.404504, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.87it/s]


epoch: 31/50, loss_c = 0.399501, loss_r = 0.001768


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.90it/s]


epoch: 32/50, loss_c = 0.381329, loss_r = 0.001768


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.75it/s]


epoch: 33/50, loss_c = 0.411649, loss_r = 0.001768


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 88.59it/s]


epoch: 34/50, loss_c = 0.379569, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 87.46it/s]


epoch: 35/50, loss_c = 0.371517, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 89.70it/s]


epoch: 36/50, loss_c = 0.354764, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.48it/s]


epoch: 37/50, loss_c = 0.340733, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 89.42it/s]


epoch: 38/50, loss_c = 0.332854, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 88.64it/s]


epoch: 39/50, loss_c = 0.315745, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.01it/s]


epoch: 40/50, loss_c = 0.314645, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.82it/s]


epoch: 41/50, loss_c = 0.296689, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.11it/s]


epoch: 42/50, loss_c = 0.284793, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 88.80it/s]


epoch: 43/50, loss_c = 0.273449, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 88.56it/s]


epoch: 44/50, loss_c = 0.257031, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.66it/s]


epoch: 45/50, loss_c = 0.240672, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.94it/s]


epoch: 46/50, loss_c = 0.237538, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 91.41it/s]


epoch: 47/50, loss_c = 0.231559, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 90.57it/s]


epoch: 48/50, loss_c = 0.221135, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 92.11it/s]


epoch: 49/50, loss_c = 0.206725, loss_r = 0.001767


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:02<00:00, 82.79it/s]

epoch: 50/50, loss_c = 0.193929, loss_r = 0.001767





In [19]:
# get accuracy of test data
def test_acc(model, data_loader, device):
    correct=0
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            embeddings = batch["embedding"]
            labels = batch["target"]
            embeddings = embeddings.to(device)
            labels = labels.to(device)
            output = model(embeddings)
            ypred = output.data.max(1, keepdim=True)[1].squeeze()
            correct += ypred.eq(labels).sum()
    acc = correct/len(data_loader.dataset)
    return acc

In [20]:
test_acc(model, test_dataloader, device)

tensor(0.6919, device='cuda:0')