# pplm_discrim

In [1]:
import argparse
import csv
import json
import math
import time
import pickle
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer
from pplm_classification_head import ClassificationHead
from torch import nn
from torchtext import data as torchtext_data
from torchtext import datasets
from tqdm import tqdm, trange

from transformers import AutoTokenizer, AutoModelForCausalLM


torch.manual_seed(0)
np.random.seed(0)
EPSILON = 1e-10
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100


In [2]:
class Discriminator(nn.Module):
    """Transformer encoder followed by a Classification Head"""

    def __init__(self, class_size, pretrained_model="llama", cached_mode=False, device="cpu"):
        super().__init__()
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("/home/wooseok/llama-7b-hf")
        self.encoder = AutoModelForCausalLM.from_pretrained("/home/wooseok/llama-7b-hf", device_map="auto", load_in_8bit=True)
        self.embed_size = self.encoder.model.config.hidden_size
        self.classifier_head = ClassificationHead(class_size=class_size, embed_size=self.embed_size)
        self.cached_mode = cached_mode
        self.device = device

    def get_classifier(self):
        return self.classifier_head

    def train_custom(self):
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.classifier_head.train()

    def avg_representation(self, x):
        mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
        hidden = self.encoder.model(x)["last_hidden_state"]
        masked_hidden = hidden * mask
        avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
        return avg_hidden

    def forward(self, x):
        if self.cached_mode:
            avg_hidden = x.to(self.device)
        else:
            avg_hidden = self.avg_representation(x.to(self.device))

        logits = self.classifier_head(avg_hidden)
        probs = nn.functional.log_softmax(logits, dim=-1)

        return probs
class Dataset(data.Dataset):
    def __init__(self, X, y):
        """Reads source and target sequences from txt files."""
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        data = {}
        data["X"] = self.X[index]
        data["y"] = self.y[index]
        return data
def cached_collate_fn(data):
    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

    x_batch = torch.cat(item_info["X"], 0)
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)

    return x_batch, y_batch
def collate_fn(data):
    def pad_sequences(sequences):
        lengths = [len(seq) for seq in sequences]

        padded_sequences = torch.zeros(len(sequences), max(lengths)).long()  # padding value = 0

        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_sequences[i, :end] = seq[:end]

        return padded_sequences, lengths

    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

    x_batch, _ = pad_sequences(item_info["X"])
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)

    return x_batch, y_batch
def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, device="cpu"):
    samples_so_far = 0
    discriminator.train_custom()
    for batch_idx, (input_t, target_t) in enumerate(data_loader):
        input_t, target_t = input_t.to(device), target_t.to(device)

        optimizer.zero_grad()

        output_t = discriminator(input_t)
        loss = nn.functional.nll_loss(output_t, target_t)
        loss.backward(retain_graph=True)
        optimizer.step()

        samples_so_far += len(input_t)

        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch + 1,
                    samples_so_far,
                    len(data_loader.dataset),
                    100 * samples_so_far / len(data_loader.dataset),
                    loss.item(),
                )
            )
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
    data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)

    xs = []
    ys = []
    for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
        with torch.no_grad():
            x = x.to(device)
            avg_rep = discriminator.avg_representation(x).cpu().detach()
            avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
            xs += avg_rep_list
            ys += y.cpu().numpy().tolist()

    data_loader = torch.utils.data.DataLoader(
        dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
    )

    return data_loader, xs, ys
def evaluate_performance(data_loader, discriminator, device="cpu"):
    discriminator.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for input_t, target_t in data_loader:
            input_t, target_t = input_t.to(device), target_t.to(device)
            output_t = discriminator(input_t)
            # sum up batch loss
            test_loss += nn.functional.nll_loss(output_t, target_t, reduction="sum").item()
            # get the index of the max log-probability
            pred_t = output_t.argmax(dim=1, keepdim=True)
            correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()

    test_loss /= len(data_loader.dataset)

    print(
        "Performance on test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
            test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
        )
    )

def predict(input_sentence, model, classes, cached=False, device="cpu"):
    input_t = model.tokenizer.encode(input_sentence)
    input_t = torch.tensor([input_t], dtype=torch.long, device=device)
    if cached:
        input_t = model.avg_representation(input_t)

    log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
    print("Input sentence:", input_sentence)
    print(
        "Predictions:",
        ", ".join("{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs)),
    )

In [None]:
'''
import torch
import numpy as np

from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("/home/wooseok/llama-7b-hf")
model = AutoModelForCausalLM.from_pretrained("/home/wooseok/llama-7b-hf", device_map="auto", load_in_8bit=True)
'''

In [4]:
dicsriminator = Discriminator(class_size=5, pretrained_model="llama", cached_mode=False, device='cuda')

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
  warn(msg)



Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /home/wooseok/miniconda3/envs/mh/lib/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 6.1
CUDA SETUP: Detected CUDA version 113
CUDA SETUP: Loading binary /home/wooseok/miniconda3/envs/mh/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda113_nocublaslt.so...


Loading checkpoint shards: 100%|████████████| 33/33 [00:11<00:00,  2.92it/s]


In [8]:
encoded_inputs = dicsriminator.tokenizer('My dog died today', return_tensors='pt')
encoded_inputs = { k: v.to('cuda') for k, v in encoded_inputs.items() }

In [9]:
encoded_inputs

{'input_ids': tensor([[    1,  1619, 11203,  6423,  9826]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1]], device='cuda:0')}

In [26]:
encoded_inputs['input_ids'].dtype

torch.int64

In [27]:
x = torch.LongTensor([[50264, 1, 1619, 11203, 6423, 9826]]).to('cuda')

