# Setup

In [1]:
from transformer_lens import HookedTransformer
import torch as t
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if t.cuda.is_available():
    device = "cuda"
elif t.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
from safetensors import safe_open
from jaxtyping import Float
from torch import Tensor
from torch import nn
import einops
from dataclasses import dataclass
from torch.nn import functional as F
from safetensors import safe_open
import json

In [2]:
from interpreting_neurons_utils import *

# SAE Class

In [3]:
@dataclass
class AutoEncoderConfig:
    n_instances: int
    n_input_ae: int
    n_hidden_ae: int
    l1_coeff: float = 0.5
    tied_weights: bool = False
    weight_normalize_eps: float = 1e-8

class AutoEncoder(nn.Module):
    W_enc: Float[Tensor, "n_instances n_input_ae n_hidden_ae"]
    W_dec: Float[Tensor, "n_instances n_hidden_ae n_input_ae"]
    b_enc: Float[Tensor, "n_instances n_hidden_ae"]
    b_dec: Float[Tensor, "n_instances n_input_ae"]


    def __init__(self, cfg: AutoEncoderConfig):
        '''
        Initializes the two weights and biases according to the type signature above.

        If self.cfg.tied_weights = True, then we only create W_enc, not W_dec.
        '''
        super(AutoEncoder, self).__init__()
        self.cfg = cfg

        self.W_enc = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_instances, cfg.n_input_ae, cfg.n_hidden_ae))))
        if not(cfg.tied_weights):
            self.W_dec = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_instances, cfg.n_hidden_ae, cfg.n_input_ae))))

        self.b_enc = nn.Parameter(t.zeros(cfg.n_instances, cfg.n_hidden_ae))
        self.b_dec = nn.Parameter(t.zeros(cfg.n_instances, cfg.n_input_ae))

        self.to(device)


    def normalize_and_return_W_dec(self) -> Float[Tensor, "n_instances n_hidden_ae n_input_ae"]:
        '''
        If self.cfg.tied_weights = True, we return the normalized & transposed encoder weights.
        If self.cfg.tied_weights = False, we normalize the decoder weights in-place, and return them.

        Normalization should be over the `n_input_ae` dimension, i.e. each feature should have a noramlized decoder weight.
        '''
        if self.cfg.tied_weights:
            return self.W_enc.transpose(-1, -2) / (self.W_enc.transpose(-1, -2).norm(dim=1, keepdim=True) + self.cfg.weight_normalize_eps)
        else:
            self.W_dec.data = self.W_dec.data / (self.W_dec.data.norm(dim=2, keepdim=True) + self.cfg.weight_normalize_eps)
            return self.W_dec


    def forward(self, h: Float[Tensor, "batch_size n_instances n_input_ae"]):
        '''
        Runs a forward pass on the autoencoder, and returns several outputs.

        Inputs:
            h: Float[Tensor, "batch_size n_instances n_input_ae"]
                hidden activations generated from a Model instance

        Returns:
            l1_loss: Float[Tensor, "batch_size n_instances"]
                L1 loss for each batch elem & each instance (sum over the `n_hidden_ae` dimension)
            l2_loss: Float[Tensor, "batch_size n_instances"]
                L2 loss for each batch elem & each instance (take mean over the `n_input_ae` dimension)
            loss: Float[Tensor, ""]
                Sum of L1 and L2 loss (with the former scaled by `self.cfg.l1_coeff). We sum over the `n_instances`
                dimension but take mean over the batch dimension
            acts: Float[Tensor, "batch_size n_instances n_hidden_ae"]
                Activations of the autoencoder's hidden states (post-ReLU)
            h_reconstructed: Float[Tensor, "batch_size n_instances n_input_ae"]
                Reconstructed hidden states, i.e. the autoencoder's final output
        '''
        # Compute activations
        h_cent = h - self.b_dec
        acts = einops.einsum(
            h_cent, self.W_enc,
            "batch_size n_instances n_input_ae, n_instances n_input_ae n_hidden_ae -> batch_size n_instances n_hidden_ae"
        )
        acts = F.relu(acts + self.b_enc)

        # Compute reconstructed input
        h_reconstructed = einops.einsum(
            acts, self.normalize_and_return_W_dec(),
            "batch_size n_instances n_hidden_ae, n_instances n_hidden_ae n_input_ae -> batch_size n_instances n_input_ae"
        ) + self.b_dec

        # Compute loss, return values
        l2_loss = (h_reconstructed - h).pow(2).mean(-1) # shape [batch_size n_instances]
        l1_loss = acts.abs().sum(-1) # shape [batch_size n_instances]
        loss = (self.cfg.l1_coeff * l1_loss + l2_loss).mean(0).sum() # scalar

        return l1_loss, l2_loss, loss, acts, h_reconstructed
    

