In [None]:
import torch
from itertools import product

def generate_e8_roots():
        # Standard way to generate the 240 E8 roots in 8D
        roots = []

        # Type 1: ±e_i ± e_j (i < j)
        for i in range(8):
            for j in range(i + 1, 8):
                for signs in [(1, 1), (1, -1), (-1, 1), (-1, -1)]:
                    v = torch.zeros(8)
                    v[i] = signs[0]
                    v[j] = signs[1]
                    roots.append(v)

        print(f"Type 1: {len(roots)} roots")

        # Type 2: (±1/2, ±1/2, ..., ±1/2) with even number of + signs
        for signs in product([-0.5, 0.5], repeat=8):
            if sum(s > 0 for s in signs) % 2 == 0:  # even number of +1/2
                print(signs)
                v = torch.tensor(signs)
                roots.append(v)

        print(f"Type 2: {len(roots)} roots")

        roots = torch.stack(roots)
        # Normalize to unit length (all have norm sqrt(2) actually)
        #roots = roots / roots.norm(dim=1, keepdim=True)
        # Remove duplicates (there are exactly 240 unique up to sign)
        roots = torch.unique(roots, dim=0)
        assert roots.shape[0] == 240

        return roots  # (240, 8)


e8_roots = generate_e8_roots()
print(e8_roots.shape)
print()

