In [1]:
%env CUDA_VISIBLE_DEVICES=8

from random import randint

import torch
torch.set_float32_matmul_precision('high')
import torch.nn.functional as F
import triton
import triton.language as tl

env: CUDA_VISIBLE_DEVICES=8


## Triton Kernels

In [2]:
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64 * 32}),
        triton.Config({"BLOCK_SIZE": 128 * 32}),
        triton.Config({"BLOCK_SIZE": 256 * 32}),
        triton.Config({"BLOCK_SIZE": 512 * 32}),
    ],
    key=[],
)
@triton.jit
def rtn_1x16s_fp4_kernel(
    x_ptr,
    amax_ptr,
    output_ptr,
    n_elements: tl.constexpr,
    scale_override: tl.constexpr,
    group_size: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):        
    # load x
    pid = tl.program_id(0)
    start_idx = pid * BLOCK_SIZE
    offsets = start_idx + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x_flat = tl.load(x_ptr + offsets, mask=mask)
    
    # amax
    scales_max = 447.99
    val_max = 6.0 / scale_override
    amax = tl.load(amax_ptr)
    s_dec = tl.where(
        amax == 0.0,
        1.0,
        amax / scales_max / val_max,
    )
    
    # group
    x_grouped = tl.reshape(x_flat, (BLOCK_SIZE // group_size, group_size))
    
    # scale
    s_dec_b = tl.max(tl.abs(x_grouped), axis=-1, keep_dims=True) / val_max
    s_dec_b_e4m3 = (s_dec_b / s_dec).to(tl.float8e4nv).to(tl.float32)
    s_dec_b_e4m3 = tl.where(
        s_dec_b_e4m3 == 0,
        1.0,
        s_dec_b_e4m3,
    )
    s_enc_b_inv = s_dec_b_e4m3 * s_dec
    x_scaled = x_grouped / s_enc_b_inv
    
    # quantize
    x_scaled_abs = tl.abs(x_scaled)
    x_scaled_sign = tl.where(
        x_scaled > 0,
        1,
        -1,
    )
    x_fp4_abs = tl.where(
        x_scaled_abs >= 5,
        6,
        tl.where(
            x_scaled_abs >= 3.5,
            4,
            tl.where(
                x_scaled_abs >= 2.5,
                3,
                tl.where(
                    x_scaled_abs >= 1.75,
                    2,
                    tl.where(
                        x_scaled_abs >= 1.25,
                        1.5,
                        tl.where(
                            x_scaled_abs >= 0.75,
                            1,
                            tl.where(
                                x_scaled_abs >= 0.25,
                                0.5,
                                0.0,
                            )
                        )
                    )
                )
            )
        )
    )
    x_fp4 = x_fp4_abs * x_scaled_sign

    # dequantize
    x_dequantized = x_fp4 * s_enc_b_inv
    
    # Reshape back to flat form for storage
    x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,))
    
    # store
    tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask)

@torch.compiler.disable()
def rtn_1x16s_fp4_kernel_wrapper(
    x: torch.Tensor,
    scale_override: float,
    group_size: int,
) -> torch.Tensor:
    x = x.contiguous()
    output = torch.empty_like(x)
    
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    
    rtn_1x16s_fp4_kernel[grid](
        x_ptr=x,
        amax_ptr=x.abs().max(),
        output_ptr=output,
        n_elements=n_elements,
        scale_override=scale_override,
        group_size=group_size,
    )
    return output