def load_autoencoder_from_path(path : str) -> AutoEncoder:
    config_path = path + "/cfg.json"
    weights_path = path + "/sae_weights.safetensors"
    with safe_open(weights_path, "pt") as f:
        state_dict = {k: f.get_tensor(k) for k in f.keys()}

    with open(config_path) as f:
        sae_lens_config = json.load(f)

    state_dict = {k : v.unsqueeze(0) for k, v in state_dict.items()}

    cfg = AutoEncoderConfig(
        n_instances = 1,
        n_input_ae = sae_lens_config["d_in"],
        n_hidden_ae = sae_lens_config["d_sae"],
    )

    # Initialize our model, and load in state dict
    autoencoder = AutoEncoder(cfg)
    autoencoder.load_state_dict(state_dict)
    autoencoder.b_dec.requires_grad = False
    autoencoder.b_enc.requires_grad = False
    autoencoder.W_enc.requires_grad = False
    autoencoder.W_dec.requires_grad = False

    return autoencoder

# Load SAEs

In [4]:
# SAEs
saes = {
    expansion_factor : load_autoencoder_from_path(f"sparse_autoencoder_{expansion_factor}") for expansion_factor in [1, 2, 4, 8, 16]
}

In [5]:
saes[2].W_enc.shape

torch.Size([1, 512, 1024])

In [6]:
# Get the k nearest features in some sae
def get_k_nearest_features_in_sae(probe_direction : Float[Tensor, "d_model"], expansion_factor, k, seed = 42):
    sae = saes[expansion_factor]
    probe_direction_norm = probe_direction / probe_direction.norm()
    similarities = []
    for feature_idx in range(expansion_factor * 512):
        sae_dec_direction = sae.W_dec[0, feature_idx, :]
        sae_dec_direction_norm = sae_dec_direction / sae_dec_direction.norm()
        similarity = (probe_direction_norm @ sae_dec_direction_norm)
        similarities.append((similarity, feature_idx))
    similarities.sort(reverse=True)
    indeces = [idx for _, idx in similarities]
    features_dec = [sae.W_dec[0, idx, :] for idx in indeces]
    features_enc = [sae.W_enc[0, :, idx] for idx in indeces]
    similarities = [sim for sim, _ in similarities]
    return indeces[:k], features_enc[:k], features_dec[:k], similarities[:k], 

In [7]:
def get_avg_similiarity_to_probe(probe_direction, features):
    probe_direction_norm = probe_direction / probe_direction.norm()
    similarities = []
    for feature in features:
        feature_norm = feature / feature.norm()
        similarity = (probe_direction_norm @ feature_norm)
        similarities.append(similarity)
    return sum(similarities) / len(similarities)

In [48]:
THRESH = 1

In [49]:
def get_acc_of_feature_set(probe_directions : Float[Tensor, "d_model options"], sae, feature_indeces, option : int):
    resid : Float[Tensor, "batch pos d_model"] = focus_cache["resid_mid", 2]
    resid = einops.rearrange(resid, "batch pos d_model -> (batch pos) d_model")
    probe_logits = einops.einsum(resid, probe_directions, "batch d_model, d_model options -> batch options")
    probe_probs = probe_logits.softmax(dim=-1)
    probe_preds = probe_probs.argmax(dim=-1)
    probe_preds = probe_preds == option
    
    _, _, _, acts, _ = sae(resid.unsqueeze(1))
    acts = acts[:, 0, feature_indeces]
    acts_binary = (acts > THRESH).float()
    acts_preds = acts_binary.sum(dim=-1) > 0.0
    acc = (acts_preds == probe_preds).float().mean()
    # also calculate the recall and precision
    tp = (acts_preds & probe_preds).float().sum()
    fp = (acts_preds & ~probe_preds).float().sum()
    fn = (~acts_preds & probe_preds).float().sum()
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    # also calculate the f1 score
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    scores = {
        "acc" : acc,
        "precision" : precision,
        "recall" : recall,
        "f1" : f1
    }
    return scores

