In [None]:
import torch

# Path to your checkpoint file
checkpoint_file = 'epoch_48.pth'

# Load the checkpoint
checkpoint = torch.load(checkpoint_file, map_location='cpu')

# Check if 'state_dict' is in the checkpoint
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    # If 'state_dict' is not found, use the checkpoint directly
    state_dict = checkpoint

# Print out parameter names and shapes
for name, param in state_dict.items():
    print(f"Parameter name: {name}, Shape: {param.shape}")

# Optionally, print out keys in the checkpoint
print("Keys in checkpoint:", checkpoint.keys())

  from .autonotebook import tqdm as notebook_tqdm


Parameter name: arch.backbone.front_end.0.weight, Shape: torch.Size([16, 1, 5, 5])
Parameter name: arch.backbone.front_end.0.bias, Shape: torch.Size([16])
Parameter name: arch.backbone.front_end.1.weight, Shape: torch.Size([16])
Parameter name: arch.backbone.front_end.1.bias, Shape: torch.Size([16])
Parameter name: arch.backbone.front_end.1.running_mean, Shape: torch.Size([16])
Parameter name: arch.backbone.front_end.1.running_var, Shape: torch.Size([16])
Parameter name: arch.backbone.front_end.1.num_batches_tracked, Shape: torch.Size([])
Parameter name: arch.backbone.front_end.2.weight, Shape: torch.Size([1])
Parameter name: arch.backbone.front_end.3.weight, Shape: torch.Size([32, 16, 5, 5])
Parameter name: arch.backbone.front_end.3.codebook, Shape: torch.Size([256, 16])
Parameter name: arch.backbone.front_end.3.encoded_vector, Shape: torch.Size([800])
Parameter name: arch.backbone.front_end.4.weight, Shape: torch.Size([32])
Parameter name: arch.backbone.front_end.4.bias, Shape: torch

In [None]:
extracted_weights = []

for block_num in range(8):  # Iterate through backend blocks 0 to 5
    # Memory Block
    memory_codebook_key = f'arch.backbone.backbone.{block_num}.memory.0.codebook'
    memory_encoded_vector_key = f'arch.backbone.backbone.{block_num}.memory.0.encoded_vector'
    memory_dict = {}
    if memory_codebook_key in state_dict:
      print(memory_codebook_key)
      memory_dict['codebook'] = state_dict[memory_codebook_key]
    if memory_encoded_vector_key in state_dict:
      print(memory_encoded_vector_key)
      memory_dict['encoded_vector'] = state_dict[memory_encoded_vector_key]
    extracted_weights.append(memory_dict)

    # FC Trans Blocks (two of them)
    for fc_trans_num in [0, 4]:  # The relevant fc_trans blocks have indices 0 and 4
        fc_trans_codebook_key = f'arch.backbone.backbone.{block_num}.fc_trans.{fc_trans_num}.codebook'
        fc_trans_encoded_vector_key = f'arch.backbone.backbone.{block_num}.fc_trans.{fc_trans_num}.encoded_vector'
        fc_trans_dict = {}
        if fc_trans_codebook_key in state_dict:
          print(fc_trans_codebook_key)
          fc_trans_dict['codebook'] = state_dict[fc_trans_codebook_key]
        if fc_trans_encoded_vector_key in state_dict:
          print(fc_trans_encoded_vector_key)
          fc_trans_dict['encoded_vector'] = state_dict[fc_trans_encoded_vector_key]
        extracted_weights.append(fc_trans_dict)

# Now extracted_weights contains the list of dictionaries as you described.

# Print the extracted weights
#print(extracted_weights)

