In [1]:
from utils import dsl
from utils import constants
from utils.dsl import height, width
import json
import numpy as np
from utils.dsl import *

from utils.constants import *

In [10]:
task_id = 'e376de54'
def solve(I):
    # Get grid dimensions and background color
    h_w = shape(I)
    h, w = h_w[0], h_w[1]
    bg = mostcolor(I)
    
    # Create blank output grid
    O = canvas(bg, h_w)
    
    # Find connected objects with 4-connect and 8-connect to detect if diagonal
    objs4 = objects(I, True, False, True)
    objs8 = objects(I, True, True, True)
    if size(objs8) < size(objs4):
        objs = objs8
    else:
        objs = objs4
    
    # Classify dominant direction by total cells covered in each line type
    dir_counts = {}
    total_covered = 0
    total_cells = sum(size(obj) for obj in objs)
    for obj in objs:
        pos_set = toindices(obj)
        n = size(pos_set)
        rows_list = [r for r, c in pos_set]
        cols_list = [c for r, c in pos_set]
        min_r = min(rows_list)
        max_r = max(rows_list)
        min_c = min(cols_list)
        max_c = max(cols_list)
        classified = False
        
        # Check horizontal
        if min_r == max_r and n == max_c - min_c + 1:
            filled = all((min_r, j) in pos_set for j in range(min_c, max_c + 1))
            if filled:
                if 'h' not in dir_counts:
                    dir_counts['h'] = 0
                dir_counts['h'] += n
                total_covered += n
                classified = True
        
        # Check vertical
        if not classified and min_c == max_c and n == max_r - min_r + 1:
            filled = all((i, min_c) in pos_set for i in range(min_r, max_r + 1))
            if filled:
                if 'v' not in dir_counts:
                    dir_counts['v'] = 0
                dir_counts['v'] += n
                total_covered += n
                classified = True
        
        # Check anti-diagonal (/)
        if not classified:
            sums_set = {r + c for r, c in pos_set}
            if len(sums_set) == 1:
                s = next(iter(sums_set))
                if max_r - min_r + 1 == n:
                    filled = all((r, s - r) in pos_set for r in range(min_r, max_r + 1)) and all(0 <= s - r < w for r in range(min_r, max_r + 1))
                    if filled:
                        if 'a' not in dir_counts:
                            dir_counts['a'] = 0
                        dir_counts['a'] += n
                        total_covered += n
                        classified = True
        
        # Check main-diagonal (\)
        if not classified:
            diffs_set = {r - c for r, c in pos_set}
            if len(diffs_set) == 1:
                d = next(iter(diffs_set))
                if max_r - min_r + 1 == n:
                    filled = all((r, r - d) in pos_set for r in range(min_r, max_r + 1)) and all(0 <= r - d < w for r in range(min_r, max_r + 1))
                    if filled:
                        if 'm' not in dir_counts:
                            dir_counts['m'] = 0
                        dir_counts['m'] += n
                        total_covered += n
                        classified = True
    
    # Determine dominant direction (assumes total_covered == total_cells)
    if dir_counts:
        dominant_dir = max(dir_counts, key=dir_counts.get)
    else:
        return I  # No lines, return input
    
    # Determine reference length L
    nine_objs = colorfilter(objs, 9)
    is_global = bool(nine_objs)
    if is_global:
        L_global = size(next(iter(nine_objs)))
    else:
        color_groups = {}
        for obj in objs:
            c = color(obj)
            if c not in color_groups:
                color_groups[c] = []
            color_groups[c].append(obj)
        color_to_L = {}
        for c, group in color_groups.items():
            lengths = [size(toindices(o)) for o in group]
            sorted_l = sorted(lengths)
            ng = len(sorted_l)
            if ng > 0:
                med = sorted_l[ng // 2]
                color_to_L[c] = med
    
    # Adjust and paint each object
    for obj in objs:
        c = color(obj)
        if is_global:
            this_L = L_global
        else:
            this_L = color_to_L.get(c, 1)
        pos_set = toindices(obj)
        if size(pos_set) == 0:
            continue
        rows_list = [r for r, _ in pos_set]
        cols_list = [c_val for _, c_val in pos_set]
        min_r = min(rows_list)
        max_r = max(rows_list)
        min_c = min(cols_list)
        max_c = max(cols_list)
        new_pos_list = []
        
        if dominant_dir == 'h':
            row = min_r
            fixed_c = min_c
            new_max_c = min(fixed_c + this_L - 1, w - 1)
            for j in range(fixed_c, new_max_c + 1):
                new_pos_list.append((row, j))
        elif dominant_dir == 'v':
            colu = min_c
            fixed_r = min_r
            new_max_r = min(fixed_r + this_L - 1, h - 1)
            for i in range(fixed_r, new_max_r + 1):
                new_pos_list.append((i, colu))
        elif dominant_dir == 'a':
            s = min_r + min_c
            fixed_r = max_r
            new_min_r = max(0, fixed_r - this_L + 1)
            for r in range(new_min_r, fixed_r + 1):
                cc = s - r
                if 0 <= cc < w:
                    new_pos_list.append((r, cc))
        elif dominant_dir == 'm':
            d = min_r - min_c
            fixed_r = max_r
            new_min_r = max(0, fixed_r - this_L + 1)
            for r in range(new_min_r, fixed_r + 1):
                cc = r - d
                if 0 <= cc < w:
                    new_pos_list.append((r, cc))
        
        # Fill the new positions
        if new_pos_list:
            new_pos_fs = frozenset(new_pos_list)
            O = fill(O, c, new_pos_fs)
    
    return O

In [11]:
puzzle_index = 0
with open(f'../data_v2/evaluation/{task_id}.json') as f:
    task = json.load(f)
I = task['train'][puzzle_index]['input']
expected = task['train'][puzzle_index]['output']
I=tuple(tuple(row) for row in I)
output = solve(I)

In [12]:
for row in output:
    print(row)

(7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 7, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 2, 7, 7, 7, 7, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(2, 7, 7, 7, 7, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 7, 7, 9, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 7, 9, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 9, 1, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 9, 1, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(9, 1, 7, 7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 7)
(1, 7, 7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7)
(7, 7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7, 7, 7, 7)
(7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)


In [13]:
for row in expected:
    print(row)

[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
[7, 7, 7, 7, 7, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
[7, 7, 7, 7, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
[7, 7, 7, 2, 7, 7, 7, 2, 7, 7, 7, 7, 7, 7, 7, 7]
[7, 7, 2, 7, 7, 7, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7]
[7, 2, 7, 7, 7, 2, 7, 7, 7, 9, 7, 7, 7, 7, 7, 7]
[2, 7, 7, 7, 2, 7, 7, 7, 9, 7, 7, 7, 7, 7, 7, 7]
[7, 7, 7, 2, 7, 7, 7, 9, 7, 7, 7, 1, 7, 7, 7, 7]
[7, 7, 2, 7, 7, 7, 9, 7, 7, 7, 1, 7, 7, 7, 7, 7]
[7, 7, 7, 7, 7, 9, 7, 7, 7, 1, 7, 7, 7, 1, 7, 7]
[7, 7, 7, 7, 9, 7, 7, 7, 1, 7, 7, 7, 1, 7, 7, 7]
[7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 1, 7, 7, 7, 7]
[7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 1, 7, 7, 7, 7, 7]
[7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7]
[7, 7, 7, 7, 7, 7, 7, 7, 1, 7, 7, 7, 7, 7, 7, 7]


In [14]:
expected_np = np.array(expected)
output_np = np.array(output)

diff_idx = []

if np.ndim(expected_np) != np.ndim(output_np) or expected_np.shape != output_np.shape:
    print("Dimension mismatch between expected and output.")
    print("expected shape:", expected_np.shape)

else:
    diffs = np.where(expected_np != output_np)
    # print differences as ascii art
    diff_grid = np.full(expected_np.shape, ' ')
    for r, c in zip(*diffs):
        diff_grid[r, c] = 'X'
        diff_idx.append((r, c))
    for row in diff_grid:
        print(''.join(row))  
        
print("Diff indices:", diff_idx)

                
                
   X X          
  X X           
 X X  XX        
X X  XX         
 X  XX   X      
X  XX   X       
  XX   X   X    
 XX   X   X     
XX   X   XX  X  
X   X   XX  X   
       XX  X    
      XX  X     
      X  X      
     X  X       
Diff indices: [(np.int64(2), np.int64(3)), (np.int64(2), np.int64(5)), (np.int64(3), np.int64(2)), (np.int64(3), np.int64(4)), (np.int64(4), np.int64(1)), (np.int64(4), np.int64(3)), (np.int64(4), np.int64(6)), (np.int64(4), np.int64(7)), (np.int64(5), np.int64(0)), (np.int64(5), np.int64(2)), (np.int64(5), np.int64(5)), (np.int64(5), np.int64(6)), (np.int64(6), np.int64(1)), (np.int64(6), np.int64(4)), (np.int64(6), np.int64(5)), (np.int64(6), np.int64(9)), (np.int64(7), np.int64(0)), (np.int64(7), np.int64(3)), (np.int64(7), np.int64(4)), (np.int64(7), np.int64(8)), (np.int64(8), np.int64(2)), (np.int64(8), np.int64(3)), (np.int64(8), np.int64(7)), (np.int64(8), np.int64(11)), (np.int64(9), np.int64(1)), (np.int64(9)