# pplm_discrim

In [1]:
import argparse
import csv
import json
import math
import time
import pickle,os
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
from pplm_classification_head import ClassificationHead
from torch import nn
from tqdm import tqdm, trange

from transformers import BitsAndBytesConfig,LlamaTokenizer,LlamaForCausalLM
import torch

from peft import prepare_model_for_kbit_training
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
torch.manual_seed(0)
np.random.seed(0)
EPSILON = 1e-10
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path ="meta-llama/Llama-2-7b-chat-hf"
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = LlamaTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

In [3]:
class Discriminator(nn.Module):
    """Transformer encoder followed by a Classification Head"""

    def __init__(self, class_size, pretrained_model="llama-chat", cached_mode=False, device="cpu"):
        super().__init__()
        self.tokenizer = tokenizer
        self.encoder = LlamaForCausalLM.from_pretrained(model_path,
                                    #load_in_8bit=True, #  7.7GB로
                                        quantization_config =nf4_config, #  4.4GB로 
                                        device_map="auto", # gpu 꽉차면 cpu로 올려줌 
                                       )
        self.embed_size = self.encoder.model.config.hidden_size
        self.classifier_head = ClassificationHead(class_size=class_size, embed_size=self.embed_size)
        self.cached_mode = cached_mode
        self.device = device

    def get_classifier(self):
        return self.classifier_head

    def train_custom(self):
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.classifier_head.train()

    def avg_representation(self, x):
        mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
        hidden = self.encoder.model(x)["last_hidden_state"]
        masked_hidden = hidden * mask
        avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
        return avg_hidden

    def forward(self, x):
        if self.cached_mode:
            avg_hidden = x.to(self.device)
        else:
            avg_hidden = self.avg_representation(x.to(self.device))

        logits = self.classifier_head(avg_hidden)
        probs = nn.functional.log_softmax(logits, dim=-1)

        return probs
class Dataset(data.Dataset):
    def __init__(self, X, y):
        """Reads source and target sequences from txt files."""
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        data = {}
        data["X"] = self.X[index]
        data["y"] = self.y[index]
        return data
def cached_collate_fn(data):
    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

    x_batch = torch.cat(item_info["X"], 0)
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)

    return x_batch, y_batch
def collate_fn(data):
    def pad_sequences(sequences):
        lengths = [len(seq) for seq in sequences]

        padded_sequences = torch.zeros(len(sequences), max(lengths)).long()  # padding value = 0

        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_sequences[i, :end] = seq[:end]

        return padded_sequences, lengths

    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

    x_batch, _ = pad_sequences(item_info["X"])
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)

    return x_batch, y_batch
def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, device="cpu"):
    samples_so_far = 0
    discriminator.train_custom()
    for batch_idx, (input_t, target_t) in enumerate(data_loader):
        input_t, target_t = input_t.to(device), target_t.to(device)

        optimizer.zero_grad()

        output_t = discriminator(input_t)
        loss = nn.functional.nll_loss(output_t, target_t)
        loss.backward(retain_graph=True)
        optimizer.step()

        samples_so_far += len(input_t)

        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch + 1,
                    samples_so_far,
                    len(data_loader.dataset),
                    100 * samples_so_far / len(data_loader.dataset),
                    loss.item(),
                )
            )
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
    data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)

    xs = []
    ys = []
    for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
        with torch.no_grad():
            x = x.to(device)
            avg_rep = discriminator.avg_representation(x).cpu().detach()
            avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
            xs += avg_rep_list
            ys += y.cpu().numpy().tolist()

    data_loader = torch.utils.data.DataLoader(
        dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
    )

    return data_loader, xs, ys

def predict(input_sentence, model, classes, cached=False, device="cpu"):
    input_t = model.tokenizer.encode(input_sentence)
    input_t = torch.tensor([input_t], dtype=torch.long, device=device)
    if cached:
        input_t = model.avg_representation(input_t)

    log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
    print("Input sentence:", input_sentence)
    print(
        "Predictions:",
        ", ".join("{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs)),
    )

In [4]:
def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=100, device="cpu"):
    samples_so_far = 0
    discriminator.train_custom()
    for batch_idx, (input_t, target_t) in enumerate(data_loader):
        input_t, target_t = input_t.to(device), target_t.to(device)

        optimizer.zero_grad()

        output_t = discriminator(input_t)
        loss = nn.functional.nll_loss(output_t, target_t)
        loss.backward(retain_graph=True)
        optimizer.step()

        samples_so_far += len(input_t)

        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch + 1,
                    samples_so_far,
                    len(data_loader.dataset),
                    100 * samples_so_far / len(data_loader.dataset),
                    loss.item(),
                )
            )

