In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

import sys
sys.path.append("/home/msst/repo/Quantization")
import qlib

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# FP
path_to_model = "/media/msst/ssd_storage1/ml/llm/pretrained_models/Llama2-7B"
tokenizer = AutoTokenizer.from_pretrained(path_to_model)
# model = AutoModelForCausalLM.from_pretrained(
model = qlib._modeling.modeling_llama.LlamaForCausalLM.from_pretrained(
    path_to_model,
    device_map="cpu",
    dtype="auto"
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 109.09it/s]


In [3]:
class QuantLinear(nn.Module):
    def __init__(self, weight_shape, group_size):
        super().__init__()
        self.weight_shape = weight_shape
        self.group_size = group_size
        self.scale_size = [weight_shape[0], weight_shape[1] // group_size]
        self.scale = nn.Parameter(torch.empty(self.scale_size))
        self.offset = nn.Parameter(torch.empty(self.scale_size))
        self.register_buffer(
            "compressed_weight",
            torch.empty(
                weight_shape,
                dtype=torch.uint8,
                requires_grad=False
            )
        )

    def reshape_weight_for_scaling(self, w):
        return w.reshape(
            self.weight_shape[0], self.weight_shape[1] // self.group_size, -1
        )


    @torch.compile()
    def reconstruct_weight(self):
        w = self.reshape_weight_for_scaling(self.compressed_weight)
        w = w * self.scale[..., None] - self.offset[..., None]
        w = w.reshape(self.weight_shape)
        return w


    def forward(self, x):
        w = self.reconstruct_weight()
        return torch.nn.functional.linear(x, w)

In [4]:
qmodel = AutoModelForCausalLM.from_pretrained(
    path_to_model,
    device_map="cpu",
    dtype="auto"
)

def wrap_model(current_module, prefix=''):
    for module_name, module in current_module.named_children():
        full_name = f"{prefix}.{module_name}" if prefix else module_name
        
        if "proj" in module_name:
            weight_shape = module.weight.data.shape
            setattr(current_module, module_name, QuantLinear(weight_shape, 128))
        else:
            wrap_model(module, full_name)

wrap_model(qmodel)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 105.31it/s]


In [5]:
class MinMaxInitializer:
    def __init__(self):
        pass

    @torch.no_grad()
    def __call__(self, x_grouped, negative_clip, positive_clip):
        x_min = x_grouped.min(axis=-1)[0].unsqueeze(-1).float()
        x_max = x_grouped.max(axis=-1)[0].unsqueeze(-1).float()

        offset = (x_max * negative_clip - x_min * positive_clip) / (positive_clip - negative_clip)
        scale = (x_max + offset) / positive_clip
        scale = torch.abs(scale)

        scale = scale.reshape(x_grouped.shape[0], x_grouped.shape[1])
        offset = offset.reshape(x_grouped.shape[0], x_grouped.shape[1])

        return scale.contiguous(), offset.contiguous()

initializer = MinMaxInitializer()

@torch.no_grad()
def configure_single_layer(qlayer, layer, bits):
    max_int_val = 2**bits - 1

    orig_weight = layer.weight.data
    orig_weight_reshaped = qlayer.reshape_weight_for_scaling(orig_weight)
    scale, offset = initializer(orig_weight_reshaped, negative_clip=0, positive_clip=max_int_val)
    
    quant_weight = (orig_weight_reshaped + offset[..., None]) / scale[..., None]
    quant_weight = quant_weight.reshape_as(orig_weight)

    quant_weight = torch.clamp(torch.round(quant_weight), 0, max_int_val).to(torch.uint8)

    qlayer.compressed_weight.copy_(quant_weight)
    qlayer.scale.copy_(scale)
    qlayer.offset.copy_(offset)


@torch.no_grad()
def init_quant_model(qmodel, model, bits):
    for qmodule_name, qmodule in qmodel.named_modules():
        if isinstance(qmodule, QuantLinear):
            orig_module = model.get_submodule(qmodule_name)
            configure_single_layer(qmodule, orig_module, bits)

            err = torch.mean(((orig_module.weight.data.cpu() - qmodule.reconstruct_weight().cpu()) / (orig_module.weight.data.cpu().std() + 1e-8))**2)
            print(err, qmodule_name)