In [50]:
def get_score(probe_directions, option, expansion_factor, k):
    specific_probe_direction = probe_directions[:, option]
    neaest_sae_feature_indeces, _, features_dec, _ = get_k_nearest_features_in_sae(specific_probe_direction, expansion_factor, k)
    print(neaest_sae_feature_indeces)
    score = get_acc_of_feature_set(probe_directions, saes[expansion_factor], neaest_sae_feature_indeces, option)
    return score

def get_best_k(probe_directions, expansion_factor, score_str, option):
    max_score = 0
    for k in range(1, 512 * expansion_factor):
        score = get_score(probe_directions, option, expansion_factor, k)
        if score[score_str] > max_score:
            print(f"Score: {score[score_str]}, {k}")
            max_score = score[score_str]
            best_k = k
        else:
            break
    return best_k

In [51]:
probe_directions = get_probe(2, "linear", "mid")[0, :, *label_to_tuple("B4")]
option = YOURS
expansion_factor = 4
specific_probe_direction = probe_directions[:, option]
print(probe_directions.shape)
print(specific_probe_direction.shape)
k = get_best_k(probe_directions, expansion_factor, "acc", option)

torch.Size([512, 3])
torch.Size([512])
[756]
Score: 0.6812429428100586, 1
[756, 319]
Score: 0.6822598576545715, 2
[756, 319, 1565]
Score: 0.683728814125061, 3
[756, 319, 1565, 1038]
Score: 0.6846327781677246, 4
[756, 319, 1565, 1038, 1717]
Score: 0.6861016750335693, 5
[756, 319, 1565, 1038, 1717, 2039]
Score: 0.6870056390762329, 6
[756, 319, 1565, 1038, 1717, 2039, 677]
Score: 0.6874576210975647, 7
[756, 319, 1565, 1038, 1717, 2039, 677, 1551]
Score: 0.689830482006073, 8
[756, 319, 1565, 1038, 1717, 2039, 677, 1551, 70]


In [11]:
'''neaest_sae_feature_indeces, _, features_dec, _ = get_k_nearest_features_in_sae(specific_probe_direction, expansion_factor, k)
print(neaest_sae_feature_indeces)
print(get_avg_similiarity_to_probe(specific_probe_direction, features_dec))
get_acc_of_feature_set(probe_directions, saes[expansion_factor], neaest_sae_feature_indeces, option)'''

'neaest_sae_feature_indeces, _, features_dec, _ = get_k_nearest_features_in_sae(specific_probe_direction, expansion_factor, k)\nprint(neaest_sae_feature_indeces)\nprint(get_avg_similiarity_to_probe(specific_probe_direction, features_dec))\nget_acc_of_feature_set(probe_directions, saes[expansion_factor], neaest_sae_feature_indeces, option)'

In [52]:
def print_features(probe_directions : Float[Tensor, "d_model options"], sae, option : int, k = 10):
    resid : Float[Tensor, "batch pos d_model"] = focus_cache["resid_mid", 2]
    resid = einops.rearrange(resid, "batch pos d_model -> (batch pos) d_model")
    probe_logits = einops.einsum(resid, probe_directions, "batch d_model, d_model options -> batch options")
    probe_probs = probe_logits.softmax(dim=-1)
    probe_preds = probe_probs.argmax(dim=-1)
    probe_preds = probe_preds == option
    
    _, _, _, acts, _ = sae(resid.unsqueeze(1))
    acts = acts[:, 0, :]
    # print(acts.shape)
    acts = acts[probe_preds]
    # print(acts.shape)
    acts_binary = (acts > THRESH).float()
    # get most active features
    acts_sum = acts_binary.sum(dim=0)
    top_features = acts_sum.argsort(descending=True)[:k]
    for acts_batch in range(acts_binary.shape[0]):
        # print all the top features that are active
        print(acts_binary[acts_batch, top_features])
        if acts_batch > 50:
            break

