# Import libraries

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import json
import glob
from collections import defaultdict

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from tqdm import tqdm

from filtering import (
    get_args,
    load_model,
    output_to_jsonl,
    filtering,
    load_real_data,
    load_syn_data,
    compute_average_grads,
    calculate_recon_loss_ids,
)

# Parameters

In [None]:
json_file = 'synthetic_data'
file_dir = './synthetic_data/'
exp_pattern = 'test' # Change the pattern here
training_dir = glob.glob(os.path.join(file_dir, f'{exp_pattern}*'))
print(len(training_dir))

input_flags = [sys.argv[0],
               '--dataset', 'sst2', # sst2, rotten_tomatoes, TwitterEmotion
               '--model_name', 'phi',
               '--pos_label', 'positive', # positive
               '--neg_label', 'negative', # negative
               '--gen_bs', '10',
               '--use_instruction', 'false',
               '--use_fewshot', 'true',
               '--filter_score', 'cls', # cls
               '--filter_method', 'remove', # remove, relabel, top_score, first, bottom_score, greedy_selection
               '--coeff_perplexity', '0', # 0, 0.05
               '--top_n', '50',
               '--file_dir', file_dir,
               '--json_file', json_file,
               '--clean', 'true',
               '--balance_score', 'true',
               '--per_label', 'true',
               '--interleave_label', 'false',
]
sys.argv = input_flags

# Load parameters
args = get_args()
for arg in vars(args):
    print(f"{arg}: {getattr(args, arg)}")

# Load model

In [None]:
tokenizer, model, device = load_model(args.model_name)

# Set last layer gradient
LAST_LAYERS = ["lm_head"] 

named_parameters_to_optim = []

for name, param in model.named_parameters():
    if any(substring in name for substring in LAST_LAYERS):
        named_parameters_to_optim.append((name, param))
    else:
        param.requires_grad = False

assert len(named_parameters_to_optim) != 0, "no layer found"
print(f"Set gradients for {len(named_parameters_to_optim)} layers")

# Filtering - Clean remove

In [None]:
# Set file_path
args.filter_method = 'remove'
args.gen_bs = 10 # Make sure it matches the setting

# Set real_id
for run in tqdm(training_dir):
    file_path = os.path.join(run, f'{args.json_file}.jsonl')
    samples = []
    
    with open(file_path, 'r') as f:
        for line in f:
            samples.append(json.loads(line))
    
    with open(file_path, 'w') as f:
        for sample in samples:
            sample["real_id"] = sample["id"] // args.gen_bs
            f.write(json.dumps(sample) + "\n")

# Clean remove
filtering(
    args,
    training_dir,
    model,
    tokenizer,
    device
)

# (Re)calculate rec_loss_ids per sample

In [None]:
# Load real data & compute grad
pos_sequences, neg_sequences, pos_labels, neg_labels = load_real_data(
    dataset_name='sst2',
    split='validation',
    device=device,
    n_gen_samples=100,
    n_fewshot=0,
    random_seed=42,
    subset=20,
)

print(pos_sequences[:5])

real_pos_grads = compute_average_grads(
    args,
    model,
    tokenizer,
    pos_sequences,
    pos_labels
)

real_neg_grads = compute_average_grads(
    args,
    model,
    tokenizer,
    neg_sequences,
    neg_labels
)

In [None]:
if args.dataset in ['sst2', 'rotten_tomatoes']:
    POS_LABEL = 'great'
    NEG_LABEL = 'bad'
elif args.dataset == 'TwitterEmotion':
    POS_LABEL = 'joy'
    NEG_LABEL = 'sadness'

for run in tqdm(training_dir):
    syn_data_path = os.path.join(run, f'synthetic_data_clean_remove_cls_phi_{args.dataset}_{args.pos_label}_{args.neg_label}_instrFalse_fsTrue.jsonl')
    if not os.path.exists(syn_data_path):
        print(syn_data_path)
        continue

    # Load synthetic data
    syn_pos_sequences, syn_neg_sequences = load_syn_data(str(syn_data_path), args.dataset)

    list_raw_pos_loss = calculate_recon_loss_ids(
        syn_pos_sequences,
        [POS_LABEL for _ in range(len(syn_pos_sequences))],
        real_pos_grads,
        model,
        tokenizer,
        dataset=args.dataset
    )

    list_raw_neg_loss = calculate_recon_loss_ids(
        syn_neg_sequences,
        [NEG_LABEL for _ in range(len(syn_neg_sequences))],
        real_neg_grads,
        model,
        tokenizer,
        dataset=args.dataset
    )

    # Read the JSONL file
    samples = []
    with open(syn_data_path, 'r') as file:
        for line in file:
            sample = json.loads(line)
            samples.append(sample)

    # Group samples by label
    grouped_samples = defaultdict(list)
    for sample in samples:
        grouped_samples[sample['label']].append(sample)

    loss_dict = {
        1: list_raw_pos_loss,
        0: list_raw_neg_loss
    }
    list_out_samples = []

    for label, label_samples in grouped_samples.items():
        for sample_idx, sample in enumerate(label_samples):
            sample['rec_loss_ids'] = loss_dict[label][sample_idx]
            list_out_samples.append(sample)

    output_to_jsonl(args, list_out_samples, syn_data_path, post_processing=False)

# Extract top *score*

In [None]:
args.filter_method = 'top_score'
args.coeff_perplexity = 0
args.json_file = f'synthetic_data_clean_remove_cls_phi_{args.dataset}_{args.pos_label}_{args.neg_label}_instrFalse_fsTrue'
args.top_n = 50 # Modify here for different budget
args.balance_score = True

filtering(
    args,
    training_dir,
    model,
    tokenizer,
    device,
    num_out=5 if args.top_n == 3 else None
)

In [None]:
args.coeff_perplexity = 0.05

filtering(
    args,
    training_dir,
    model,
    tokenizer,
    device,
    num_out=5 if args.top_n == 3 else None
)