arch.backbone.backbone.0.memory.0.codebook
arch.backbone.backbone.0.memory.0.encoded_vector
arch.backbone.backbone.0.fc_trans.0.codebook
arch.backbone.backbone.0.fc_trans.0.encoded_vector
arch.backbone.backbone.0.fc_trans.4.codebook
arch.backbone.backbone.0.fc_trans.4.encoded_vector
arch.backbone.backbone.1.memory.0.codebook
arch.backbone.backbone.1.memory.0.encoded_vector
arch.backbone.backbone.1.fc_trans.0.codebook
arch.backbone.backbone.1.fc_trans.0.encoded_vector
arch.backbone.backbone.1.fc_trans.4.codebook
arch.backbone.backbone.1.fc_trans.4.encoded_vector
arch.backbone.backbone.2.memory.0.codebook
arch.backbone.backbone.2.memory.0.encoded_vector
arch.backbone.backbone.2.fc_trans.0.codebook
arch.backbone.backbone.2.fc_trans.0.encoded_vector
arch.backbone.backbone.2.fc_trans.4.codebook
arch.backbone.backbone.2.fc_trans.4.encoded_vector
arch.backbone.backbone.3.memory.0.codebook
arch.backbone.backbone.3.memory.0.encoded_vector
arch.backbone.backbone.3.fc_trans.0.codebook
arch.backbo

In [4]:
torch.save(extracted_weights, 'extracted_weights.pth')

print("Extracted weights saved to extracted_weights.pth")

Extracted weights saved to extracted_weights.pth


In [5]:
codebook_lookup_table = []
for i, block_data in enumerate(extracted_weights):
    if 'codebook' in block_data:
        codebook = block_data['codebook']
        # Convert to NumPy array for easier handling
        codebook_numpy = codebook.numpy() #.tolist() #.astype(np.float32)
        codebook_lookup_table.append(codebook_numpy)

# codebook_lookup_table is now a list where each element is a NumPy array
# representing the codebook for that block.

# Double-check the structure:
print(f"Number of blocks with codebooks: {len(codebook_lookup_table)}")
print(f"Shape of the first codebook: {codebook_lookup_table[0].shape}")  # Should be (256, 16)

Number of blocks with codebooks: 24
Shape of the first codebook: (256, 16)


In [None]:
import numpy as np
def pack_binary_to_int16(binary_vector):
    """Packs a 16-element binary vector into a signed int16."""
    packed_value = 0
    for i, bit in enumerate(binary_vector):
        if bit == 1:  # If the bit is 1
            packed_value |= (1 << 15-i)  # Set the i-th bit
    # No need to explicitly handle negative sign as int16 handles signed representation

    return packed_value

packed_codebook_lookup_table = []
for codebook_numpy in codebook_lookup_table:
    packed_codebook = np.array([pack_binary_to_int16(row) for row in codebook_numpy], dtype=np.int16)
    packed_codebook_lookup_table.append(packed_codebook)

# packed_codebook_lookup_table is now a list where each element is a NumPy array
# of int16 values.  Each int16 represents a packed binary vector.

print(f"Shape of the first packed codebook: {packed_codebook_lookup_table[0].shape}")  # Should be (256,)
print(f"Data type of packed codebook: {packed_codebook_lookup_table[0].dtype}")  # Should be int16

Shape of the first packed codebook: (256,)
Data type of packed codebook: int16


For the old behavior, usually:
    np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
  packed_codebook = np.array([pack_binary_to_int16(row) for row in codebook_numpy], dtype=np.int16)
For the old behavior, usually:
    np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
  packed_codebook = np.array([pack_binary_to_int16(row) for row in codebook_numpy], dtype=np.int16)
For the old behavior, usually:
    np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
  packed_codebook = np.array([pack_binary_to_int16(row) for row in codebook_numpy], dtype=np.int16)
For the old behavior, usually:
    np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
  packed_codebook = np.array([pack_binary_to_int16(row) for row in codebook_numpy], dtype=np.int16)
