In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset
import torch
import pandas as pd
import re


device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(device)
datapath = "../data/24.csv"
instructions = """
Solve the following puzzle using the numbers provided to get the result 24. You are given four numbers, and you can use addition (+), subtraction (-), multiplication (*), and division (/) to combine these numbers. You must use each number exactly once. You can use parentheses to group operations and control the order of operations. Along with that, the numbers must stay in the order they are presented in. 

For example:
Numbers: 1, 1, 8, 8
Solution:
1. Start with the numbers 1, 1, 8, and 8.
2. Group 1 and 1 to get (1 + 1) = 2.
3. Multiply 2 by 8 to get 2 * 8 = 16.
4. Add the remaining 8 to get 16 + 8 = 24.
Final answer: (1+1) * 8 + 8

Another example:
Numbers: 4, 2, 3, 1
Solution:
1. Start with the numbers 3, 3, 8, and 8.
2. Group 4 and 1 to get (4 + 2) = 6.
3. Group 3 and 1 to get (3 + 1) = 4. 
4. Multiply 6 by 4 to get 6 * 4 = 24
Final answer: (4 + 2) * (3 + 1)
"""

cuda


In [2]:
# student model
student_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-large")
student_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-large")

In [3]:
student_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
student_model.resize_token_embeddings(len(student_tokenizer))

Embedding(50258, 1280)

In [4]:
df = pd.read_csv(datapath)


In [5]:
puzzles = df['Puzzles'].tolist()

In [6]:
class Game24Dataset(Dataset):
    def __init__(self, puzzles, tokenizer, instructions, max_length=512):
        self.puzzles = puzzles
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.instructions = instructions

    def __len__(self):
        return len(self.puzzles)

    def __getitem__(self, idx):
        puzzle = self.puzzles[idx]
        prompt = f"{self.instructions}: {puzzle}"
        encoding = self.tokenizer(
            prompt,
            add_special_tokens=True,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': input_ids
        }


In [7]:
# prototype of a loss function

def enhanced_loss_function(outputs, puzzles, tokenizer):
    total_loss = 0
    batch_size = len(puzzles)
    
    for i in range(batch_size):
        solution = tokenizer.decode(outputs[i], skip_special_tokens=True)
        puzzle = puzzles[i]
        
        # Check for correctness
        correct, partial_loss = evaluate_solution(solution, puzzle)
        
        total_loss += partial_loss if not correct else 0
    
    return total_loss / batch_size

def evaluate_solution(solution, puzzle):
   # Extract numbers and operators from the solution
    numbers = re.findall(r'\d+', solution)
    operators_and_brackets = re.findall(r'[+\-*/()]', solution)
    
    # Ensure the numbers used are exactly the ones in the puzzle
    puzzle_numbers = sorted(puzzle.split())
    solution_numbers = sorted(numbers)
    
    # Calculate a partial loss based on the incorrect use of numbers
    partial_loss = len(set(puzzle_numbers) - set(solution_numbers)) / len(puzzle_numbers)
    
    if puzzle_numbers != solution_numbers:
        return False, 1 + partial_loss
    
    # Check if valid operators and brackets are used
    valid_operators = set('+-*/')
    valid_brackets = set('()')
    invalid_chars = [char for char in operators_and_brackets if char not in valid_operators and char not in valid_brackets]
    
    if invalid_chars:
        partial_loss += 0.5  # Arbitrary penalty for invalid operators or brackets
    
    # Check for balanced brackets
    if not are_brackets_balanced(solution):
        partial_loss += 0.5  # Arbitrary penalty for unbalanced brackets
    
    # Evaluate the expression
    try:
        result = eval(solution)
        if result == 24:
            return True, 0
        else:
            return False, abs(24 - result) / 24 + partial_loss
    except Exception as e:
        print(f"Error evaluating solution: {e}")
        return False, 1 + partial_loss 

def are_brackets_balanced(expression):
    stack = []
    brackets = {'(': ')'}
    
    for char in expression:
        if char in brackets.keys():
            stack.append(char)
        elif char in brackets.values():
            if not stack or brackets[stack.pop()] != char:
                return False
    return not stack

In [8]:
dataset = Game24Dataset(puzzles, student_tokenizer, instructions)

In [9]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [10]:
student_model.eval()
student_model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50258, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50258, bias=False)
)

In [11]:
def extract_solution(generated_output, instructions):
    try:
        solution = generated_output.split(instructions)[-1].strip()
        return solution
    except Exception as e:
        print(f"Error extracting solution: {e}")
        return ""