# init_quant_model(qmodel, model, bits=3)

In [6]:
from transformers.masking_utils import create_causal_mask


@torch.no_grad()
def prepare_hessian(activations):
    hidden_size = activations[0].shape[-1]
    H = torch.zeros(hidden_size, hidden_size).cuda()
    for act in activations:
        act = act.cuda().view(-1, act.shape[-1]).float() / hidden_size ** 2
        H += act.T @ act
    return H


def hessian_loss(layer_q, layer_fp, H):
    delta_w = layer_q.reconstruct_weight() - layer_fp.weight
    return torch.trace(delta_w @ H @ delta_w.T)


def optimize_quant_params(
        layer_q,
        layer_fp,
        bits,
        H
    ):
    trainable_params = [layer_q.scale, layer_q.offset]
    # optim = torch.optim.Adam(trainable_params, lr=3e-4)
    # n_steps = 100    
    
    optim = torch.optim.Adam(trainable_params, lr=1e-3)
    n_steps = 100
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        for i in range(n_steps):
            optim.zero_grad()
            loss = hessian_loss(layer_q, layer_fp, H)
            if i == 0:
                init_loss =  loss.item()
            loss.backward()
            optim.step()

        print(f"{init_loss} -> {loss}")


def init_quant_block_hessian(
        block_q,
        block_fp,
        bits,
        activations,
        causal_mask,
        position_embeddings
        ):

    ##### Attention #####

    # Copy activations for the residual stream
    residual_activations = [x.clone() for x in activations]

    # Collect activations after input_layernorm
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            act = block_fp.input_layernorm(act)
            activations[act_id] = act.cpu()

    # Initialize q,k,v-projs
    H = prepare_hessian(activations)
    block_q_attn = block_q.self_attn
    block_fp_attn = block_fp.self_attn
    for layer_name in ["q_proj", "k_proj", "v_proj"]:
        layer_q = getattr(block_q_attn, layer_name)
        layer_fp = getattr(block_fp_attn, layer_name)
        configure_single_layer(layer_q, layer_fp, bits)
        optimize_quant_params(layer_q, layer_fp, bits, H)

    # Collect attention-out activations
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            act = block_fp_attn.compute_attention(
                hidden_states=act, 
                position_embeddings=position_embeddings,
                attention_mask=causal_mask
            )[0]
            activations[act_id] = act.cpu()

    # Initialize o_proj
    layer_q = block_q.self_attn.o_proj
    layer_fp = block_fp.self_attn.o_proj
    H = prepare_hessian(activations)
    configure_single_layer(layer_q, layer_fp, bits)
    optimize_quant_params(layer_q, layer_fp, bits, H)

    # Collect self_attn outs
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            res_act = residual_activations[act_id].cuda()
            act = block_fp.self_attn.o_proj(act)
            activations[act_id] = (act + res_act).cpu()

    ##### MLP #####

    # Copy activations for the residual stream
    residual_activations = [x.clone() for x in activations]

    # Collect activations after post_attention_layernorm
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            act = block_fp.post_attention_layernorm(act)
            activations[act_id] = act.cpu()

    # Initialize gate_proj and up_proj
    H = prepare_hessian(activations)
    block_q_mlp = block_q.mlp
    block_fp_mlp = block_fp.mlp
    for layer_name in ["gate_proj", "up_proj"]:
        layer_q = getattr(block_q_mlp, layer_name)
        layer_fp = getattr(block_fp_mlp, layer_name)
        configure_single_layer(layer_q, layer_fp, bits)
        optimize_quant_params(layer_q, layer_fp, bits, H)

    # Collect internal mlp activations
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            act = block_fp_mlp.act_fn(block_fp.mlp.gate_proj(act)) * block_fp.mlp.up_proj(act)
            activations[act_id] = act.cpu()

    # Initialize down_proj
    layer_q = block_q.mlp.down_proj
    layer_fp = block_fp.mlp.down_proj
    H = prepare_hessian(activations)
    configure_single_layer(layer_q, layer_fp, bits)
    optimize_quant_params(layer_q, layer_fp, bits, H)

    # Collect mlp outs
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            res_act = residual_activations[act_id].cuda()
            act = block_fp.mlp.down_proj(act)
            activations[act_id] = (act + res_act).cpu()


