In [1]:
import sys
sys.path.append("src")
import torch
import numpy as np
import random
import json
import re
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from configuration import MistralConfig
from mistral_direct import Mistral
from tqdm import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class CoTAnalysisDataset(Dataset):
    def __init__(self, tokenizer, file_path: str, max_length: int, num_samples: int):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.eos_token = tokenizer.eos_token
        self.examples = self.load_and_process_file(file_path, num_samples)

    def load_and_process_file(self, file_path, num_samples):
        with open(file_path, encoding="utf-8") as f:
            lines = [line.strip().split('||') for line in f.readlines() if len(line.strip().split('||')) == 2]
        
        sampled_lines = random.sample(lines, min(num_samples, len(lines)))
        examples = []
        for src, tgt in sampled_lines:
            ans = self.extract_answer_w_prefix(tgt)
            cot = self.extract_cot_w_prefix(tgt)
            example = {
                'src': src,
                'cot': cot,
                'ans': ans
            }
            examples.append(example)
        return examples

    def extract_answer_w_prefix(self, text, prefix='####'):
        parts = text.split('####', 1)
        return prefix + " " + parts[1].strip().replace(',', '') if len(parts) > 1 else text

    def extract_cot_w_prefix(self, text, prefix=""):
        parts = text.split('####', 1)
        return prefix + parts[0].strip() if len(parts) > 1 else ""

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        example = self.examples[i]
        return example

# def split_cot_sentences(cot):
#     # '.' 앞뒤가 숫자가 아닌 경우에만 split
#     sentences = re.split(r'(?<!\d)\.(?!\d)', cot)
#     # 문장 정리 및 필요 시 마침표 추가
#     result = []
#     for s in sentences:
#         s = s.strip()
#         if s:
#             if s[-1] not in '.!?':
#                 s += '.'
#             result.append(s)
#     return result

def split_cot_sentences(cot):
    # '.' 앞이 숫자여도 되지만, 뒤에는 숫자가 오면 안 됨
    sentences = re.split(r'\.(?!\d)', cot)
    # 문장 정리 및 필요 시 마침표 추가
    result = []
    for s in sentences:
        s = s.strip()
        if s:
            if s[-1] not in '.!?':
                s += '.'
            result.append(s)
    return result

def format_input(src, cot, ans, eos_token):
    return f"{src} {eos_token} {cot} {eos_token} {ans} {eos_token}"

def get_sep_position(input_ids, sep_id, skip=0):
    batch_size = input_ids.shape[0]
    sep_positions = input_ids.new_zeros(batch_size).long()
    for batch_id in range(batch_size):
        mask = input_ids[batch_id].eq(sep_id)
        sep_position = mask.nonzero()[0, -1].item()
        for _ in range(skip):
            mask[sep_position] = False
            sep_position = mask.nonzero()[0, -1].item()
        sep_positions[batch_id] = sep_position
    return sep_positions

In [5]:
def compute_answer_loss(model, tokenizer, input_text, device):
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    
    # Find the position of the last EOS token (which comes before the answer)
    sep_positions = get_sep_position(inputs.input_ids, tokenizer.eos_token_id, skip=1)
    
    # Create labels: -100 for all tokens except the answer
    labels = torch.full_like(inputs.input_ids, -100)
    for i in range(inputs.input_ids.shape[0]):
        labels[i, sep_positions[i]+1:] = inputs.input_ids[i, sep_positions[i]+1:]
    
    with torch.no_grad():
        outputs = model.compute_loss(inputs['input_ids'], labels=labels)
    
    return outputs.loss.item()

In [29]:
# data = 'math_qa'
# data = 'aqua_rat'
# data = 'trivia_qa'
# data='commonsenseqa'
# data = 'gsm8k'
data='strategy-qa'

def analyze_cot_removal(model, tokenizer, dataset, max_sentences_to_remove, num_samples_do, device):
    results = []
    
    for t in range(1, max_sentences_to_remove + 1):
        loss_differences = []
        example_results = []

        cnt = 0
        for example in tqdm(dataset, desc=f"Analyzing {t} sentence(s) removal"):
            src, cot, ans = example['src'], example['cot'], example['ans']
            sentences = split_cot_sentences(cot)
            

            if len(sentences) < t:
                continue
            
            original_input = format_input(src, cot, ans, tokenizer.eos_token)
            original_loss = compute_answer_loss(model, tokenizer, original_input, device)
            # torch.cuda.empty_cache()
            
            # Remove t random sentences
            removed_sentences = random.sample(sentences, t)
            
            remaining_sentences = [s for s in sentences if s not in removed_sentences]
            if len(remaining_sentences) <1:
                continue

            modified_cot = " ".join(remaining_sentences).strip()

            modified_input = format_input(src, modified_cot, ans, tokenizer.eos_token)
            modified_loss = compute_answer_loss(model, tokenizer, modified_input, device)
            # torch.cuda.empty_cache()
            
            loss_difference =  original_loss - modified_loss
            loss_differences.append(loss_difference)
            
            example_results.append({
                "original_text": original_input,
                "removed_cot_text": modified_input,
                "removed_sentences": " ".join(removed_sentences).strip(),
                'original_loss': original_loss,
                'modified_loss': modified_loss,
                "loss_difference": loss_difference
            })
            cnt +=1
            if cnt == num_samples_do:
                break
        
        avg_loss_difference = sum(loss_differences) / len(loss_differences) if loss_differences else 0
        results.append({
            "sentences_removed": t,
            "avg_loss_difference": avg_loss_difference,
            "num_samples": len(loss_differences)
        })
        
        # Save detailed results to a JSON file
        os.makedirs(f"motivation-new/{data}", exist_ok=True)
        with open(f'motivation-new/{data}/cot_removal_results_{t}.jsonl', 'w') as f:
            for item in example_results:
                json.dump(item, f)
                f.write('\n')
    
    return results


