In [None]:
!pip install -U transformers

In [None]:
import torch
import json
import torch.nn.functional as F
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

In [None]:
DATASET_NAME = 'databricks/databricks-dolly-15k'
MODEL_NAME = 'Qwen/Qwen2.5-0.5B'
RANDOM_STATE = 42
MAXITER = 500
MIN_WORDS = 10
MAX_WORDS = 100
SAMPLE_SIZE = 50
BATCH_SIZE = 8
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
HYPERPARAMS = {
        'lr': 0.01,
        'weight_decay': 0.01,
        'betas': (0.9, 0.9)
    }
torch.manual_seed(RANDOM_STATE)

In [None]:
class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx], idx

def evaluate_plain_model(model, input_ids):
    model.eval()
    with torch.no_grad():
        logits = model(input_ids).logits
        vocab_size = logits.shape[-1]
        
        logits_for_loss = logits[:, :-1, :].reshape((-1, vocab_size))
        targets_for_loss = input_ids[:, 1:].reshape(-1)

        loss = F.cross_entropy(logits_for_loss, targets_for_loss).item()
        pred = logits[:, :-1, :].argmax(dim=-1).reshape(-1)

        accuracy, seq_accuracy, correct_prefix_length = calculate_metrics(targets_for_loss, pred)
        return loss, accuracy, seq_accuracy, correct_prefix_length

def generate_input_one(vectors, text_length):
    return torch.cat([vectors[:1, None, :], vectors[1:2, None, :].expand(-1, text_length - 1, -1)], dim=1)

def generate_input(batch_vectors, lengths):
    embeds = []
    for vectors, length in zip(batch_vectors, lengths):
        embeds.append(generate_input_one(vectors, length).squeeze(0))
    return pad_sequence(embeds, batch_first=True, padding_value=0.0)   

def calculate_metrics(target, pred):
    accuracy = (pred == target).float().mean().item()
    first_wrong = torch.nonzero((pred != target).float())
    if first_wrong.shape[0] == 0:
        seq_accuracy = 1.0
        correct_prefix_length = len(pred)
    else:
        seq_accuracy = first_wrong[0].item() / len(pred)
        correct_prefix_length = first_wrong[0].item()
    return accuracy, seq_accuracy, correct_prefix_length

def collate_fn(batch, tokenizer, device):
    texts = [item[0] for item in batch]
    indices = [item[1] for item in batch]
    input_ids = [tokenizer.encode(text, return_tensors='pt').reshape(-1) for text in texts]
    lengths = [text.shape[0] for text in input_ids]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
    attention_mask = (input_ids != tokenizer.pad_token_id).int()
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'lengths': lengths,
        'indices': indices
    }

def safe_sample(group, n, seed):
    if len(group) < n:
        return group
    return group.sample(n, random_state=seed)

In [None]:
dataset = load_dataset(DATASET_NAME)
df = dataset['train'].to_pandas()
df = df[df['response'].apply(lambda x: len(x.split(' ')) > MIN_WORDS)]
df = df.groupby(by='category', group_keys=False).apply(lambda x: safe_sample(x, SAMPLE_SIZE, RANDOM_STATE)).reset_index(drop=True)
texts = list(df['response'])

In [None]:
import matplotlib.pyplot as plt  
text_len = [len(text.split(' ')) for text in texts]
plt.hist(text_len, bins=25)

In [None]:
texts = [' '.join(text.split(' ')[:MAX_WORDS]) if len(text.split(' ')) > MAX_WORDS else text for text in texts]
text_len = [len(text.split(' ')) for text in texts]
plt.hist(text_len, bins=25)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text_dataset = TextDataset(texts)
text_dataloader = DataLoader(text_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: collate_fn(x, tokenizer, DEVICE))
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
for param in model.parameters():
    param.requires_grad = False
model.eval()

In [None]:
result = []

for batch in tqdm(text_dataloader, desc='Processing dataset'):
    tokenized_text = batch['input_ids']
    lengths = batch['lengths']
    attention_mask = batch['attention_mask']
    indices = batch['indices']
    labels = tokenized_text.clone()
        
    vectors = torch.nn.Parameter(torch.randn(BATCH_SIZE, 2, model.config.hidden_size, device=DEVICE))
    optimizer = torch.optim.AdamW([vectors], lr=HYPERPARAMS['lr'], betas=HYPERPARAMS['betas'], weight_decay=HYPERPARAMS['weight_decay'])

    # vanilla_loss, vanilla_accuracy, vanilla_seq_accuracy, vanilla_correct_prefix_length = evaluate_plain_model(model, tokenized_text)
    max_accuracy = torch.zeros(BATCH_SIZE, device=DEVICE)
    max_seq_accuracy = torch.zeros(BATCH_SIZE, device=DEVICE)
    best_vectors = [None] * BATCH_SIZE
    best_metrics = [(0.0, 0.0, 0)] * BATCH_SIZE
        
    for _ in range(MAXITER):
        optimizer.zero_grad()
            
        current_input = generate_input(vectors, lengths)
        logits = model(inputs_embeds=current_input, attention_mask=attention_mask).logits
        logits = logits.reshape(-1, logits.shape[-1])
        loss = torch.nn.functional.cross_entropy(logits, labels.reshape(-1), ignore_index=tokenizer.pad_token_id)

        pred = logits.argmax(dim=-1).view(tokenized_text.shape)
        for i in range(BATCH_SIZE):
            current_pred = pred[i, :lengths[i]]
            current_labels = labels[i, :lengths[i]]
            accuracy, seq_accuracy, correct_prefix_length = calculate_metrics(current_labels, current_pred)
            if (accuracy > max_accuracy[i]) or (accuracy == max_accuracy[i] and seq_accuracy > max_seq_accuracy[i]):
                max_accuracy[i] = accuracy
                max_seq_accuracy[i] = seq_accuracy
                best_metrics[i] = (accuracy, seq_accuracy, correct_prefix_length)
                best_vectors[i] = vectors[i].detach().clone()
            
        loss.backward()
        optimizer.step()
        
    for i in range(BATCH_SIZE):
        result.append({
            'instruction': df.iloc[indices[i]]['instruction'],
            'context': df.iloc[indices[i]]['context'],
            'category': df.iloc[indices[i]]['category'],
            'text': df.iloc[indices[i]]['response'],
            'accuracy': best_metrics[i][0],
            'seq_accuracy': best_metrics[i][1],
            'correct_prefix_len': best_metrics[i][2],
            'best_vectors': best_vectors[i].cpu().numpy().tolist()
    })

In [None]:
with open('/kaggle/working/training_results.json', 'w', encoding='utf-8') as f:
    json.dump(result, f, ensure_ascii=False, indent=4)