In [53]:
def get_mean_difference_feature_scores(probe_directions, option, expansion_factor):
    sae = saes[expansion_factor]#
    resid : Float[Tensor, "batch pos d_model"] = focus_cache["resid_mid", 2]
    resid = einops.rearrange(resid, "batch pos d_model -> (batch pos) d_model")
    probe_logits = einops.einsum(resid, probe_directions, "batch d_model, d_model options -> batch options")
    # probe_probs = probe_logits.softmax(dim=-1)
    probe_preds = probe_logits.argmax(dim=-1)
    probe_preds = probe_preds == option
    _, _, _, acts, _ = sae(resid.unsqueeze(1))
    acts = acts[:, 0, :]
    neuron_score = acts[probe_preds == 1].mean(dim=0) - acts[probe_preds == 0].mean(dim=0)
    neuron_score_indeces = neuron_score.argsort(descending=True)
    return neuron_score_indeces, neuron_score

In [54]:
def get_best_k_md(probe_directions, expansion_factor, score_str, option, debug = True):
    max_score = 0
    top_feature_indeces, _ = get_mean_difference_feature_scores(probe_directions, option, expansion_factor)
    for k in range(1, 512 * expansion_factor):
        score = get_acc_of_feature_set(probe_directions, saes[expansion_factor], top_feature_indeces[:k], option)
        if score[score_str] > max_score:
            if debug:
                print(f"Score: {score[score_str]}, {k}")
            max_score = score[score_str]
            best_k = k
        else:
            break
    return best_k

In [55]:
probe_directions = get_probe(2, "linear", "mid")[0, :, *label_to_tuple("B4")]
option = YOURS
expansion_factor = 2
specific_probe_direction = probe_directions[:, option]
print(probe_directions.shape)
print(specific_probe_direction.shape)
k = get_best_k_md(probe_directions, expansion_factor, "acc", option)

torch.Size([512, 3])
torch.Size([512])
Score: 0.6882485747337341, 1
Score: 0.6975141167640686, 2
Score: 0.7046327590942383, 3
Score: 0.7088135480880737, 4
Score: 0.7118644118309021, 5
Score: 0.7176271080970764, 6
Score: 0.7212429046630859, 7
Score: 0.7256497144699097, 8
Score: 0.7272316217422485, 9
Score: 0.7282485961914062, 10
Score: 0.7326553463935852, 11
Score: 0.7327683568000793, 12


In [16]:
# TODO: Find SAE Features that correspond to probe Directions. If I can't find this the SAE's are probably bad
# TODO: In the SAE's find combinations of features in a large SAE that correspond to a feature in a smaller SAE
#   Use Mean Difference to reduce the number of features to check, then go over each pair

In [17]:
'''expansion_factor = 4
for probe_name in ["linear", "flipped", "placed"]:
    if probe_name == "linear":
        score_str = "acc"
    else:
        score_str = "f1"
    for probe_direction in probe_directions_list[probe_name]:
        for row in range(1, 7):
            for col in range(1, 7):
                # print(get_probe(2, probe_name, "mid").shape)
                probes = get_probe(2, probe_name, "mid")[0, :, row, col]
                option = get_direction_int(probe_direction)
                top_feature_indeces, _ = get_mean_difference_feature_scores_sae(probes, option, expansion_factor)
                score = get_acc_of_feature_set(probes, saes[expansion_factor], top_feature_indeces[:1], option)
                # print(top_feature_indeces[:10])
                score = score[score_str]
                if (score > 0.7 and score_str == "acc") or (score > 0.6 and score_str == "f1"):
                    print(f"{tuple_to_label((row, col))} {probe_direction} {score}")
# MHhhh My SAE's seem to be bad :('''

'expansion_factor = 4\nfor probe_name in ["linear", "flipped", "placed"]:\n    if probe_name == "linear":\n        score_str = "acc"\n    else:\n        score_str = "f1"\n    for probe_direction in probe_directions_list[probe_name]:\n        for row in range(1, 7):\n            for col in range(1, 7):\n                # print(get_probe(2, probe_name, "mid").shape)\n                probes = get_probe(2, probe_name, "mid")[0, :, row, col]\n                option = get_direction_int(probe_direction)\n                top_feature_indeces, _ = get_mean_difference_feature_scores_sae(probes, option, expansion_factor)\n                score = get_acc_of_feature_set(probes, saes[expansion_factor], top_feature_indeces[:1], option)\n                # print(top_feature_indeces[:10])\n                score = score[score_str]\n                if (score > 0.7 and score_str == "acc") or (score > 0.6 and score_str == "f1"):\n                    print(f"{tuple_to_label((row, col))} {probe_direction} {sco