# def analyze_cot_removal(model, tokenizer, dataset, max_sentences_to_remove, num_samples_do, device):
#     results = []
    
#     for t in range(1, max_sentences_to_remove + 1):
#         loss_differences = []
#         example_results = []

#         cnt = 0
#         for example in tqdm(dataset, desc=f"Analyzing {t} sentence(s) removal"):
#             src, cot, ans = example['src'], example['cot'], example['ans']
#             sentences = split_cot_sentences(cot)
            
#             if len(sentences) < t:
#                 continue
            
#             original_input = format_input(src, cot, ans, tokenizer.eos_token)
#             original_loss = compute_answer_loss(model, tokenizer, original_input, device)
#             # torch.cuda.empty_cache()
            
#             # Create a probability distribution that gives higher weight to earlier sentences
#             probabilities = np.linspace(1.0, 0.1, num=len(sentences))
#             probabilities /= probabilities.sum()  # Normalize to make it a valid probability distribution
            
#             # Remove t sentences with weighted probabilities
#             removed_indices = np.random.choice(len(sentences), size=t, replace=False, p=probabilities)
#             removed_sentences = [sentences[i] for i in removed_indices]
            
#             remaining_sentences = [s for i, s in enumerate(sentences) if i not in removed_indices]
#             if len(remaining_sentences) < 1:
#                 continue

#             modified_cot = " ".join(remaining_sentences).strip()

#             modified_input = format_input(src, modified_cot, ans, tokenizer.eos_token)
#             modified_loss = compute_answer_loss(model, tokenizer, modified_input, device)
#             # torch.cuda.empty_cache()
            
#             loss_difference = original_loss - modified_loss
#             loss_differences.append(loss_difference)
            
#             example_results.append({
#                 "original_text": original_input,
#                 "removed_cot_text": modified_input,
#                 "removed_sentences": " ".join(removed_sentences).strip(),
#                 'original_loss': original_loss,
#                 'modified_loss': modified_loss,
#                 "loss_difference": loss_difference
#             })
#             cnt +=1
#             if cnt == num_samples_do:
#                 break
        
#         avg_loss_difference = sum(loss_differences) / len(loss_differences) if loss_differences else 0
#         results.append({
#             "sentences_removed": t,
#             "avg_loss_difference": avg_loss_difference,
#             "num_samples": len(loss_differences)
#         })
        
#         # Save detailed results to a JSON file
#         os.makedirs(f"motivation-new2/{data}", exist_ok=True)
#         with open(f'motivation-new2/{data}/cot_removal_results_{t}.jsonl', 'w') as f:
#             for item in example_results:
#                 json.dump(item, f)
#                 f.write('\n')
    
#     return results

In [27]:
model_name = "mistralai/Mistral-7B-v0.1"  # Update this to your model
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = Mistral.from_pretrained(model_name)

config = MistralConfig(model_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Mistral(config).to(device, dtype=torch.bfloat16)
tokenizer = model.tokenizer

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.38s/it]


In [30]:

file_path = f"data/{data}/{data}_train.txt"  # Update this to your dataset path
num_samples = 1000  # Number of samples to analyze
num_samples_do = 100
max_length = 512  # Max sequence length
max_sentences_to_remove = 5  # Maximum number of sentences to remove

dataset = CoTAnalysisDataset(tokenizer, file_path, max_length, num_samples)

results = analyze_cot_removal(model, tokenizer, dataset, max_sentences_to_remove, num_samples_do, device)

# Print results
print("\nResults:")
for result in results:
    print(f"Sentences removed: {result['sentences_removed']}")
    print(f"Average loss difference: {result['avg_loss_difference']:.4f}")
    print(f"Number of samples: {result['num_samples']}")
    # print("!")

Analyzing 1 sentence(s) removal:   0%|          | 0/1000 [00:00<?, ?it/s]

Analyzing 1 sentence(s) removal:  11%|█         | 111/1000 [00:04<00:36, 24.25it/s]
Analyzing 2 sentence(s) removal:  19%|█▉        | 188/1000 [00:05<00:25, 31.95it/s]
Analyzing 3 sentence(s) removal:  60%|██████    | 601/1000 [00:08<00:05, 67.67it/s]
Analyzing 4 sentence(s) removal: 100%|██████████| 1000/1000 [00:05<00:00, 199.28it/s]
Analyzing 5 sentence(s) removal: 100%|██████████| 1000/1000 [00:01<00:00, 668.37it/s]


Results:
Sentences removed: 1
Average loss difference: -0.1220
Number of samples: 100
Sentences removed: 2
Average loss difference: -0.1394
Number of samples: 100
Sentences removed: 3
Average loss difference: -0.2038
Number of samples: 100
Sentences removed: 4
Average loss difference: -0.2156
Number of samples: 47
Sentences removed: 5
Average loss difference: -0.3876
Number of samples: 15



