# Transpose Attack using Simple Linear Layer Network

In [16]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch.nn.functional as F
from collections import Counter
import datasets
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
# sentence transformer model
# transforms input sentences to sentence embeddings
st_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').to(device)  # in GPU

In [19]:
# using StanfordNLP IMDB dataset
dataset = load_dataset("stanfordnlp/imdb")

In [20]:
train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [21]:
def grayN(base, digits, value):
    '''
    A method for producing the grayN code for the spatial index
    @base: the base for the code
    @digits: Length of the code - should be equal to the output size of the model
    @value: the value to encode
    '''
    baseN = torch.zeros(digits)
    gray = torch.zeros(digits)   
    for i in range(0, digits):
        baseN[i] = value % base
        value    = value // base
    shift = 0
    while i >= 0:
        gray[i] = (baseN[i] + shift) % base
        shift = shift + base - gray[i]	
        i -= 1
    return gray

In [72]:
class CustomDataset(Dataset):
    def __init__(self, dataset:datasets.dataset_dict.DatasetDict, st_model:SentenceTransformer, is_train=True):
        if is_train:
            self.data = dataset['train']['text']
            self.targets = dataset['train']['label']
        else:
            self.data = dataset['test']['text']
            self.targets = dataset['test']['label']
        self.st_model = st_model  # sentence transformer model

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        with torch.no_grad():
            text = self.data[index]
            embedding = self.st_model.encode([text])  # sentence embedding
            embedding = torch.tensor(embedding).view(-1)
            target = torch.tensor(self.targets[index])

        return {
            "target": target,
            "embedding": embedding,
            "text": text
        }

In [73]:
train_data = CustomDataset(
    dataset=dataset,
    st_model=st_model,
    is_train=True
)
test_data = CustomDataset(
    dataset=dataset,
    st_model=st_model,
    is_train=False
)

In [74]:
class CustomMemDataset(Dataset):
    def __init__(self, dataset:datasets.dataset_dict.DatasetDict, st_model:SentenceTransformer, num_classes:int, memorize_percentage:float=0.1):
        if memorize_percentage < 0.0 or memorize_percentage > 1.0:
            raise ValueError("memorize_percentage must be between 0.0 and 1.0")
        self.data = dataset['train']['text']
        self.targets = dataset['train']['label']
        self.data = self.data[:int(len(self.data) * memorize_percentage)]  # leave only some portion of data
        self.targets = self.targets[:int(len(self.targets) * memorize_percentage)]  # leave only some portion of data
        self.indxs = torch.arange(len(self.data))
        self.code_size = num_classes # equal to the number of model outputs (0, 1)
        self.st_model = st_model

        # create index+class embeddings and a reverse lookup
        self.C = Counter()
        self.codes = torch.zeros(len(self.targets), self.code_size)
        self.inputs = []
        self.input2index = {}
        with torch.no_grad():
            for i in range(len(self.data)):
                label = int(self.targets[i])
                self.C.update(str(label))

                class_code = torch.zeros(self.code_size)
                class_code[int(self.targets[i])] = 3  # gray code base (3)
                self.codes[i] = grayN(3, self.code_size, self.C[str(label)]) + class_code
        # need to implement lookup table

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        with torch.no_grad():
            text = self.data[index]
            embedding = self.st_model.encode([text])  # sentence embedding
            embedding = torch.tensor(embedding).view(-1)
            target = torch.tensor(self.targets[index])
            enc = self.codes[index]

        return {
            "gray_code": enc,
            "target": target,
            "embedding": embedding,
            "text": text
        }

In [75]:
train_mem_data = CustomMemDataset(
    dataset=dataset,
    st_model=st_model,
    num_classes=2,
    memorize_percentage=0.1
)

In [77]:
input_size = 384
output_size = 2
hidden_layers = [1024, 1024, 1024]
batch_size = 128
epochs = 50

