In [126]:
GATED_CONV_LAYERS = [
    "conv_seqs.0.res_block0",
    "conv_seqs.0.res_block1",
    "conv_seqs.1.res_block0",
    "conv_seqs.1.res_block1",
    "conv_seqs.2.res_block0",
    "conv_seqs.2.res_block1"
]

import torch
def layer_to_b_tensor(W=None, V=None, layer_label=None, state_dict=None, device='cpu'):
    if W is not None and V is not None:
        W = W.to(dtype=torch.float64, device=device)
        V = V.to(dtype=torch.float64, device=device)
        
        c_out, c_in, k, _ = W.shape
        B = torch.zeros(c_out, c_in*k*k, c_in*k*k, dtype=torch.float64, device=device)

        for l in range(c_out):
            W_outl = W[l]
            V_outl = V[l]
            for i in range(c_in):
                for j in range(c_in):
                    W_i = W_outl[i]
                    V_j = V_outl[j]
                    W_i_f = W_i.reshape(-1)
                    V_j_f = V_j.reshape(-1)
                    block = torch.outer(W_i_f, V_j_f)
                    B[l, i*k*k:(i+1)*k*k, j*k*k:(j+1)*k*k] = block

        B_sym = torch.zeros_like(B, dtype=torch.float64)
        for o in range(c_out):
            B_sym[o] = 0.5 * (B[o] + B[o].T)
        return B_sym
    
    elif layer_label is not None and state_dict is not None:
        if layer_label in GATED_CONV_LAYERS:
            W = state_dict[layer_label + ".conv0.weight"].to(dtype=torch.float64, device=device)
            V = state_dict[layer_label + ".conv1.weight"].to(dtype=torch.float64, device=device)
            return layer_to_b_tensor(W=W, V=V, device=device)
    else:
        raise ValueError("Either provide W and V or layer_label and state_dict")

def b_tensor_decomp(b_tensor, out_vector, topk=None, specified_idxs=None, device='cpu'):
    b_tensor = b_tensor.to(dtype=torch.float64, device=device)
    out_vector = out_vector.to(dtype=torch.float64, device=device)
    b_tensor = torch.einsum("oij, o-> ij", b_tensor, out_vector)
    eigvals, eigvecs = torch.linalg.eigh(b_tensor)

    
    if specified_idxs is not None:
        specified_eigvals = eigvals[specified_idxs]
        specified_eigvecs = eigvecs[:, specified_idxs]
        return specified_eigvals, specified_eigvecs
        
    if topk is None:
        topk = eigvecs.shape[1]
        
    sorted_indices = torch.argsort(torch.abs(eigvals), descending=True)
    topk_indices = sorted_indices[:topk]
    topk_eigvals = eigvals[topk_indices]
    topk_eigvecs = eigvecs[:, topk_indices]
    return topk_eigvals, topk_eigvecs

def proj_activ_onto_out_vector(x, out_vector, device='cpu'):
    x = x.to(dtype=torch.float64, device=device)
    out_vector = out_vector.to(dtype=torch.float64, device=device)
    return torch.einsum("ohw,o->hw", x, out_vector)

def ablate_eigs(layer_label, state_dict, x, out_vector, idxs, invert=False, device='cpu'):
    x = x.to(dtype=torch.float64, device=device)
    out_vector = out_vector.to(dtype=torch.float64, device=device)
    
    B = layer_to_b_tensor(layer_label=layer_label, state_dict=state_dict, device=device)
    eigvals, eigvecs = b_tensor_decomp(b_tensor=B, out_vector=out_vector, topk=None, 
                                     specified_idxs=idxs, device=device)

    _, in_chan, k, _ = state_dict[layer_label + ".conv0.weight"].shape
    eigvecs = eigvecs.reshape(in_chan, k, k, eigvecs.shape[-1])

    contrib_specific_eigvecs = torch.zeros_like(x[0], dtype=torch.float64, device=device)
    
    for eig_idx in range(eigvecs.shape[-1]):
        eigvec = eigvecs[:,:,:,eig_idx]
        padding = (k - 1) // 2
        conv = torch.nn.Conv2d(in_channels=in_chan, out_channels=1, kernel_size=k, bias=False, 
                             padding=padding, stride=1).to(dtype=torch.float64, device=device)
        conv.weight = torch.nn.Parameter(eigvec.unsqueeze(0))
        conv_output = conv(x).squeeze(0)
        contrib = eigvals[eig_idx] * (conv_output**2)
        contrib_specific_eigvecs += contrib
        
    if invert:
        return contrib_specific_eigvecs
    else:
        W = state_dict[layer_label + ".conv0.weight"]
        V = state_dict[layer_label + ".conv1.weight"]
        W_conv = torch.nn.Conv2d(in_channels=W.shape[1], out_channels=W.shape[0], kernel_size=W.shape[2], bias=False, padding=(W.shape[2]-1)//2, stride=1)
        V_conv = torch.nn.Conv2d(in_channels=V.shape[1], out_channels=V.shape[0], kernel_size=V.shape[2], bias=False, padding=(V.shape[2]-1)//2, stride=1)
        W_conv.weight = torch.nn.Parameter(W)
        V_conv.weight = torch.nn.Parameter(V)
        total_contrib = W_conv(x) * V_conv(x)
        return proj_activ_onto_out_vector(total_contrib,out_vector) - contrib_specific_eigvecs

In [129]:
# Set device
device = 'cpu'

# Initialize random tensors with float64 on specified device
W = torch.rand(32, 32, 7, 7, dtype=torch.float64, device=device)
V = torch.rand(32, 32, 7, 7, dtype=torch.float64, device=device)
x = torch.rand(32, 8, 8, dtype=torch.float64, device=device)  # Adding batch dimension

state_dict = {}
state_dict["conv_seqs.2.res_block0.conv0.weight"] = W
state_dict["conv_seqs.2.res_block0.conv1.weight"] = V

# Set up convolutions
conv_a = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=7, bias=False, padding=3, stride=1)
conv_b = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=7, bias=False, padding=3, stride=1)

# Move convs to device and float64
conv_a = conv_a.to(dtype=torch.float64, device=device)
conv_b = conv_b.to(dtype=torch.float64, device=device)

# Set weights
conv_a.weight = torch.nn.Parameter(W)
conv_b.weight = torch.nn.Parameter(V)

# Create indices on device
indices = torch.arange(0, 32*7*7, device=device)

# Initialize out_vector in float64 on device
i = torch.randint(0, 32, (1,)).item()
out_vector = torch.zeros(32, dtype=torch.float64, device=device)
out_vector[i] = 1.0

# Calculate contributions
channel_contrib_channel0 = ablate_eigs(
    layer_label="conv_seqs.2.res_block0",
    state_dict=state_dict,
    x=x, 
    out_vector=out_vector, 
    idxs=indices, 
    invert=True
)

# Compute using convolutions
output_using_conv_weights = conv_a(x) * conv_b(x)
output_using_conv_weights = output_using_conv_weights[i]  # Get first channel of first batch

print(torch.max(torch.abs(output_using_conv_weights - channel_contrib_channel0)).item())
assert torch.allclose(output_using_conv_weights, channel_contrib_channel0, atol=1e-6, rtol=1e-5)

8.440110832452774e-10
