In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import Dataset, load_dataset
import os

# Hyperparameters
MODEL_NAME = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
DATASET_PATH = "sudoku_sft_data.json"
OUTPUT_DIR = "sft_output"
BATCH_SIZE = 8
LEARNING_RATE = 2e-5
GRADIENT_ACCUMULATION_STEPS = 1
NUM_EPOCHS = 3

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token


# Preprocess data
def preprocess_function(examples):
    inputs = examples["instruction"] + examples["input"]
    targets = examples["output"]
    text = tokenizer(f"<instruction>, {inputs}, <output>, {targets}{tokenizer.eos_token}", return_tensors="pt", padding="longest")
    return text

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = load_dataset("json", data_files=DATASET_PATH)

In [3]:
data["train"] = data["train"].take(10)

In [4]:
data["train"][0]

{'instruction': 'Solve this Sudoku puzzle:',
 'input': '0 1 9 4 0 3 0 5 7 2 5 3 7 9 1 0 4 8 4 0 8 5 6 2 0 1 9 0 9 1 2 0 6 0 7 5 5 2 0 1 7 8 4 9 3 8 4 0 3 5 9 1 0 2 7 3 2 6 1 5 9 8 4 9 6 5 0 3 0 0 2 1 1 8 4 9 2 7 5 3 0',
 'output': '<thonk> I see a sudoku problem. Most of its cells are filled. So it should be easy to finish it.\nIn row 2 the only missing element is 6 so row 2 column 7 must be 6.\nIn column 2 the only missing element is 7 so row 3 column 2 must be 7.\nIn row 3 the only missing element is 3 so row 3 column 7 must be 3.\nIn column 4 the only missing element is 8 so row 8 column 4 must be 8.\nIn row 5 the only missing element is 6 so row 5 column 3 must be 6.\nIn column 3 the only missing element is 7 so row 6 column 3 must be 7.\nIn row 6 the only missing element is 6 so row 6 column 8 must be 6.\nIn column 6 the only missing element is 4 so row 8 column 6 must be 4.\nIn row 8 the only missing element is 7 so row 8 column 7 must be 7.\nIn column 9 the only missing element 

In [5]:
tokenized_datasets = data.map(preprocess_function, remove_columns=["instruction", "input", "output"]).with_format("torch")
tokenized_datasets["train"][0]["input_ids"]

Map: 100%|██████████| 10/10 [00:00<00:00, 135.80 examples/s]


tensor([[    27,  54974,   8066,  63284,    419,  94254,  24626,     25,     15,
            220,     16,    220,     24,    220,     19,    220,     15,    220,
             18,    220,     15,    220,     20,    220,     22,    220,     17,
            220,     20,    220,     18,    220,     22,    220,     24,    220,
             16,    220,     15,    220,     19,    220,     23,    220,     19,
            220,     15,    220,     23,    220,     20,    220,     21,    220,
             17,    220,     15,    220,     16,    220,     24,    220,     15,
            220,     24,    220,     16,    220,     17,    220,     15,    220,
             21,    220,     15,    220,     22,    220,     20,    220,     20,
            220,     17,    220,     15,    220,     16,    220,     22,    220,
             23,    220,     19,    220,     24,    220,     18,    220,     23,
            220,     19,    220,     15,    220,     18,    220,     20,    220,
             24,    220,    

In [6]:
tokenizer("<s>", return_tensors="pt", padding="longest")

{'input_ids': tensor([[44047,    29]]), 'attention_mask': tensor([[1, 1]])}

In [7]:
input_ids = tokenized_datasets["train"][0]["input_ids"].to("cuda")
print(input_ids.shape)
model(input_ids, labels=input_ids).loss

torch.Size([1, 850])


tensor(0.9330, device='cuda:0', grad_fn=<NllLossBackward0>)

In [8]:
checkpoint = torch.load("sft_output/epoch=0-step=11250.ckpt")
checkpoint.keys()

  checkpoint = torch.load("sft_output/epoch=0-step=11250.ckpt")


dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])

In [9]:
state_dict = {}
for key in checkpoint["state_dict"].keys():
    # remove the prefix "model."
    state_dict[key.replace("model.", "", 1)] = checkpoint["state_dict"][key]

model.load_state_dict(state_dict)
model.eval()
model.to("cuda")

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotary_emb): Qw

In [10]:
prompt = f"<instruction>, Solve this Sudoku puzzle:\n{data['train'][0]['instruction'] + data['train'][0]['input']}, <output>, "
y = tokenizer(prompt, return_tensors="pt", padding="longest").to("cuda")
prompt, y

('<instruction>, Solve this Sudoku puzzle:\nSolve this Sudoku puzzle:0 1 9 4 0 3 0 5 7 2 5 3 7 9 1 0 4 8 4 0 8 5 6 2 0 1 9 0 9 1 2 0 6 0 7 5 5 2 0 1 7 8 4 9 3 8 4 0 3 5 9 1 0 2 7 3 2 6 1 5 9 8 4 9 6 5 0 3 0 0 2 1 1 8 4 9 2 7 5 3 0, <output>, ',
 {'input_ids': tensor([[   27, 54974,  8066, 63284,   419, 94254, 24626,   510,    50,  3948,
            419, 94254, 24626,    25,    15,   220,    16,   220,    24,   220,
             19,   220,    15,   220,    18,   220,    15,   220,    20,   220,
             22,   220,    17,   220,    20,   220,    18,   220,    22,   220,
             24,   220,    16,   220,    15,   220,    19,   220,    23,   220,
             19,   220,    15,   220,    23,   220,    20,   220,    21,   220,
             17,   220,    15,   220,    16,   220,    24,   220,    15,   220,
             24,   220,    16,   220,    17,   220,    15,   220,    21,   220,
             15,   220,    22,   220,    20,   220,    20,   220,    17,   220,
             15,   22

In [12]:
from sudoku import Sudoku
import numpy as np

In [13]:
sud = Sudoku().difficulty(0.1)
sud.show()
problem = "".join([str(cell or 0) for row in sud.board for cell in row])
prompt = f"<instruction>, Solve this Sudoku puzzle:\n{problem}, <output>,"
y = tokenizer(prompt, return_tensors="pt", padding="longest").to("cuda")
prompt, y

Puzzle has exactly one solution
+-------+-------+-------+
| 4 1 9 |   3 8 | 7 6 2 |
| 5 6 7 | 1 2 4 | 9 8 3 |
| 2 3 8 | 7 9 6 | 4 5 1 |
+-------+-------+-------+
| 8 5 3 | 2 4 9 | 1 7 6 |
| 9 4 1 |     7 | 2 3 8 |
| 6 7 2 | 3 8 1 | 5 4 9 |
+-------+-------+-------+
|   8 4 | 9 7 3 | 6 2 5 |
| 3 9   | 4 6 2 | 8 1 7 |
| 7 2 6 | 8 1 5 |       |
+-------+-------+-------+



('<instruction>, Solve this Sudoku puzzle:\n419038762567124983238796451853249176941007238672381549084973625390462817726815000, <output>,',
 {'input_ids': tensor([[   27, 54974,  8066, 63284,   419, 94254, 24626,   510,    19,    16,
             24,    15,    18,    23,    22,    21,    17,    20,    21,    22,
             16,    17,    19,    24,    23,    18,    17,    18,    23,    22,
             24,    21,    19,    20,    16,    23,    20,    18,    17,    19,
             24,    16,    22,    21,    24,    19,    16,    15,    15,    22,
             17,    18,    23,    21,    22,    17,    18,    23,    16,    20,
             19,    24,    15,    23,    19,    24,    22,    18,    21,    17,
             20,    18,    24,    15,    19,    21,    17,    23,    16,    22,
             22,    17,    21,    23,    16,    20,    15,    15,    15,    11,
            366,  3006,  8066]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [16]:
x = model.generate(**y, max_new_tokens=100)
tokenizer.decode(x[0], skip_special_tokens=True)

'<instruction>, Solve this Sudoku puzzle:\n419038762567124983238796451853249176941007238672381549084973625390462817726815000, <output>,<thonk> I see a sudoku problem. Most of its cells are filled. So it should be easy to finish it.\nIn column 1 the only missing element is 5 so row 9 column 1 must be 5.\nIn column 5 the only missing element is 4 so row 9 column 5 must be 4.\nIn column 7 the only missing element is 6 so row 9 column 7 must be 6.\nIn block (1,'