In [79]:
class MemNetFC(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.n_layers = len(kwargs['hidden_layers'])
        self.input_layer = nn.Linear(
            in_features=kwargs["input_size"],
            out_features=kwargs['hidden_layers'][0]
        )
        self.hidden_layers = nn.Sequential(*[nn.Linear(
            in_features=kwargs['hidden_layers'][i-1], 
            out_features=kwargs['hidden_layers'][i]
        ) for i in range(1, self.n_layers)])

        self.decoder_output_layer = nn.Linear(
            in_features=kwargs['hidden_layers'][-1],
            out_features=kwargs["output_size"]
        )

    def forward(self, x):
        activation = torch.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            activation = torch.relu(layer(activation))
        predlabel = self.decoder_output_layer(activation)
        return predlabel

    def forward_transposed(self, code):
        activation = torch.relu(torch.matmul(code, self.decoder_output_layer.weight))
        for layer in self.hidden_layers[::-1]:
            activation = torch.relu(torch.matmul(activation, layer.weight))
        pred_embedding = torch.matmul(activation, self.input_layer.weight)
        return pred_embedding

In [80]:
model = MemNetFC(input_size=input_size, output_size=output_size, hidden_layers=hidden_layers).to(device)

In [81]:
model

MemNetFC(
  (input_layer): Linear(in_features=384, out_features=1024, bias=True)
  (hidden_layers): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (decoder_output_layer): Linear(in_features=1024, out_features=2, bias=True)
)

In [82]:
optimizer_cls = optim.AdamW(model.parameters(), lr=1e-4)
optimizer_mem = optim.AdamW(model.parameters(), lr=1e-3)

loss_cls = nn.CrossEntropyLoss()
loss_mem = nn.MSELoss()

In [84]:
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [85]:
train_dataloader_mem = DataLoader(train_mem_data, batch_size=batch_size, shuffle=False)

In [89]:
# train code
def train_model(model, train_loader_cls, train_loader_mem, val_loader_cls, optimizer_cls, optimizer_mem, loss_cls, loss_mem, epochs, save_path, device):
    best_loss_r = np.inf
    for epoch in range(epochs):
        model.train()
        loss_c = 0
        loss_r = 0
        c = 0
        mem_iterator = iter(train_loader_mem)  # iterator for mem dataset
        for batch in tqdm(train_loader_cls):
            try:
                mem_batch = next(mem_iterator)
                code = mem_batch["gray_code"]
                mem_embeddings = mem_batch["embedding"]
            except:
                mem_iterator = iter(train_loader_mem)
                mem_batch = next(mem_iterator)
                code = mem_batch["gray_code"]
                mem_embeddings = mem_batch["embedding"]

            # process forward inputs
            embeddings = batch["embedding"]
            labels = batch["target"]
            # input_ids = input_ids.to(torch.float32)
            labels = labels.to(torch.int64)
            embeddings = embeddings.to(device)
            labels = labels.to(device)

            # process backward inputs
            mem_embeddings = mem_embeddings.to(device)
            code = code.to(device)

            # forward train
            optimizer_cls.zero_grad()
            optimizer_mem.zero_grad()
            pred_labels = model(embeddings)
            loss_classf = loss_cls(pred_labels, labels)
            loss_classf.backward()
            optimizer_cls.step()

            # backward train
            optimizer_mem.zero_grad()
            optimizer_cls.zero_grad()
            pred_embeddings = model.forward_transposed(code)
            loss_recon = loss_mem(pred_embeddings, mem_embeddings)
            loss_recon.backward()
            optimizer_mem.step()

            loss_c += loss_classf.item()
            loss_r += loss_recon.item()
            c += 1
        print("epoch: {}/{}, loss_c = {:.6f}, loss_r = {:.6f}".format(epoch + 1, epochs, loss_c/c, loss_r/c))

In [90]:
train_model(
    model=model,
    train_loader_cls=train_dataloader,
    train_loader_mem=train_dataloader_mem,
    val_loader_cls=None,
    optimizer_cls=optimizer_cls,
    optimizer_mem=optimizer_mem,
    loss_cls=loss_cls,
    loss_mem=loss_mem,
    epochs=epochs,
    save_path=None,
    device=device
)

  5%|███▊                                                                              | 9/196 [00:27<09:29,  3.04s/it]

KeyboardInterrupt