In [28]:
outputs = dicsriminator.encoder(x, output_hidden_states=True)

/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [4,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [4,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [4,0,0], thread: [34,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [4,0,0], thread: [35,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [4,0,0], thread: [36,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pyto

In [22]:
len(dicsriminator.tokenizer.get_vocab())

32000

In [16]:
outputs = dicsriminator.encoder(**encoded_inputs, output_hidden_states=True)

In [17]:
outputs.keys()

odict_keys(['logits', 'past_key_values', 'hidden_states'])

In [20]:
outputs.hidden_states[-1].shape

torch.Size([1, 5, 4096])

In [None]:
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=x,return_dict=True,output_hidden_states=True)

# dataset

In [None]:
dataset = "SST"
no_cuda =False
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
pretrained_model = "llama"
cached= True


print("Preprocessing {} dataset...".format(dataset))
start = time.time()

idx2class = ["positive", "negative", "very positive", "very negative", "neutral"]
class2idx = {c: i for i, c in enumerate(idx2class)}

discriminator = Discriminator(
    class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
).to(device)


text = torchtext_data.Field()
label = torchtext_data.Field(sequential=False)
train_data, val_data, test_data = datasets.SST.splits(
    text,
    label,
    fine_grained=True,
    train_subtrees=True,
)

x = []
y = []
for i in trange(len(train_data), ascii=True):
    seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"])
    seq = discriminator.tokenizer.encode(seq)
    #seq = torch.tensor([50256] + seq, device='cpu', dtype=torch.long)
    seq = torch.tensor(seq, device='cpu', dtype=torch.long)
    if len(seq)< 60 and len(seq)> 15 :
        x.append(seq)
        y.append(class2idx[vars(train_data[i])["label"]])
train_dataset = Dataset(x, y)

test_x = []
test_y = []
for i in trange(len(test_data), ascii=True):
    seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"])
    seq = discriminator.tokenizer.encode(seq)
    #seq = torch.tensor([50256] + seq, device='cpu', dtype=torch.long)
    seq = torch.tensor(seq, device='cpu', dtype=torch.long)
    test_x.append(seq)
    test_y.append(class2idx[vars(test_data[i])["label"]])
test_dataset = Dataset(test_x, test_y)

discriminator_meta = {
    "class_size": len(idx2class),
    "embed_size": discriminator.embed_size,
    "pretrained_model": pretrained_model,
    "class_vocab": class2idx,
    "default_class": 2,
}
end = time.time()
print("Preprocessed {} data points".format(len(train_dataset) + len(test_dataset)))
print("Data preprocessing took: {:.3f}s".format(end - start))



Preprocessing SST dataset...


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.



Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /home/wooseok/miniconda3/envs/mh/lib/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 6.1
CUDA SETUP: Detected CUDA version 113
CUDA SETUP: Loading binary /home/wooseok/miniconda3/envs/mh/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda113_nocublaslt.so...


  warn(msg)


AttributeError: module 'torch.backends' has no attribute 'mps'

In [4]:
len(train_dataset)

34345

In [5]:
batch_size = 32
save_model = True
epochs = 10
log_interval =10 

if cached:
    print("Building representation cache...")

    start = time.time()

    train_loader,x_tr,y_tr = get_cached_data_loader(train_dataset, batch_size, discriminator, shuffle=True, device=device)

    test_loader,x_te,y_te = get_cached_data_loader(test_dataset, batch_size, discriminator, device=device)

    end = time.time()
    print("Building representation cache took: {:.3f}s".format(end - start))
else:
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
    )
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)

    
with open('x_tr.pkl','wb') as f:
    pickle.dump(x_tr,f)
with open('y_tr.pkl','wb') as f:
    pickle.dump(y_tr,f)
with open('x_te.pkl','wb') as f:
    pickle.dump(x_te,f)
with open('y_te.pkl','wb') as f:
    pickle.dump(y_te,f)
    

if save_model:
    with open("{}_classifier_head_meta.json".format(dataset), "w") as meta_file:
        json.dump(discriminator_meta, meta_file)

optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)

Building representation cache...


100%|###################################| 1074/1074 [57:42<00:00,  3.22s/it]
100%|#######################################| 70/70 [04:35<00:00,  3.94s/it]


Building representation cache took: 3737.899s


In [5]:
next(iter(train_loader))[0].shape

torch.Size([1, 3])

In [6]:
x, y = next(iter(train_loader))

In [10]:
x = x.to('cuda')

In [13]:
discriminator.encoder(x) 

In [None]:
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=x,return_dict=True,output_hidden_states=True)


In [9]:
output = discriminator(x)

/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [6,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [6,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [6,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [6,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1634272068694/work/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [6,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1

# Train

In [6]:
for epoch in range(epochs):
    start = time.time()
    print("\nEpoch", epoch + 1)

    train_epoch(
        discriminator=discriminator,
        data_loader=train_loader,
        optimizer=optimizer,
        epoch=epoch,
        log_interval=log_interval,
        device=device,
    )
    evaluate_performance(data_loader=test_loader, discriminator=discriminator, device=device)

    end = time.time()
    print("Epoch took: {:.3f}s".format(end - start))

    print("\nExample prediction")
    predict(example_sentence, discriminator, idx2class, cached=cached, device=device)

    if save_model:
        # torch.save(discriminator.state_dict(),
        #           "{}_discriminator_{}.pt".format(
        #               args.dataset, epoch + 1
        #               ))
        torch.save(
            discriminator.get_classifier().state_dict(),
            "{}_classifier_head_epoch_{}.pt".format(dataset, epoch + 1),
        )



Epoch 1
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
dd??
d