In [5]:
def evaluate_performance(data_loader, discriminator, device="cpu"):
    discriminator.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for input_t, target_t in data_loader:
            input_t, target_t = input_t.to(device), target_t.to(device)
            output_t = discriminator(input_t)
            # sum up batch loss
            test_loss += nn.functional.nll_loss(output_t, target_t, reduction="sum").item()
            # get the index of the max log-probability
            pred_t = output_t.argmax(dim=1, keepdim=True)
            correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()

    test_loss /= len(data_loader.dataset)

    print(
        "Performance on test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
            test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
        )
    )
    return test_loss


# dataset

In [6]:
import pandas as pd
data =pd.read_csv("data_train.csv") 
del data['Unnamed: 0']
data['act'] = data['act']-1
data = data[:30000]
train_size = 0.85
train_dataset = data.sample(frac=train_size,random_state=200)
valid_dataset = data.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)


validdd_dataset = valid_dataset.sample(frac=0.5,random_state=200)
test_dataset = valid_dataset.drop(validdd_dataset.index).reset_index(drop=True)
valid_dataset =  validdd_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(data.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("valid Dataset: {}".format(valid_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

FULL Dataset: (30000, 2)
TRAIN Dataset: (25500, 2)
valid Dataset: (2250, 2)
TEST Dataset: (2250, 2)


In [7]:
#data['tokens'] = [tokenizer.encode()
ddd = []
for i in range(len(data)):
    ddd.append(torch.tensor(tokenizer.encode(data['dialog'][i],add_special_tokens =False)))
data['tokens'] = ddd   
a=data.loc[data.act==1,['tokens','act']][:2000]
b=data.loc[data.act==3,['tokens','act']][:2000]
c= data.loc[data.act==0,['tokens','act']][:2000]
d =data.loc[data.act==2,['tokens','act']][:2000]

data = pd.concat([a,b,c,d],ignore_index = True)
data

Unnamed: 0,tokens,act
0,"[tensor(29871), tensor(1724), tensor(437), ten...",1
1,"[tensor(29871), tensor(1938), tensor(366), ten...",1
2,"[tensor(29871), tensor(306), tensor(4140), ten...",1
3,"[tensor(1815), tensor(366), tensor(437), tenso...",1
4,"[tensor(29871), tensor(830), tensor(635), tens...",1
...,...,...
7901,"[tensor(29871), tensor(8221), tensor(869), ten...",2
7902,"[tensor(29871), tensor(8221), tensor(1919), te...",2
7903,"[tensor(7251), tensor(1919), tensor(2041), ten...",2
7904,"[tensor(29871), tensor(366), tensor(29915), te...",2


In [8]:
X = list(data['tokens'])
Y = list(data['act'])
X

[tensor([29871,  1724,   437,   366,  2099,  1577,   739,   674,  1371,   502,
           304, 26681,   869, 29871]),
 tensor([29871,  1938,   366,  2289,  1348,   577,  1577,   306,  1016, 29915,
         29873,   869,   739,   674,   925,  1207,   502,  9950,   322,  1044,
         24866,   869, 22738,  1833,   931,  1577, 29871]),
 tensor([29871,   306,  4140,   366,   526,  1492, 29889,  6246,   825,  4091,
           591,   437,  1577,   306,  1016, 29915, 29873,  4459,   763, 16246,
           472,  3271,   869, 29871]),
 tensor([ 1815,   366,   437,  5503, 29899, 14340,  1577, 29871]),
 tensor([29871,   830,   635,  1577,   306,  1348,   393, 29915, 29879,  9301,
          1738, 29871]),
 tensor([29871,   887,  2099, 29871, 29941, 29900,  5503, 29899, 14340,  1577,
         29871]),
 tensor([ 1815,   366,  6559,   411,   278,  7155,   373,  1577, 29871]),
 tensor([29871,  1724,   338,   278,  4328,  1577, 29871]),
 tensor([ 4683,   366,   599,  1492,  1577, 29871]),
 tensor([186

In [9]:
data['act'].value_counts()

act
1    2000
0    2000
2    2000
3    1906
Name: count, dtype: int64

In [10]:
from sklearn.model_selection import train_test_split
no_cuda =False
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
pretrained_model = "llama-chat"
cached= True

idx2class = ["inform", "question", "directive", "commissive"]
class2idx = {c: i for i, c in enumerate(idx2class)}

discriminator = Discriminator(
    class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
).to(device)


x,test_x,y, test_y = train_test_split(X, Y, test_size=0.2, random_state=321)


train_dataset = Dataset(x, y)
test_dataset = Dataset(test_x, test_y)

discriminator_meta = {
    "class_size": len(idx2class),
    "embed_size": discriminator.embed_size,
    "pretrained_model": pretrained_model,
    "class_vocab": class2idx,
    "default_class": 0,
}




Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.78s/it]


In [11]:
batch_size = 16
save_model = True

 

if cached:
    print("Building representation cache...")

    start = time.time()

    train_loader,x_tr,y_tr = get_cached_data_loader(train_dataset, batch_size, discriminator, shuffle=True, device=device)

    test_loader,x_te,y_te = get_cached_data_loader(test_dataset, batch_size, discriminator, device=device)

    end = time.time()
    print("Building representation cache took: {:.3f}s".format(end - start))
else:
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
    )
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)

    
with open('x_tr.pkl','wb') as f:
    pickle.dump(x_tr,f)
with open('y_tr.pkl','wb') as f:
    pickle.dump(y_tr,f)
with open('x_te.pkl','wb') as f:
    pickle.dump(x_te,f)
with open('y_te.pkl','wb') as f:
    pickle.dump(y_te,f)
    


Building representation cache...


100%|#################################################################################| 396/396 [14:26<00:00,  2.19s/it]
100%|###################################################################################| 99/99 [04:02<00:00,  2.45s/it]


Building representation cache took: 1108.909s


# Train

In [15]:
epochs = 100
log_interval =1000


best_valid_loss =float('inf')
valid_losses = []
optimizer = optim.Adam(discriminator.parameters(), lr=0.00001)
for epoch in range(epochs):
    start = time.time()
    print("\nEpoch", epoch + 1)

    train_epoch(
        discriminator=discriminator,
        data_loader=train_loader,
        optimizer=optimizer,
        epoch=epoch,
        log_interval=log_interval,
        device=device,
    )
    test_loss= evaluate_performance(data_loader=test_loader, discriminator=discriminator, device=device)
    valid_losses.append(test_loss)
    end = time.time()
    #print("Epoch took: {:.3f}s".format(end - start))

    #print("\nExample prediction")
    #predict(example_sentence, discriminator, idx2class, cached=cached, device=device)
    
    if save_model:
        # torch.save(discriminator.state_dict(),
        #           "{}_discriminator_{}.pt".format(
        #               args.dataset, epoch + 1
        #               ))
        if test_loss < best_valid_loss:
            best_valid_loss = test_loss
            if epoch !=0:
                torch.save(
                    discriminator.get_classifier().state_dict(),
                    "{}_reclassifier_head_epoch_{}.pt".format('utter', epoch + 1),
            )



Epoch 1
Performance on test set: Average loss: 0.5814, Accuracy: 1227/1582 (78%)

Epoch 2
Performance on test set: Average loss: 0.5809, Accuracy: 1230/1582 (78%)

Epoch 3
Performance on test set: Average loss: 0.5806, Accuracy: 1228/1582 (78%)

Epoch 4
Performance on test set: Average loss: 0.5806, Accuracy: 1229/1582 (78%)

Epoch 5
Performance on test set: Average loss: 0.5803, Accuracy: 1226/1582 (77%)

Epoch 6
Performance on test set: Average loss: 0.5794, Accuracy: 1231/1582 (78%)

Epoch 7
Performance on test set: Average loss: 0.5804, Accuracy: 1229/1582 (78%)

Epoch 8
Performance on test set: Average loss: 0.5805, Accuracy: 1236/1582 (78%)

Epoch 9
Performance on test set: Average loss: 0.5802, Accuracy: 1229/1582 (78%)

Epoch 10
Performance on test set: Average loss: 0.5813, Accuracy: 1231/1582 (78%)

Epoch 11
Performance on test set: Average loss: 0.5807, Accuracy: 1233/1582 (78%)

Epoch 12
Performance on test set: Average loss: 0.5803, Accuracy: 1232/1582 (78%)

Epoch 13
Per

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np



plt.plot( np.arange(0,epoch),valid_losses)
plt.show()