In [12]:
def evaluate_model(model, dataloader, tokenizer, progress_interval=100):
    total_correct = 0
    total_puzzles = 0
    
    for i, batch in enumerate(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        with torch.no_grad():
            # Generate solutions from the model
            generated_outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=50)
        
        # Decode puzzles for evaluation
        puzzles = [tokenizer.decode(ids, skip_special_tokens=True).replace(instructions, "").strip() for ids in input_ids]
        
        for j in range(len(generated_outputs)):
            raw_solution = tokenizer.decode(generated_outputs[j], skip_special_tokens=True)
            solution = extract_solution(raw_solution, instructions)
            print(f"solution {j}: {solution}")
            puzzle = puzzles[j]
            correct, _ = evaluate_solution(solution, puzzle)
            total_correct += int(correct)
            total_puzzles += 1
        
        
        accuracy = total_correct / total_puzzles if total_puzzles > 0 else 0
        print(f"Progress: {i + 1} batches processed. Current accuracy: {accuracy * 100:.2f}%")
    
    # Final accuracy
    accuracy = total_correct / total_puzzles if total_puzzles > 0 else 0
    return accuracy

In [13]:
accuracy = evaluate_model(student_model, dataloader, student_tokenizer)
print(f"Model accuracy: {accuracy * 100:.2f}%")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


solution 0: : 12 12 12 13:
: 12 12 12 13:
: 12 12 12 13:
: 12 12 12 13:
: 12 12 12 13:
: 12 12 12 13:
: 12 12 12 13:
: 12 12 12 13:
solution 1: : 4 4 7 13:
: 4 4 7 13:
: 4 4 7 13:
: 4 4 7 13:
: 4 4 7 13:
: 4 4 7 13:
: 4 4 7 13:
: 4 4 7 13:
solution 2: : 3 4 7 9:
: 3 4 7 9:
: 3 4 7 9:
: 3 4 7 9:
: 3 4 7 9:
: 3 4 7 9:
: 3 4 7 9:
: 3 4 7 9:
solution 3: : 2 2 6 6: 2 2 6 6: 2 2 6 6: 2 2 6 6: 2 2 6 6: 2 2 6 6: 2 2 6 6: 2 2 6 6: 2 2 6 6: 2 2 6 6: 2 2 6 6
solution 4: : 1 3 5 11:

1. Start with the numbers 1, 1, 8, 8.

2. Group 1 and 1 to get (1 + 1) = 2.

3. Multiply 2 by 8 to get 2 * 8
solution 5: : 1 6 12 13: 2 6 12 13: 3 6 12 13: 4 6 12 13: 5 6 12 13: 6 6 12 13: 7 6 12 13: 8 6 12 13: 9 6 12 13: 10 6 12 13: 11 6 12 13
solution 6: : 4 5 5 9:
: 4 5 5 9:
: 4 5 5 9:
: 4 5 5 9:
: 4 5 5 9:
: 4 5 5 9:
: 4 5 5 9:
: 4 5 5 9:
solution 7: : 1 5 5 6: 2 6 6 7: 3 7 7 8: 4 8 8 9: 5 9 9 10: 6 10 10 11: 7 11 11 12: 8 12 12 13: 9 13 13 14: 10 14 14 15: 11 15 15 16
Progress: 1 batches processed. Current accura

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


solution 0: : 1 1 6 12:
: 1 1 6 12:
: 1 1 6 12:
: 1 1 6 12:
: 1 1 6 12:
: 1 1 6 12:
: 1 1 6 12:
: 1 1 6 12:
solution 1: : 2 8 10 13:

Numbers: 2, 2, 8, 8

Solution:

1. Start with the numbers 2, 2, 8, and 8.

2. Group 2 and 8 to get (2 + 8) = 8
solution 2: : 5 8 8 9:

Numbers: 5, 5, 8, 8

Solution:

1. Start with the numbers 5, 5, 8, and 8.

2. Group 5 and 8 to get (5 + 8) = 10
solution 3: : 2 3 7 13:

Solution:

1. Start with the numbers 2, 2, 7, and 7.

2. Group 2 and 7 to get (2 + 7) = 9.