For the old behavior, usually:
    np.array(value).astype(dtype)`
will give the desired result (the cast overflows).
  packed_codebo

In [None]:
print(codebook_lookup_table[0][69])
print(type(packed_codebook_lookup_table[0][69]))

[-1. -1.  1. -1. -1. -1. -1.  1.  1.  1. -1. -1.  1. -1.  1. -1.]
<class 'numpy.int16'>


In [35]:
np_array = np.array(packed_codebook_lookup_table)
np.save("precomputed_lut.npy", np_array)

In [30]:
precomputed_lut = [i.tolist() for i in packed_codebook_lookup_table] 

In [32]:
import numpy as np
precomputed_lut = np.load("precomputed_lut.npy")
value =precomputed_lut[0][3]
print(value)
binary_representation = np.binary_repr(value, width=16)
print(binary_representation)

4365
0001000100001101


In [25]:
def calculate_result(binary_str1, binary_str2):
    """
    Calculate the result of 32 - (popcount * 2) for two 32-bit binary strings.

    Args:
        binary_str1 (str): A 32-bit binary string.
        binary_str2 (str): A 32-bit binary string.

    Returns:
        int: The result of the calculation.
    """
    # Ensure both inputs are valid 32-bit binary strings
    if len(binary_str1) != 32 or len(binary_str2) != 32:
        raise ValueError("Both inputs must be 32-bit binary strings.")
    if not all(c in '01' for c in binary_str1 + binary_str2):
        raise ValueError("Inputs must only contain '0' and '1'.")

    # Perform bitwise AND operation
    result = int(binary_str1, 2) ^ int(binary_str2, 2)

    # Calculate the popcount (number of 1s in the AND result)
    popcount = bin(result).count('1')

    # Perform the final calculation
    result = 32 - (popcount * 2)

    return result

# Example usage
binary1 = "00000000001000000000000000100000"
binary2 = "10101010101010101010101010101010"
result = calculate_result(binary1, binary2)
print(f"Result: {result}")


Result: 4


In [38]:
def hex_to_binary(hex_value):
    """
    Convert a hexadecimal value to a 32-bit binary string.

    Args:
        hex_value (str): A hexadecimal value as a string (e.g., "0x55555555").

    Returns:
        str: The binary representation of the hexadecimal value.
    """
    # Remove the "0x" prefix if present
    if hex_value.startswith("0x"):
        hex_value = hex_value[2:]

    # Convert hex to an integer, then format as binary with leading zeros for 32 bits
    binary_representation = bin(int(hex_value, 16))[2:].zfill(32)

    return binary_representation

# Example usage
hex_value = "02"
binary_result = hex_to_binary(hex_value)
print(f"{binary_result}")
#print (type(binary_result))



00000000000000000000000000000010


In [67]:
print(get_w("00010100",'00010203',precomputed_lut,0))

['00000000001000000100000000000000', '01000000000000000000000000100000', '00000000001000000100000000000000', '00000000101000000001000100001101']


In [None]:
def get_w (w1,w2,lut,indx):
    w=[]
    for i in range(2):
        weight1=np.binary_repr( lut[indx][int(w1[0+i*4:2+i*4], 16)],16)
        #print(0+i*4,2+i*4,2+i*4,4+i*4)
        weight2=np.binary_repr( lut[indx][int(w1[2+i*4:4+i*4], 16)],16)
        w.append(weight1+weight2)
        
    for i in range(2):
        weight1=np.binary_repr( lut[indx][int(w2[0+i*4:2+i*4], 16)],16)
        weight2=np.binary_repr( lut[indx][int(w2[2+i*4:4+i*4], 16)],16)
        w.append(weight1+weight2)
        

   

    return w
    
def calc_acc (input,w):
    res=0
    for i,x in enumerate(input):
        x= hex_to_binary(x)
        res+= calculate_result(x,w[i])
    print(res)

w=["00000000001000000000000000100000","00000000001000000000000000100000","00000000001000000000000000100000","00000000001000000000000000100000"]
x=["AAAAAAAA",'0x00000000','0x55555555','0xFFFFFFFF']
calc_acc(x,w)
w=["00000000001000000100000000000000","00000000101000000001000100001101","00000000001000000000000000100000","00000000001000000000000000100000"]
x=['0xFFFFFFFF','0xFFFFFFFF','0xFFFFFFFF','0x00000000']
calc_acc(x,w)

w=["00000000001000000100000000000000","00000000101000000001000100001101","00000000001000000000000000100000","00000000001000000000000000100000"]
x=['0x00000002','00040006','0xAAAAAAAA','0xAAAAAAAA']
calc_acc(x,w)
w=get_w("00010203",'00000000',precomputed_lut,0)
calc_acc(x,w)





0
-46
50
50


In [70]:
x=['0xFFFFFFFF','0xFFFFFFFF','0xFFFFFFFF','0x00000000']
w=get_w("00010203",'93456899',precomputed_lut,1)
calc_acc(x,w)
x=['0xF1534567','0x29105755','0x21182119','0x11472765']
w=get_w("32542919",'05892134',precomputed_lut,6)
calc_acc(x,w)



-14
2