In [18]:
# Es ist doch nicht soo bad .., gibt einige Features mit 70 % accuracy .... Weiß aber nicht ob 

In [102]:
THRESH = 0.2

In [103]:
def get_score_for_all_pairs(sae1, feature_idx1, sae2, k):
    resid : Float[Tensor, "batch pos d_model"] = focus_cache["resid_mid", 2]
    resid = resid[:, 5:50, :]
    resid = einops.rearrange(resid, "batch pos d_model -> (batch pos) d_model")
    _, _, _, acts1, _ = sae1(resid.unsqueeze(1))
    _, _, _, acts2, _ = sae2(resid.unsqueeze(1))
    acts1 = acts1[:, 0, feature_idx1]
    labels = acts1 > THRESH
    neuron_score = acts2[labels == 1].mean(dim=0) - acts2[labels == 0].mean(dim=0)
    neuron_score_indeces = neuron_score.argsort(descending=True)[0]
    # print(neuron_score_indeces.shape)

    for feature_idx2_1 in neuron_score_indeces[:k]:
        labels1 = acts2[:, 0 , feature_idx2_1] > THRESH
        # print(feature_idx2_1, k)
        for feature_idx2_2 in neuron_score_indeces[feature_idx2_1:k]:
            labels2 = acts2[:, 0 , feature_idx2_2] > THRESH
            both_features_active = labels1 & labels2
            # calculate the accuracy
            acc = (both_features_active == labels).float().mean()
            # Also calculate the recall and precision
            tp = (both_features_active & labels).float().sum()
            fp = (both_features_active & ~labels).float().sum()
            fn = (~both_features_active & labels).float().sum()
            precision = tp / (tp + fp + 1e-8)
            recall = tp / (tp + fn + 1e-8)
            # And f1 score
            f1 = 2 * precision * recall / (precision + recall + 1e-8)
            if f1 > 0.80 and labels.sum() > 0 and labels1.sum() > 0 and labels2.sum() > 0:
                print(f"{feature_idx1} | {feature_idx2_1} {feature_idx2_2} {acc}, {f1}, {acts1.sum()}")

In [104]:
reduction_size = 20
expansion_factors = [1, 2, 4, 8, 16]
for i, expansion_factor in enumerate(expansion_factors):
    for expansion_factor2 in expansion_factors[i+1:]:
        print(f"----------------------------{expansion_factor} {expansion_factor2}--------------------")
        for feature_idx in range(512 * expansion_factor):
            sae1 = saes[expansion_factor]
            sae2 = saes[expansion_factor2]
            get_score_for_all_pairs(sae1, feature_idx, sae2, reduction_size)

----------------------------1 2--------------------
369 | 12 392 1.0, 1.0, 0.40136751532554626
369 | 12 877 1.0, 1.0, 0.40136751532554626
369 | 12 40 1.0, 1.0, 0.40136751532554626
369 | 12 582 1.0, 1.0, 0.40136751532554626
369 | 12 576 1.0, 1.0, 0.40136751532554626
369 | 12 182 1.0, 1.0, 0.40136751532554626
369 | 12 974 1.0, 1.0, 0.40136751532554626
----------------------------1 4--------------------
192 | 1 1500 1.0, 1.0, 162.98788452148438
----------------------------1 8--------------------
----------------------------1 16--------------------


OutOfMemoryError: CUDA out of memory. Tried to allocate 212.00 MiB. GPU 

In [None]:
'''
36 | 51 591 0.9996610283851624, 0.9900989532470703
36 | 0 591 0.9997739791870117, 0.9933774471282959
67 | 51 591 0.9996610283851624, 0.9900989532470703
67 | 0 591 0.9997739791870117, 0.9933774471282959
115 | 0 591 0.9996610283851624, 0.9900990128517151
183 | 0 591 0.9996610283851624, 0.9900990128517151
307 | 0 591 0.9996610283851624, 0.9900990128517151
331 | 0 591 0.9996610283851624, 0.9900990128517151
370 | 0 591 0.9996610283851624, 0.9900990128517151
505 | 51 591 0.9996610283851624, 0.9900989532470703
505 | 0 591 0.9997739791870117, 0.9933774471282959
'''
# TODO: Analyze this
# TODO: vizsualize the most activating dataset examples


