In [13]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch.optim import AdamW
from torch.nn.functional import cross_entropy

In [61]:
MODEL_NAME = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
# Define the new maximum length for the embeddings from BERT
BERT_MAX_LENGTH = 502
# Define the new soft prompt length
SOFT_PROMPT_LENGTH = 10

In [16]:
class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer, max_length=BERT_MAX_LENGTH):
        self.encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
    
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}
    
    def __len__(self):
        return len(self.encodings.input_ids)

In [17]:
import pandas as pd
from itertools import product

train = pd.read_csv("data/train.csv").sample(100)

question_columns = ['질문_1']
answer_columns = ['답변_2']

queries = []
answers = []
for question, answer in product(question_columns, answer_columns):
    for index, row in train.iterrows():
        queries.append(row[question])
        answers.append(row[answer])

In [18]:
# Combine queries and answers into training data
training_data = [f"Query: {q} Answer: {a}" for q, a in zip(queries, answers)]

# Create a dataset and data loader
dataset = SimpleDataset(training_data, tokenizer)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=20)

In [19]:
len(training_data)

100

In [20]:
# Let's define a soft prompt
soft_prompt = torch.nn.Embedding(SOFT_PROMPT_LENGTH, model.config.hidden_size)
soft_prompt.weight.data.normal_(mean=0.0, std=0.5)

tensor([[ 0.6525, -0.9803,  1.0921,  ...,  0.3349,  0.5240, -0.3715],
        [ 0.1660, -0.2549, -0.2168,  ..., -0.8566,  0.0027, -0.2683],
        [-0.0462,  0.0218,  0.4449,  ...,  0.3407,  0.6583, -0.6654],
        ...,
        [ 0.0522,  0.8773, -0.4474,  ..., -0.1889, -0.0091,  0.2379],
        [ 0.9105, -0.1946, -0.9873,  ...,  0.3898,  0.3661, -0.3900],
        [ 0.2153,  0.5887,  0.9717,  ..., -0.3128, -0.4339, -0.6346]])

In [21]:
# We'll use a simple optimizer just for the soft prompt
optimizer = AdamW(soft_prompt.parameters(), lr=1e-4)

In [22]:
model.train()

# Training loop for soft prompt tuning
for epoch in range(2): 
    for batch in data_loader:
        optimizer.zero_grad()

        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        soft_prompt_tokens = torch.arange(SOFT_PROMPT_LENGTH).unsqueeze(0).expand(input_ids.size(0), -1)
        soft_prompt_embeddings = soft_prompt(soft_prompt_tokens)
        inputs_embeds = model.bert.embeddings.word_embeddings(input_ids)
        inputs_embeds = torch.cat((soft_prompt_embeddings, inputs_embeds), dim=1)
        attention_mask = torch.cat([torch.ones(input_ids.size(0), SOFT_PROMPT_LENGTH), attention_mask], dim=1)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        
        # Compute loss
        logits = outputs.logits[:, SOFT_PROMPT_LENGTH:, :]  # Exclude the logits for the soft prompt tokens
        loss = cross_entropy(logits.reshape(-1, model.config.vocab_size), input_ids.view(-1))

        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch}, Loss: {loss.item()}")


Epoch 0, Loss: 4.761325836181641
Epoch 0, Loss: 5.644777297973633
Epoch 0, Loss: 4.193082809448242
Epoch 0, Loss: 4.40958833694458
Epoch 0, Loss: 5.350088596343994
Epoch 1, Loss: 4.69485330581665
Epoch 1, Loss: 5.628662109375
Epoch 1, Loss: 4.254441261291504
Epoch 1, Loss: 4.384392738342285
Epoch 1, Loss: 5.291342735290527


In [27]:
# Save the soft prompt embeddings
torch.save(soft_prompt.state_dict(), 'path_to_save_soft_prompt/soft_prompt.pt')