Type 1: 112 roots
(-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5)
(-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, 0.5, 0.5)
(-0.5, -0.5, -0.5, -0.5, -0.5, 0.5, -0.5, 0.5)
(-0.5, -0.5, -0.5, -0.5, -0.5, 0.5, 0.5, -0.5)
(-0.5, -0.5, -0.5, -0.5, 0.5, -0.5, -0.5, 0.5)
(-0.5, -0.5, -0.5, -0.5, 0.5, -0.5, 0.5, -0.5)
(-0.5, -0.5, -0.5, -0.5, 0.5, 0.5, -0.5, -0.5)
(-0.5, -0.5, -0.5, -0.5, 0.5, 0.5, 0.5, 0.5)
(-0.5, -0.5, -0.5, 0.5, -0.5, -0.5, -0.5, 0.5)
(-0.5, -0.5, -0.5, 0.5, -0.5, -0.5, 0.5, -0.5)
(-0.5, -0.5, -0.5, 0.5, -0.5, 0.5, -0.5, -0.5)
(-0.5, -0.5, -0.5, 0.5, -0.5, 0.5, 0.5, 0.5)
(-0.5, -0.5, -0.5, 0.5, 0.5, -0.5, -0.5, -0.5)
(-0.5, -0.5, -0.5, 0.5, 0.5, -0.5, 0.5, 0.5)
(-0.5, -0.5, -0.5, 0.5, 0.5, 0.5, -0.5, 0.5)
(-0.5, -0.5, -0.5, 0.5, 0.5, 0.5, 0.5, -0.5)
(-0.5, -0.5, 0.5, -0.5, -0.5, -0.5, -0.5, 0.5)
(-0.5, -0.5, 0.5, -0.5, -0.5, -0.5, 0.5, -0.5)
(-0.5, -0.5, 0.5, -0.5, -0.5, 0.5, -0.5, -0.5)
(-0.5, -0.5, 0.5, -0.5, -0.5, 0.5, 0.5, 0.5)
(-0.5, -0.5, 0.5, -0.5, 0.5, -0.5, -0.5, -0.5)
(-0.5

In [None]:
mask = (e8_roots == 0.5).any(dim=1)   # rows where at least one entry == 0.5
e8_roots[mask]

tensor([[-0.5000, -0.5000, -0.5000,  ..., -0.5000,  0.5000,  0.5000],
        [-0.5000, -0.5000, -0.5000,  ...,  0.5000, -0.5000,  0.5000],
        [-0.5000, -0.5000, -0.5000,  ...,  0.5000,  0.5000, -0.5000],
        ...,
        [ 0.5000,  0.5000,  0.5000,  ..., -0.5000,  0.5000, -0.5000],
        [ 0.5000,  0.5000,  0.5000,  ...,  0.5000, -0.5000, -0.5000],
        [ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000]])

In [None]:
(e8_roots == 1/2) | (e8_roots == -1/2)

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [None]:
from tqdm import tqdm

def generate_d8_half_vectors(max_sq_norm: float = 10.0) -> torch.Tensor:
    """
    """
    vectors = []

    # From (k_i + 1/2)^2 < 10 we get |k_i + 1/2| < sqrt(10) ~ 3.16,
    # so k_i must lie in {-3, -2, -1, 0, 1, 2}.
    for ks in tqdm(product([0.5, 1.5, 2.5], repeat=8)):
        vec = torch.tensor(ks)

        if (vec < 0).any() or vec.dot(vec) > max_sq_norm:
            continue

        vectors.append(vec)

    if not vectors:
        raise RuntimeError("No D8 half-integer vectors found; check constraints.")

    vectors = torch.stack(vectors, dim=0)
    # Remove any possible duplicates (should not occur, but this is safe).
    vectors = torch.unique(vectors, dim=0)

    # QUIP#/VQ construction expects exactly 227 such vectors.
    #assert vectors.shape[0] == 227, f"Expected 227 vectors, got {vectors.shape[0]}"

    return vectors

In [None]:
d8 = generate_d8_half_vectors()

6561it [00:00, 124852.90it/s]


In [None]:
d8.shape

torch.Size([227, 8])

In [None]:
def append_12(d8_half_vectors:torch.Tensor)->torch.Tensor:
    """Append 12 to the end of the vector."""
    additional = torch.tensor([[3, 1, 1, 1, 3, 3, 3, 3], [1, 3, 1, 1, 3, 3, 3, 3], [1, 1, 3, 1, 3, 3, 3, 3],
                               [1, 1, 1, 3, 3, 3, 3, 3], [3, 3, 3, 1, 3, 3, 1, 1], [3, 3, 3, 1, 3, 1, 3, 1],
                               [3, 3, 3, 1, 1, 3, 3, 1], [3, 3, 3, 1, 3, 1, 1, 3], [3, 3, 3, 1, 1, 3, 1, 3],
                               [3, 3, 3, 1, 1, 1, 3, 3], [3, 3, 1, 3, 3, 3, 1, 1], [3, 3, 1, 3, 3, 1, 3, 1],
                               [3, 3, 1, 3, 1, 3, 3, 1], [3, 3, 1, 3, 3, 1, 1, 3], [3, 3, 1, 3, 1, 3, 1, 3],
                               [3, 3, 1, 3, 1, 1, 3, 3], [3, 1, 3, 3, 3, 3, 1, 1], [3, 1, 3, 3, 3, 1, 3, 1],
                               [3, 1, 3, 3, 1, 3, 3, 1], [3, 1, 3, 3, 3, 1, 1, 3], [3, 1, 3, 3, 1, 3, 1, 3],
                               [1, 3, 3, 3, 1, 1, 3, 3], [1, 3, 3, 3, 3, 3, 1, 1], [1, 3, 3, 3, 3, 1, 3, 1],
                               [1, 3, 3, 3, 1, 3, 3, 1], [1, 3, 3, 3, 3, 1, 1, 3], [1, 3, 3, 3, 1, 3, 1, 3],
                               [1, 1, 3, 3, 1, 3, 3, 3], [3, 3, 1, 1, 3, 3, 3, 1]]) / 2


    return torch.cat([d8_half_vectors, additional], dim=0)


def generate_d8_signs(d8_half_vectors:torch.Tensor)->torch.Tensor:
        """Generate the signs for the d8 half vectors."""

        vectors = []

        for vec in d8_half_vectors:
            for signs in product([-1, 1], repeat=7):
                new_vec = vec.clone()
                new_vec[1:] = new_vec[1:] * torch.tensor(signs)

                if new_vec.sum() % 2 != 0:
                    new_vec[0] = -new_vec[0]

                if new_vec.sum() % 2 != 0:
                    raise ValueError("Invalid vector")

                vectors.append(new_vec)

        vectors = torch.stack(vectors, dim=0)

        #assert vectors.shape[0] == 2 ** (7 + 8)

        return vectors

def add(d8_signs:torch.Tensor)->torch.Tensor:
    return torch.cat([d8_signs, d8_signs + 0.25], dim=0)

In [None]:
d8_full = append_12(d8)
d8_signs = generate_d8_signs(d8_full)
codebook = add(d8_signs)

In [None]:
torch.unique(codebook, dim=0).shape

torch.Size([65536, 8])

In [None]:
len(list(product([-1, 1], repeat=7)))

128

In [None]:
2**(7+8)

32768

In [None]:
2 ** (7 + 8)

32768

In [None]:
import torch

a = torch.load("codebook_magnitude.pt")

a

tensor([1.7876, 2.4469, 3.0478, 3.7594])