In [106]:
layer = 2 # 0
feature_idx = 192  # 368
expansion_factor = 1

# Promising: 192 | 1 1500 1.0, 1.0, 162.98788452148438 Ich müsste mir die Dings anschauen though (Die vizualisierung von Linear Probes)

# get_fraction_of_variance_from_neuron_explained_by_probe(neuron = neuron, layer = 1)
def get_max_acitvations_of_feature(
    layer: int,
    feature_idx: int,
    expansion_factor: int,
    num_activations: int = 10,
    random = False,
    thresh = 1,
) -> Tuple[Float[Tensor, "game move"], Float[Tensor, "game move"]]:
    '''
    Returns the top activations for a given neuron in a given layer.
    '''
    sae = saes[expansion_factor]
    resid : Float[Tensor, "batch pos d_model"] = focus_cache["resid_post", layer]
    resid = resid[:, 5:50, :]
    batch_size, seq_len, _ = resid.shape
    resid = einops.rearrange(resid, "batch pos d_model -> (batch pos) d_model")
    _, _, _, acts1, _ = sae(resid.unsqueeze(1))
    acts1 = acts1[:, 0, feature_idx]
    top_activations = acts1.argsort(descending=True)
    activation_values = acts1.sort(descending=True)
    # set positive_activations to num_activation random examples
    if random:
        positive_count = (activation_values.values > thresh).sum()
        positive_activations = top_activations[:positive_count]
        random_indeces = (t.randperm(positive_count)[:num_activations]).to(t.long)
        top_activations = positive_activations[random_indeces]
        activation_values = acts1[top_activations]
    else:
        top_activations = top_activations[:num_activations]
    top_games = top_activations // seq_len
    top_moves = top_activations % seq_len
    print(f"Top Game: \n{top_games}")
    print(f"Top Move: \n{top_moves}")
    print(f"Activation values: \n{activation_values}")
    return top_games, top_moves, activation_values

top_games, top_moves, activation_values = get_max_acitvations_of_feature(layer, feature_idx, expansion_factor, 10, random=True, thresh = THRESH)
for i in range(len(top_games)):
    game = top_games[i].item()
    pos = top_moves[i].item()
    activation = activation_values[i].item()
    vis_args = VisualzeBoardArguments()
    vis_args.start_pos = pos
    vis_args.end_pos = pos+1
    vis_args.include_layer_norm = True
    vis_args.include_pre_resid = False
    vis_args.layers = 7
    print(f"Game: {game}, Pos: {pos}, Activation: {activation}")
    visualize_game(
        focus_games_string[game, :59],
        vis_args,
        model,
    )
    # break

Top Game: 
tensor([ 33,  40,  22, 122, 104,  40,  61,  17,  43,  87], device='cuda:0')
Top Move: 
tensor([27, 19, 36, 32,  8, 31, 18,  7, 29, 28], device='cuda:0')
Activation values: 
tensor([0.2515, 0.2404, 0.2133, 0.2597, 0.2153, 0.2169, 0.2402, 0.2010, 0.2301,
        0.3016], device='cuda:0')
Game: 33, Pos: 27, Activation: 0.2515352666378021
torch.Size([7, 3, 59, 8, 8])


Game: 40, Pos: 19, Activation: 0.24042001366615295
torch.Size([7, 3, 59, 8, 8])


Game: 22, Pos: 36, Activation: 0.2132743000984192
torch.Size([7, 3, 59, 8, 8])


Game: 122, Pos: 32, Activation: 0.259714812040329
torch.Size([7, 3, 59, 8, 8])


Game: 104, Pos: 8, Activation: 0.21533086895942688
torch.Size([7, 3, 59, 8, 8])


Game: 40, Pos: 31, Activation: 0.21690693497657776
torch.Size([7, 3, 59, 8, 8])


Game: 61, Pos: 18, Activation: 0.24024900794029236
torch.Size([7, 3, 59, 8, 8])


Game: 17, Pos: 7, Activation: 0.20096108317375183
torch.Size([7, 3, 59, 8, 8])


Game: 43, Pos: 29, Activation: 0.2301207184791565
torch.Size([7, 3, 59, 8, 8])


Game: 87, Pos: 28, Activation: 0.3015708327293396
torch.Size([7, 3, 59, 8, 8])
