In [None]:
import sys

sys.path.append('..')
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from itertools import combinations
from utils import SyntheticDataset
from model import Transformer

seed = 0
torch.manual_seed(seed)
np.random.seed(seed)


# Set up attention extraction
class SaveOutput:
  def __init__(self):
    self.outputs = []

  def __call__(self, module, module_in, module_out):
    self.outputs.append(module_out[1][:, :, :-1, :-1])

  def clear(self):
    self.outputs = []


# Patch attention to always output weights explicitly
def patch_attention(m):
  forward_orig = m.forward

  def wrap(*args, **kwargs):
    kwargs['need_weights'] = True
    kwargs['average_attn_weights'] = False
    return forward_orig(*args, **kwargs)

  m.forward = wrap


# Model initialization with pre-trained weights
model = Transformer(evals=True, times=True)
best_path = '../results/weights/Transformer(evals, times)_generated_12_10000_60_0.5_64_0.0005_0.0001_20.pt'
model.load_state_dict(torch.load(best_path))
model.eval()
save_output = SaveOutput()
patch_attention(model.transformer_encoder.layers[-1].self_attn)
hook_handle = model.transformer_encoder.layers[-1].self_attn.register_forward_hook(save_output)

dataset = SyntheticDataset(limit=100, num_moves=60, engine_prob=0.2)

# Search for best attention-head combination
best_corr, best_data, best_heads = -np.inf, None, None

for data in dataset:
  moves, evals, times, move_labels, _ = data
  moves, evals, times = moves.unsqueeze(0), evals.unsqueeze(0), times.unsqueeze(0)

  # even distribution
  cheat_indices = torch.where(move_labels == 1)[0]
  if len(cheat_indices) > 0 and cheat_indices[-1] < 55:
    continue

  if len(cheat_indices) < 9 or len(cheat_indices) > 15:
    continue

  save_output.clear()
  with torch.no_grad():
    _ = model(moves, evals, times)

  attn = save_output.outputs[0][0].cpu().numpy()

  for r in range(1, attn.shape[0] + 1):
    for head_combo in combinations(range(attn.shape[0]), r):
      combo_attn_mean = attn[list(head_combo)].mean(axis=(0, 1))
      corr, _ = pearsonr(combo_attn_mean, move_labels.numpy())

      if corr > best_corr:
        best_corr = corr
        best_data = data
        best_heads = head_combo

print(f'Best correlation: {best_corr:.4f}, heads used: {best_heads}')

# Generate final plot
moves, evals, times, move_labels, _ = best_data
moves, evals, times = moves.unsqueeze(0), evals.unsqueeze(0), times.unsqueeze(0)

save_output.clear()
with torch.no_grad():
  _ = model(moves, evals, times)

final_attention = save_output.outputs[0][0][list(best_heads)].mean(dim=(0, 1)).cpu().numpy()

# Improved visualization
plt.figure(figsize=(14, 6))
plt.imshow(final_attention[np.newaxis, :], cmap='viridis', aspect='auto')
plt.yticks([])
plt.xlabel('Move Index', fontsize=14)
plt.ylabel('Attention Weight Distribution', fontsize=14)
plt.colorbar(label='Normalized Attention Weight', shrink=0.75)

for idx, label in enumerate(move_labels):
  if label == 1:
    plt.axvline(x=idx, color='red', linestyle='--', linewidth=2, label='Cheat Move' if idx == 0 else '')

plt.legend(loc='upper right', fontsize=12)
plt.title(f'Optimal Attention Heads for Cheat Detection (Correlation: {best_corr:.2f})', fontsize=16)
plt.grid(axis='x', linestyle=':', linewidth=0.5)
plt.tight_layout()
plt.show()

# Cleanup
hook_handle.remove()

  model.load_state_dict(torch.load(best_path))


Best correlation: -inf, heads used: None


TypeError: cannot unpack non-iterable NoneType object