In [3]:
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64 * 32}),
        triton.Config({"BLOCK_SIZE": 128 * 32}),
        triton.Config({"BLOCK_SIZE": 256 * 32}),
        triton.Config({"BLOCK_SIZE": 512 * 32}),
    ],
    key=[],
)
@triton.jit
def eden_1x16s_fp4_kernel(
    x_ptr,
    hadamard_matrix_ptr,
    current_amax_ptr,
    output_ptr,
    next_amax_ptr,
    n_elements: tl.constexpr,
    hadamard_dim: tl.constexpr,
    scale_override: tl.constexpr,
    group_size: tl.constexpr,
    seed: int,
    BLOCK_SIZE: tl.constexpr,
):    
    # load x
    pid = tl.program_id(0)
    start_idx = pid * BLOCK_SIZE
    offsets = start_idx + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x_flat = tl.load(x_ptr + offsets, mask=mask)
    
    # hadamard transform
    offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim)
    hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim) 
    x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
    x_had = tl.dot(x, hadamard_matrix) # not TN!, A @ B!
    
    # write amax for next iter
    tl.atomic_max(next_amax_ptr, tl.max(tl.abs(x_had)), sem="relaxed")
    
    # group
    x_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size))

    # amax
    scales_max = 255.99 # Not 448 because eden needs space to rescale up a bit sometimes after the correction
    val_max = 6.0 / scale_override
    amax = tl.load(current_amax_ptr)
    s_dec = tl.where(
        amax == 0.0,
        1.0,
        amax / scales_max / val_max,
    )
    
    # scale
    s_dec_b = tl.max(tl.abs(x_grouped), axis=-1, keep_dims=True) / val_max
    s_dec_b_e4m3 = (s_dec_b / s_dec).to(tl.float8e4nv).to(tl.float32)
    s_dec_b_e4m3 = tl.where(
        s_dec_b_e4m3 == 0,
        1.0,
        s_dec_b_e4m3,
    )
    x_scaled = x_grouped / (s_dec_b_e4m3 * s_dec)
    
    # quantize
    x_scaled_abs = tl.abs(x_scaled)
    x_scaled_sign = tl.where(
        x_scaled > 0,
        1,
        -1,
    )
    x_fp4 = tl.where(
        x_scaled_abs >= 5,
        6,
        tl.where(
            x_scaled_abs >= 3.5,
            4,
            tl.where(
                x_scaled_abs >= 2.5,
                3,
                tl.where(
                    x_scaled_abs >= 1.75,
                    2,
                    tl.where(
                        x_scaled_abs >= 1.25,
                        1.5,
                        tl.where(
                            x_scaled_abs >= 0.75,
                            1,
                            tl.where(
                                x_scaled_abs >= 0.25,
                                0.5,
                                0,
                            )
                        )
                    )
                )
            )
        )
    ) * x_scaled_sign
    
    # Calculate EDEN scale
    x_scaled = tl.reshape(x_scaled, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
    x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
    
    num = tl.sum(x_scaled * x_scaled, axis=-1, keep_dims=True)
    denom = tl.sum(x_scaled * x_fp4, axis=-1, keep_dims=True)
    
    correction = tl.where(
        denom == 0.0,
        1.0,
        num / denom,
    )
    
    # Apply EDEN scale
    scales = tl.reshape(s_dec_b_e4m3, (BLOCK_SIZE // hadamard_dim, hadamard_dim // group_size))
    corrected_scales = tl.reshape(scales * correction, (BLOCK_SIZE // group_size, 1))
    
    bitscales = tl.cast(corrected_scales.to(tl.float8e4nv), tl.uint8, bitcast=True)
    prevscale = tl.cast((bitscales - 1), tl.float8e4nv, bitcast=True).to(tl.float32)
    currscale = tl.cast((bitscales), tl.float8e4nv, bitcast=True).to(tl.float32)
    nextscale = tl.cast((bitscales + 1), tl.float8e4nv, bitcast=True).to(tl.float32)
    
    up = tl.where(
        currscale > corrected_scales,
        currscale,
        nextscale,
    )
    down = tl.where(
        currscale > corrected_scales,
        prevscale,
        currscale,
    )
    
    prob_up = (corrected_scales - down) / (up - down)
    
    scale_start_idx = pid * (BLOCK_SIZE // group_size)
    scale_offsets = scale_start_idx + tl.arange(0, BLOCK_SIZE // group_size)
    sampled_prob = tl.rand(seed, scale_offsets).reshape(BLOCK_SIZE // group_size, 1)
    
    scales = tl.where(
        sampled_prob < prob_up,
        up,
        down,
    )
    scales = tl.reshape(scales, (BLOCK_SIZE // group_size, 1))
    x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // group_size, group_size))
    
    # Reshape back to flat form for storage
    x_dequantized = x_fp4 * scales * s_dec
    x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,))
    
    # store
    tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask)

@torch.compiler.disable()
def eden_1x16s_fp4_kernel_wrapper(
    x: torch.Tensor,
    hadamard_matrix: torch.Tensor,
    scale_override: float,
    group_size: int,
    current_amax: torch.Tensor,
) -> [torch.Tensor, torch.Tensor]:
    hadamard_dim = hadamard_matrix.size(0)
    assert hadamard_matrix.size(1) == hadamard_dim
    assert x.numel() % hadamard_dim == 0
    assert hadamard_dim % group_size == 0
    
    x = x.contiguous()
    hadamard_matrix = hadamard_matrix.T.contiguous() # .T.contiguous() + tl.dot -> TN
    output = torch.empty_like(x)
    seed = randint(0, 1000000)
    
    next_amax = torch.zeros_like(current_amax)
    
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    
    eden_1x16s_fp4_kernel[grid](
        x_ptr=x,
        hadamard_matrix_ptr=hadamard_matrix,
        current_amax_ptr=current_amax,
        output_ptr=output,
        next_amax_ptr=next_amax,
        n_elements=n_elements,
        hadamard_dim=hadamard_dim,
        scale_override=scale_override,
        group_size=group_size,
        seed=seed,
    )
    return output, next_amax


In [4]:
from scipy.linalg import hadamard

def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
    return torch.tensor(
        hadamard(group_size) * group_size**-0.5,
        dtype=dtype,
        device=device,
        requires_grad=False,
    )
    
def rerotate_hadamard(hadamard_matrix):
    signs = torch.diag(
        torch.randint(
            0, 2, (hadamard_matrix.size(0),),
            device=hadamard_matrix.device,
            dtype=hadamard_matrix.dtype
        ) * 2 - 1
    )
    return hadamard_matrix @ signs # NOTE: rerotate along last dim, inner dim for TN GEMM

In [5]:
from tqdm.auto import trange, tqdm

M = 1024
N = 1024
K = 1024
HADAMARD_DIM = 32

A = torch.randn((M, K), device='cuda')
B = torch.randn((N, K), device='cuda')
ht = get_hadamard_matrix(32, A.dtype, A.device)

with torch.no_grad():
    for acc_steps in tqdm([1, 4, 16, 64, 256, 1024], desc="Iterating steps"):
        accumulator = torch.zeros_like(A @ B.T)
        for i in trange(acc_steps, leave=False):
            ht = rerotate_hadamard(ht)

            A_amax_buffer = (A.view(-1, ht.size(0)) @ ht.T).abs().max()
            Aq, A_amax_buffer = eden_1x16s_fp4_kernel_wrapper(
                A,
                ht,
                1.0,
                16,
                current_amax=A_amax_buffer,
            )
            
            B_amax_buffer = (A.view(-1, ht.size(0)) @ ht.T).abs().max()
            Bq, B_amax_buffer = eden_1x16s_fp4_kernel_wrapper(
                B,
                ht,
                1.0,
                16,
                current_amax=B_amax_buffer,
            )
            
            accumulator += Aq @ Bq.T
        accumulator /= acc_steps
        
        quad_err = (accumulator - A @ B.T).pow(2).mean() / (A @ B.T).pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        print(f"{acc_steps}: {eff_bitwidth:.2f} bits")
        
# NEED TO GROW BY ~1 bit per 4x samples
# v

Iterating steps:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

1: 2.84 bits


  0%|          | 0/4 [00:00<?, ?it/s]

4: 3.83 bits


  0%|          | 0/16 [00:00<?, ?it/s]

16: 4.84 bits


  0%|          | 0/64 [00:00<?, ?it/s]

64: 5.83 bits


  0%|          | 0/256 [00:00<?, ?it/s]

256: 6.82 bits


  0%|          | 0/1024 [00:00<?, ?it/s]

1024: 7.74 bits


## Quartet_II linear

In [6]:
class AmaxStorage:
    def __init__(self):
        self.e_ht_amax = None
        self.weght_tht_amax = None
        self.e_tht_amax = None
        self.input_tht_amax = None
        
    def __repr__(self) -> str:
        fields = [
            ("e_ht_amax", self.e_ht_amax), 
            ("weght_tht_amax", self.weght_tht_amax), 
            ("e_tht_amax", self.e_tht_amax), 
            ("input_tht_amax", self.input_tht_amax)
        ]
        field_strs = []
        for name, val in fields:
            if val is not None:
                try:
                    v = val.item()
                except Exception:
                    v = val
                field_strs.append(f"{name}: {v:.3e}")
            else:
                field_strs.append(f"{name}: None")
        return "<AmaxStorage " + ", ".join(field_strs) + ">"
        

class Quartet_II_fn(torch.autograd.Function):
    group_size = 16
    forward_scale_override = 1.0
    backward_scale_override = (17 / 16) * 0.93
    hadamard_matrix = get_hadamard_matrix(32, device="cuda", dtype=torch.float32)

    @torch.compile(dynamic=False)
    @staticmethod
    def forward(ctx, input, weight, amax_storage: AmaxStorage, delayed_amax: bool, disable_backward_quant: bool):
        ctx.batch = input.shape[0]
        ctx.seq = input.shape[1]
        ctx.in_dim = weight.shape[1]
        ctx.out_dim = weight.shape[0]
        ctx.delayed_amax = delayed_amax
        ctx.amax_storage = amax_storage
        ctx.disable_backward_quant = disable_backward_quant
        
        input_fp4 = rtn_1x16s_fp4_kernel_wrapper(input, scale_override=Quartet_II_fn.forward_scale_override, group_size=Quartet_II_fn.group_size)
        weight_fp4 = rtn_1x16s_fp4_kernel_wrapper(weight, scale_override=Quartet_II_fn.forward_scale_override, group_size=Quartet_II_fn.group_size)

        ctx.save_for_backward(input_fp4, weight_fp4)
        return F.linear(input_fp4, weight_fp4)

    @torch.compile(dynamic=False)
    @staticmethod
    def backward(ctx, grad_output):
        # Load ctx and reshape
        input_fp4, weight_fp4 = ctx.saved_tensors
        
        input_fp4 = input_fp4.reshape(ctx.batch * ctx.seq, ctx.in_dim)
        grad_output = grad_output.reshape(ctx.batch * ctx.seq, ctx.out_dim)
        
        # Re-randomize the rotation
        Quartet_II_fn.hadamard_matrix = rerotate_hadamard(Quartet_II_fn.hadamard_matrix)
        
        # No backward quant if flag
        if ctx.disable_backward_quant:
            grad_input = F.linear(
                grad_output,
                weight_fp4.T,
                None,
            ).view(ctx.batch, ctx.seq, ctx.in_dim)
            
            grad_weight = F.linear(
                grad_output.T,
                input_fp4.T,
                None,
            )
            return grad_input, grad_weight, None, None, None
        
        # EW
        if ctx.amax_storage.e_ht_amax is None or not ctx.delayed_amax:
            ctx.amax_storage.e_ht_amax = (grad_output.reshape(-1, Quartet_II_fn.hadamard_matrix.size(0)) @ Quartet_II_fn.hadamard_matrix.T).abs().max()
        e_ht_fp4, ctx.amax_storage.e_ht_amax = eden_1x16s_fp4_kernel_wrapper(grad_output, Quartet_II_fn.hadamard_matrix, Quartet_II_fn.backward_scale_override, 16, ctx.amax_storage.e_ht_amax)
        
        if ctx.amax_storage.weght_tht_amax is None or not ctx.delayed_amax:
            ctx.amax_storage.weght_tht_amax = (weight_fp4.T.reshape(-1, Quartet_II_fn.hadamard_matrix.size(0)) @ Quartet_II_fn.hadamard_matrix.T).abs().max()
        weight_tht_fp4, ctx.amax_storage.weght_tht_amax = eden_1x16s_fp4_kernel_wrapper(weight_fp4.T, Quartet_II_fn.hadamard_matrix, Quartet_II_fn.backward_scale_override, 16, ctx.amax_storage.weght_tht_amax)
        
        grad_input = F.linear(
            e_ht_fp4,
            weight_tht_fp4,
            None,
        ).view(ctx.batch, ctx.seq, ctx.in_dim)

        # EtX
        if ctx.amax_storage.e_tht_amax is None or not ctx.delayed_amax:
            ctx.amax_storage.e_tht_amax = (grad_output.T.reshape(-1, Quartet_II_fn.hadamard_matrix.size(0)) @ Quartet_II_fn.hadamard_matrix.T).abs().max()
        e_tht_fp4, ctx.amax_storage.e_tht_amax = eden_1x16s_fp4_kernel_wrapper(grad_output.T, Quartet_II_fn.hadamard_matrix, Quartet_II_fn.backward_scale_override, Quartet_II_fn.group_size, ctx.amax_storage.e_tht_amax)
        
        if ctx.amax_storage.input_tht_amax is None or not ctx.delayed_amax:
            ctx.amax_storage.input_tht_amax = (input_fp4.T.reshape(-1, Quartet_II_fn.hadamard_matrix.size(0)) @ Quartet_II_fn.hadamard_matrix.T).abs().max()
        input_tht_fp4, ctx.amax_storage.input_tht_amax = eden_1x16s_fp4_kernel_wrapper(input_fp4.T, Quartet_II_fn.hadamard_matrix, Quartet_II_fn.backward_scale_override, Quartet_II_fn.group_size, ctx.amax_storage.input_tht_amax)
        
        grad_weight = F.linear(
            e_tht_fp4,
            input_tht_fp4,
            None,
        )
        
        return grad_input, grad_weight, None, None, None

In [7]:
class Quartet_II_linear(torch.nn.Linear):
    def __init__(self, *args, delayed_amax=False, disable_backward_quant=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.delayed_amax = delayed_amax
        self.disable_backward_quant = disable_backward_quant
        self.amax_storage = AmaxStorage()
        
    
    def forward(self, x, disable_backward_quant=False):
        return Quartet_II_fn.apply(x, self.weight, self.amax_storage, self.delayed_amax, self.disable_backward_quant)
    

In [8]:
BATCH = 4
SEQ = 16
HID = 256
DELAYED_AMAX = True

INPUT = torch.randn((BATCH, SEQ, HID), device='cuda')
TARGET = torch.randn((BATCH, SEQ, 1), device='cuda')

W1 = Quartet_II_linear(HID, HID, device='cuda', delayed_amax=DELAYED_AMAX)
W2 = Quartet_II_linear(HID, HID, device='cuda', delayed_amax=DELAYED_AMAX)
W3 = Quartet_II_linear(HID, HID, device='cuda', delayed_amax=DELAYED_AMAX)

with torch.no_grad():
    W1.weight /= (HID**0.5 * W1.weight.std())
    W2.weight /= (HID**0.5 * W2.weight.std())
    W3.weight /= (HID**0.5 * W3.weight.std())

head = torch.randn(HID, 1, device='cuda')

In [9]:
W1.weight.grad = None
W2.weight.grad = None
W3.weight.grad = None

W1.disable_backward_quant = True
W2.disable_backward_quant = True
W3.disable_backward_quant = True

hid = W1(INPUT)
hid = torch.nn.functional.relu(hid)
hid = W2(hid)
hid = torch.nn.functional.relu(hid)
hid = W3(hid)
loss = (hid @ head - TARGET).pow(2).sum()
loss.backward()

w1_ref_grad = W1.weight.grad.clone().detach()
w2_ref_grad = W2.weight.grad.clone().detach()
w3_ref_grad = W3.weight.grad.clone().detach()

In [10]:
W1.disable_backward_quant = False
W2.disable_backward_quant = False
W3.disable_backward_quant = False

hid = W1(INPUT)
hid = torch.nn.functional.relu(hid)
hid = W2(hid)
hid = torch.nn.functional.relu(hid)
hid = W3(hid)
loss = (hid @ head - TARGET).pow(2).sum()

for acc_steps in tqdm([1, 4, 16, 64, 256]):
    W1.weight.grad = None
    W2.weight.grad = None
    W3.weight.grad = None
    for _ in trange(acc_steps, leave=False):
        loss.backward(retain_graph=True)
    with torch.no_grad():
        W1.weight.grad /= acc_steps
        W2.weight.grad /= acc_steps
        W3.weight.grad /= acc_steps
    
        quad_err = (W1.weight.grad - w1_ref_grad).pow(2).mean() / w1_ref_grad.pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        cosine = (W1.weight.grad.flatten() @ w1_ref_grad.flatten()) / (w1_ref_grad.flatten() @ w1_ref_grad.flatten())
        print(f"{acc_steps} acc_steps:\n\tW1 grad err: {eff_bitwidth:.2f} bits, {cosine:.3f} cosine")
        
        quad_err = (W2.weight.grad - w2_ref_grad).pow(2).mean() / w2_ref_grad.pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        cosine = (W2.weight.grad.flatten() @ w2_ref_grad.flatten()) / (w2_ref_grad.flatten() @ w2_ref_grad.flatten())
        print(f"\tW2 grad err: {eff_bitwidth:.2f} bits, {cosine:.3f} cosine")
        
        quad_err = (W3.weight.grad - w3_ref_grad).pow(2).mean() / w3_ref_grad.pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        cosine = (W3.weight.grad.flatten() @ w3_ref_grad.flatten()) / (w3_ref_grad.flatten() @ w3_ref_grad.flatten())
        print(f"\tW3 grad err: {eff_bitwidth:.2f} bits, {cosine:.3f} cosine")


# NEED TO GROW BY ~1 bit per 4x samples
# v

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

1 acc_steps:
	W1 grad err: 1.97 bits, 0.996 cosine
	W2 grad err: 2.71 bits, 0.983 cosine
	W3 grad err: 4.47 bits, 0.996 cosine


If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))


  0%|          | 0/4 [00:00<?, ?it/s]

4 acc_steps:
	W1 grad err: 2.95 bits, 0.998 cosine
	W2 grad err: 3.74 bits, 1.000 cosine
	W3 grad err: 5.37 bits, 0.999 cosine


  0%|          | 0/16 [00:00<?, ?it/s]

16 acc_steps:
	W1 grad err: 4.04 bits, 0.998 cosine
	W2 grad err: 4.73 bits, 0.997 cosine
	W3 grad err: 6.17 bits, 0.993 cosine


  0%|          | 0/64 [00:00<?, ?it/s]

64 acc_steps:
	W1 grad err: 5.01 bits, 0.995 cosine
	W2 grad err: 5.80 bits, 0.996 cosine
	W3 grad err: 7.15 bits, 0.997 cosine


  0%|          | 0/256 [00:00<?, ?it/s]

256 acc_steps:
	W1 grad err: 5.93 bits, 0.995 cosine
	W2 grad err: 6.55 bits, 0.997 cosine
	W3 grad err: 8.08 bits, 1.000 cosine