3. Multiply 9 by 7
solution 4: : 1 6 11 13: 2 6 11 13: 3 6 11 13: 4 6 11 13: 5 6 11 13: 6 6 11 13: 7 6 11 13: 8 6 11 13: 9 6 11 13: 10 6 11 13: 11 6 11 13
solution 5: : 1 3 5 6: 1 3 5 6: 1 3 5 6: 1 3 5 6: 1 3 5 6: 1 3 5 6: 1 3 5 6: 1 3 5 6: 1 3 5 6: 1 3 5 6: 1 3 5 6
solution 6: : 1 5 7 10:
Solution:
1. Start with the numbers 1, 1, 5, 7, and 10.
2. Group 1 and 1 to get (1 + 1) = 5.
3. Multiply 5 by 7 to get
solution 7: : 6 6 8 8:

1. Start with the numbers 1, 1, 8, and 8.
2. Group 1 and 1 to get (1

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


solution 0: : 6 11 12 12:
: 6 11 12 12:
: 6 11 12 12:
: 6 11 12 12:
: 6 11 12 12:
: 6 11 12 12:
: 6 11 12 12:
: 6 11 12 12:
solution 1: : 4 8 10 10:
: 4 8 10 10:
: 4 8 10 10:
: 4 8 10 10:
: 4 8 10 10:
: 4 8 10 10:
: 4 8 10 10:
: 4 8 10 10:
solution 2: : 2 6 6 7: 3 6 6 7: 4 6 6 7: 5 6 6 7: 6 6 6 7: 7 6 6 7: 8 6 6 7: 9 6 6 7: 10 6 6 7: 11 6 6 7: 12 6 6 7
solution 3: : 1 2 3 12:
: 1 2 3 12:
: 1 2 3 12:
: 1 2 3 12:
: 1 2 3 12:
: 1 2 3 12:
: 1 2 3 12:
: 1 2 3 12:
solution 4: : 3 6 6 9:
: 3 6 6 9:
: 3 6 6 9:
: 3 6 6 9:
: 3 6 6 9:
: 3 6 6 9:
: 3 6 6 9:
: 3 6 6 9:
solution 5: : 2 6 6 8: 2 * 4 + 2 * 8 = 24
: 2 6 6 8: 2 * 4 + 2 * 8 = 24
: 2 6 6 8: 2 * 4 + 2 * 8 = 24
: 2 6 6 8: 2
solution 6: : 2 2 7 12:

Solution:

1. Start with the numbers 2, 2, 7, and 12.

2. Group 2 and 2 to get (2 + 2) = 6.

3. Multiply 6 by 2
solution 7: : 2 2 5 6:
: 2 2 5 6:
: 2 2 5 6:
: 2 2 5 6:
: 2 2 5 6:
: 2 2 5 6:
: 2 2 5 6:
: 2 2 5 6:
Progress: 3 batches processed. Current accuracy: 0.00%


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


solution 0: : 2 2 3 4:
: 2 2 3 4:
: 2 2 3 4:
: 2 2 3 4:
: 2 2 3 4:
: 2 2 3 4:
: 2 2 3 4:
: 2 2 3 4:
solution 1: : 2 11 12 13:

Solution:

1. Start with the numbers 11, 12, 13, and 13.

2. Group 11 and 12 to get (11 + 12) = 13.

3. Multiply 13 by 13
solution 2: : 3 5 10 13:

Solution:

1. Start with the numbers 1, 1, 8, 8.

2. Group 1 and 1 to get (1 + 1) = 2.

3. Multiply 2 by 8 to
solution 3: : 5 6 11 11:
Solution:
1. Start with the numbers 5, 6, 11, and 11.
2. Group 5 and 11 to get (5 + 11) = 11.
3. Multiply 11 by 5 to get 11 *
solution 4: : 1 4 6 6: 2 4 6 6: 3 4 6 6: 4 4 6 6: 5 4 6 6: 6 4 6 6: 7 4 6 6: 8 4 6 6: 9 4 6 6: 10 4 6 6: 11 4 6 6
solution 5: : 3 10 10 12:

Numbers: 1, 1, 8, 8

Solution:

1. Start with the numbers 1, 1, 8, and 8.

2. Group 1 and 1 to get (1 + 1) = 2
solution 6: : 5 6 6 9:
: 6 7 7 9:
: 7 8 8 9:
: 8 9 9 9:
: 10 10 10 10:
: 11 11 11 11:
: 12 12 12 12:
: 13 13 13 13:
solution 7: : 6 7 12 12:
: 6 7 12 12:
: 6 7 12 12:
: 6 7 12 12:
: 6 7 12 12:
: 6 7 12 12:
: 6 7 

KeyboardInterrupt: 