def init_quant_model_hessian(model_q, model_fp, bits, dataloader):
    embed_tokens = model_fp.get_decoder().embed_tokens.cuda()
    embed_tokens_device = embed_tokens.weight.device

    _batch = next(iter(dataloader))
    _inputs_embeds = embed_tokens(_batch.to(embed_tokens_device))

    cache_position = torch.arange(_inputs_embeds.shape[1], device=_inputs_embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = create_causal_mask(
        config=model_fp.config,
        input_embeds=_inputs_embeds,
        attention_mask=None,
        cache_position=cache_position,
        past_key_values=None,
        position_ids=position_ids,
    )

    position_embeddings = model_fp.get_decoder().rotary_emb(_inputs_embeds, position_ids)

    # Prepare activations
    activations = []
    with torch.no_grad():
        for batch in dataloader:
            activations.append(embed_tokens(batch.to(embed_tokens_device)).cpu())

    for decoder_layer_id in tqdm(range(len(model_q.get_decoder().layers))):
        block_q = model_q.get_decoder().layers[decoder_layer_id].cuda()
        block_fp = model_fp.get_decoder().layers[decoder_layer_id].cuda()
        
        init_quant_block_hessian(
            block_q,
            block_fp,
            bits,
            activations,
            causal_mask,
            position_embeddings
        )

        block_q = block_q.cpu()
        block_fp = block_fp.cpu()


config = {
    "dataset_name" : "wiki",
    "split": "train[:7500]",
    "seq_length": 4096,
    "n_seq" : 64,
    "batch_size": 8,
    "random_seed": 'no_rand'
}
dataloader = qlib.QATDataset(
    config=config,
    tokenizer=tokenizer
).get_dataloader()
print(len(dataloader))

init_quant_model_hessian(qmodel, model, 3, dataloader)

8


  0%|          | 0/32 [00:00<?, ?it/s]

init_loss: tensor(1.2502e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.2263e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3721e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.5248e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.8072e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.3621e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.2383e-12, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.9861e-12, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.7145e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.0465e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.3452e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.6174e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.9322e-13, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.9322e-13, device='cuda:0', grad_fn=<Tr

  3%|▎         | 1/32 [00:31<16:19, 31.60s/it]

init_loss: tensor(1.4616e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.7970e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3599e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.8350e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.0242e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.1427e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.5040e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.3637e-11, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.2190e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.9534e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.8283e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.5901e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.2292e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.1293e-12, device='cuda:0', grad_fn=<Tr

  6%|▋         | 2/32 [01:02<15:34, 31.15s/it]

init_loss: tensor(2.5753e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.0681e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.7720e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1613e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.5219e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.9937e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.6148e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.5914e-11, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.7212e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.3092e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.8618e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.5000e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.9439e-12, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.9439e-12, device='cuda:0', grad_fn=<Tr

  9%|▉         | 3/32 [01:33<14:57, 30.95s/it]

init_loss: tensor(4.5503e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.7535e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.7825e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.8963e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.2890e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.7696e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.6190e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.3778e-11, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.1366e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.0598e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.7540e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(9.0956e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.0342e-12, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.0341e-12, device='cuda:0', grad_fn=<Tr

 12%|█▎        | 4/32 [02:03<14:22, 30.80s/it]

init_loss: tensor(4.4361e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.6206e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.5497e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.6736e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3266e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.6311e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.2919e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.2589e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.6349e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.4822e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3231e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1988e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.6206e-12, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.6196e-12, device='cuda:0', grad_fn=<Tr

 16%|█▌        | 5/32 [02:34<13:48, 30.68s/it]

init_loss: tensor(4.6929e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.9026e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.0164e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.1170e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.4868e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.9429e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.3756e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.2930e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.0499e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.8705e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.6461e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.4986e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.1236e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1235e-11, device='cuda:0', grad_fn=<Tr

 19%|█▉        | 6/32 [03:04<13:16, 30.63s/it]

init_loss: tensor(6.2781e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.2545e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.3638e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.3560e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.8790e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.2746e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.5427e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.3027e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.5964e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.3526e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.9966e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.8093e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.6979e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.6975e-11, device='cuda:0', grad_fn=<Tr

 22%|██▏       | 7/32 [03:35<12:44, 30.57s/it]

init_loss: tensor(6.6742e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.6482e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.6608e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.6408e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.0964e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.4404e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.0481e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.7292e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.0652e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.7696e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.3798e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.1464e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.3207e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.3199e-11, device='cuda:0', grad_fn=<Tr

 25%|██▌       | 8/32 [04:05<12:12, 30.52s/it]

init_loss: tensor(6.9475e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.6866e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.9639e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.6764e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.2218e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.4960e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.2804e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.6334e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.2123e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.9041e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.6456e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.3852e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.8589e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.8578e-11, device='cuda:0', grad_fn=<Tr

 28%|██▊       | 9/32 [04:36<11:47, 30.75s/it]

init_loss: tensor(6.9495e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.8617e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.2731e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.0802e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.3423e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.6197e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.0729e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(9.9164e-10, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.3958e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.0729e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.8980e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.6132e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.3624e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.3607e-11, device='cuda:0', grad_fn=<Tr

 31%|███▏      | 10/32 [05:08<11:23, 31.07s/it]

init_loss: tensor(7.2394e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.0187e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.5205e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.2945e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.3650e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.6462e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.5415e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.4179e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.5407e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.1954e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.0906e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.7822e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.9027e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.9005e-11, device='cuda:0', grad_fn=<Tr

 34%|███▍      | 11/32 [05:40<10:55, 31.21s/it]

init_loss: tensor(7.5821e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.5826e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.5415e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.5247e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.0322e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.2165e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.6478e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.4743e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.7743e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.4147e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.3842e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.0521e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.3134e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.3098e-11, device='cuda:0', grad_fn=<Tr

 38%|███▊      | 12/32 [06:11<10:26, 31.32s/it]

init_loss: tensor(7.9357e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.8204e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.3264e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.1213e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.9340e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.1515e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.8496e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.6823e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.9491e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.5925e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.6481e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.3133e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.9794e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.9760e-11, device='cuda:0', grad_fn=<Tr

 41%|████      | 13/32 [06:43<09:56, 31.37s/it]

init_loss: tensor(7.6572e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.8528e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.9440e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.0231e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.1187e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.3618e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.0192e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.7578e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.1979e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.7921e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.9726e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.5897e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.1704e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.1615e-11, device='cuda:0', grad_fn=<Tr

 44%|████▍     | 14/32 [07:14<09:25, 31.42s/it]

init_loss: tensor(8.0067e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.0396e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.3019e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.2794e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.1277e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.3766e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.4273e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.2012e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.5083e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.1006e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.2990e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.9081e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.1466e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.1383e-11, device='cuda:0', grad_fn=<Tr

 47%|████▋     | 15/32 [07:46<08:54, 31.45s/it]

init_loss: tensor(7.7152e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.7383e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.1788e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.0631e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.3465e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.4743e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.4420e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.2386e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.9580e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.5073e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.7354e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.3016e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.1990e-11, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(9.1822e-11, device='cuda:0', grad_fn=<Tr

 50%|█████     | 16/32 [08:17<08:23, 31.46s/it]

init_loss: tensor(7.8270e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.0049e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.1133e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.2464e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.7145e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.8403e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.1942e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.8732e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.6367e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.1296e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.2971e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.8169e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.2342e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.2298e-10, device='cuda:0', grad_fn=<Tr

 53%|█████▎    | 17/32 [08:49<07:51, 31.45s/it]

init_loss: tensor(7.7983e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.1444e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.0518e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.3929e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.7203e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.9472e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.4611e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.2428e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.3726e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.8112e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.8410e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.3186e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3398e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.3353e-10, device='cuda:0', grad_fn=<Tr

 56%|█████▋    | 18/32 [09:20<07:20, 31.47s/it]

init_loss: tensor(8.0641e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.4724e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.3314e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.7088e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.4446e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.5630e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.3516e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.1640e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.2514e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.6159e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.4897e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.9096e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.5749e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.5690e-10, device='cuda:0', grad_fn=<Tr

 59%|█████▉    | 19/32 [09:52<06:49, 31.49s/it]

init_loss: tensor(7.8446e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.3094e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.0497e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.5095e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.4550e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.5878e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.3406e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.1691e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.8298e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.1639e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.9478e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.3514e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.7028e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.6972e-10, device='cuda:0', grad_fn=<Tr

 62%|██████▎   | 20/32 [10:23<06:17, 31.49s/it]

init_loss: tensor(7.9367e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.4598e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.1542e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.6718e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.6134e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.7463e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.8209e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.5501e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.4311e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.6826e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.3887e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.7338e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.1497e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.1371e-10, device='cuda:0', grad_fn=<Tr

 66%|██████▌   | 21/32 [10:54<05:42, 31.16s/it]

init_loss: tensor(8.4151e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.8962e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.5146e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.0277e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.3766e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.4159e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.3249e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.1753e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.1213e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.3422e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.8657e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.1893e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.1994e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.1913e-10, device='cuda:0', grad_fn=<Tr

 69%|██████▉   | 22/32 [11:24<05:09, 30.93s/it]

init_loss: tensor(9.0673e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.3910e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.2095e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.5367e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.6291e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.5636e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.4108e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.0279e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.7677e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.9420e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.3087e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.5997e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.4530e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.4454e-10, device='cuda:0', grad_fn=<Tr

 72%|███████▏  | 23/32 [11:54<04:36, 30.77s/it]

init_loss: tensor(9.6934e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.0767e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.7813e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.1818e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.6539e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.5466e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.1081e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.8827e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.0385e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(9.5239e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.9260e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.1844e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.6211e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.6094e-10, device='cuda:0', grad_fn=<Tr

 75%|███████▌  | 24/32 [12:25<04:05, 30.70s/it]

init_loss: tensor(9.0279e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.4384e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.2092e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.5713e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.5474e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.3310e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.6679e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.2857e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.0942e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.0047e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.4329e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.6551e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(2.7447e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(2.7327e-10, device='cuda:0', grad_fn=<Tr

 78%|███████▊  | 25/32 [12:55<03:34, 30.61s/it]

init_loss: tensor(1.0232e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.6504e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.0195e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.6785e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.7695e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.5431e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.1018e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.2358e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.1754e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.0624e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.0151e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(9.1700e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.0908e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.0613e-10, device='cuda:0', grad_fn=<Tr

 81%|████████▏ | 26/32 [13:26<03:03, 30.54s/it]

init_loss: tensor(9.7581e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.0071e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.7987e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.1033e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.8170e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.4099e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.9988e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.1421e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.2469e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1227e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.0771e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(9.7076e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(3.3808e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.3280e-10, device='cuda:0', grad_fn=<Tr

 84%|████████▍ | 27/32 [13:56<02:32, 30.51s/it]

init_loss: tensor(1.0340e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.7650e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.0371e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.7965e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.6554e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.5227e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.6105e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.9899e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3301e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1983e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.1605e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.0459e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.0144e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.8879e-10, device='cuda:0', grad_fn=<Tr

 88%|████████▊ | 28/32 [14:27<02:01, 30.47s/it]

init_loss: tensor(1.0199e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.6879e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.0219e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.7014e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.5258e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.2651e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.0847e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.8432e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3884e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.2342e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.2492e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1100e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(5.0261e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.6692e-10, device='cuda:0', grad_fn=<Tr

 91%|█████████ | 29/32 [14:57<01:31, 30.44s/it]

init_loss: tensor(9.2264e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.7323e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.1298e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.6768e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.1063e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.7878e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.2729e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.6282e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.4380e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.2667e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3091e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1519e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.7042e-10, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.9016e-10, device='cuda:0', grad_fn=<Tr

 94%|█████████▍| 30/32 [15:27<01:00, 30.42s/it]

init_loss: tensor(9.6619e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.2085e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(9.5392e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(8.0846e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.6496e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(7.3428e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.6393e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.3326e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.5134e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.3055e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3521e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1682e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.9375e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.0397e-09, device='cuda:0', grad_fn=<Tr

 97%|█████████▋| 31/32 [15:58<00:30, 30.40s/it]

init_loss: tensor(6.6953e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(5.7151e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(7.0954e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.0492e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(4.9710e-08, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(4.2320e-08, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(8.7753e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(6.9127e-09, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.3666e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.1600e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(1.2210e-07, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(1.0359e-07, device='cuda:0', grad_fn=<TraceBackward0>)
init_loss: tensor(6.2121e-09, device='cuda:0', grad_fn=<TraceBackward0>)
res_loss: tensor(3.5146e-09, device='cuda:0', grad_fn=<Tr

100%|██████████| 32/32 [16:28<00:00, 30.89s/it]


In [7]:
config = {
    "dataset_name" : "wiki",
    "split": "test",
    "seq_length": 4096,
    "batch_size": 1,
    "random_seed": 'no_rand'
}

dataloader = qlib.QATDataset(
    config=config,
    tokenizer=tokenizer
).get_dataloader()

qmodel = qmodel.cuda()
# qmodel = model.cuda()

with torch.no_grad():
    with torch.amp.autocast('cuda', dtype=torch.float16):
        res = qlib.evaluate(qmodel, dataloader, print_times=25)
        print(res)

  5%|▍         | 4/83 [00:04<01:32,  1.17s/it]

5.326810313460885


  8%|▊         | 7/83 [00:07<01:29,  1.18s/it]

5.49138789832124


 12%|█▏        | 10/83 [00:10<01:25,  1.18s/it]

5.993877091337049


 16%|█▌        | 13/83 [00:13<01:22,  1.18s/it]

6.00483434472225


 19%|█▉        | 16/83 [00:16<01:18,  1.18s/it]

6.241929882400746


 23%|██▎       | 19/83 [00:19<01:15,  1.18s/it]

6.1370251076631925


 27%|██▋       | 22/83 [00:22<01:11,  1.18s/it]

5.916831914319163


 30%|███       | 25/83 [00:26<01:08,  1.18s/it]

5.865186418271654


 34%|███▎      | 28/83 [00:29<01:04,  1.18s/it]

6.0069449884136485


 37%|███▋      | 31/83 [00:32<01:01,  1.18s/it]

5.908814094036397


 41%|████      | 34/83 [00:35<00:57,  1.18s/it]

5.731856700686873


 45%|████▍     | 37/83 [00:38<00:54,  1.18s/it]

5.820989810645159


 48%|████▊     | 40/83 [00:41<00:50,  1.18s/it]

5.719625884893324


 52%|█████▏    | 43/83 [00:44<00:47,  1.18s/it]

5.715653544915329


 55%|█████▌    | 46/83 [00:47<00:43,  1.18s/it]

5.726612795938129


 59%|█████▉    | 49/83 [00:50<00:40,  1.18s/it]

5.716056923675822


 63%|██████▎   | 52/83 [00:54<00:36,  1.18s/it]

5.7143734147335765


 66%|██████▋   | 55/83 [00:57<00:32,  1.18s/it]

5.870452942993012


 70%|██████▉   | 58/83 [01:00<00:29,  1.18s/it]

5.877431419146757


 73%|███████▎  | 61/83 [01:03<00:25,  1.18s/it]

5.897400702319287


 77%|███████▋  | 64/83 [01:06<00:22,  1.18s/it]

5.857047966446579


 81%|████████  | 67/83 [01:09<00:18,  1.18s/it]

5.8767462288430306


 84%|████████▍ | 70/83 [01:12<00:15,  1.18s/it]

5.854831554794251


 88%|████████▊ | 73/83 [01:15<00:11,  1.18s/it]

5.831968882713048


 92%|█████████▏| 76/83 [01:18<00:08,  1.18s/it]

5.8051996943841


 95%|█████████▌| 79/83 [01:22<00:04,  1.18s/it]

5.771935874904906


 99%|█████████▉| 82/83 [01:25<00:01,  1.18s/it]

5.8045090858022625


100%|██████████| 83/83 [01:25<00:00,  1.03s/it]


5.814998848786007
