In [4]:
%load_ext autoreload
%autoreload 2

import sys
import os
import dotenv
from pathlib import Path

env_file = "../.env"

if os.path.exists(env_file):
    dotenv.load_dotenv(env_file, verbose=True)
    print("Loaded environment variables from .env file.")

cwd = os.getcwd()
# for some reason appending to PATH you need it to be string
sys.path.append(str(Path(cwd).parent / "src"))
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Loaded environment variables from .env file.


In [5]:
from research_tools.gpu import get_gpus_available

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in get_gpus_available()])

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
updated_model = AutoModelForCausalLM.from_pretrained(
    model_id, revision="main", torch_dtype=torch.bfloat16
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
updated_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
from datasets import load_dataset
import json

unlearning_task = "cyber"

forget_corpus = (
    "cyber-forget-corpus" if unlearning_task == "cyber" else "bio-forget-corpus"
)
retain_corpus = "wikitext"


def get_unlearning_rmu_dataset(name, min_len=0, batch_size=4):
    data = []
    if name == "wikitext":
        raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        for x in raw_data:
            if len(x["text"]) > min_len:
                data.append(str(x["text"]))
    elif name == "cyber-forget-corpus":
        raw_data = load_dataset("cais/wmdp-corpora", name=name, split="train")
        for x in raw_data:
            if len(x["text"]) > min_len:
                data.append(str(x["text"]))
    else:  # wmdp bio must be pre-downloaded
        for line in open(f"data/{name}.jsonl", "r"):
            raw_text = json.loads(line)["text"]
            if len(raw_text) > min_len:
                data.append(str(raw_text))
    data = [data[i : i + batch_size] for i in range(0, len(data), batch_size)]
    return data


forget_data_list = get_unlearning_rmu_dataset(forget_corpus)
retain_data_list = get_unlearning_rmu_dataset(retain_corpus)

In [None]:
idx = 0
max_length = 512
tokenizer.pad_token_id = tokenizer.eos_token_id

unlearn_batch = forget_data_list[idx]
retain_batch = retain_data_list[idx]

unlearn_inputs = tokenizer(
    unlearn_batch,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=max_length,
).to(updated_model.device)
retain_inputs = tokenizer(
    retain_batch,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=max_length,
).to(updated_model.device)

adv_labels_mask = torch.zeros_like(unlearn_inputs["input_ids"], dtype=bool)
def_labels_mask = torch.zeros_like(retain_inputs["input_ids"], dtype=bool)
for b, example in enumerate(retain_batch):
    len_example = len(tokenizer(example)["input_ids"])
    def_labels_mask[b, :len_example] = True
for b, example in enumerate(unlearn_batch):
    len_example = len(tokenizer(example)["input_ids"])
    adv_labels_mask[b, :len_example] = True

# prompt_mask = torch.zeros(len(unlearn_batch), pad_length + 1, dtype=torch.bool)
pgd_batch = {
    "def_tokens": retain_inputs["input_ids"].to(updated_model.device),
    "adv_tokens": unlearn_inputs["input_ids"].to(updated_model.device),
    "adv_labels_mask": adv_labels_mask.to(updated_model.device),
    "def_labels_mask": def_labels_mask.to(updated_model.device),
}

In [6]:
epsilon = 20.0
lr_adv = 1e-3
noise_completions = True
n_pgd_steps = 20

In [None]:
from unlearn_order.algos.lat.model import LATModel


lat_model = LATModel(
    model=updated_model, tokenizer=tokenizer, layer=12, attack_completions=False
)

In [None]:
for param in updated_model.model.layers[0].parameters():
    print()

AttributeError: 'Parameter' object has no attribute 'named_modules'

[('mlp.down_proj', Linear(in_features=14336, out_features=4096, bias=False))]

In [None]:
from tqdm import tqdm
from unlearn_order.algos.lat.hooks import add_hooks
from unlearn_order.algos.lat.model import get_adversary_hook, projected_gradient_descent
from unlearn_order.algos.lat.utils import get_layer_module, set_model_grads
from typing import List, Dict
from transformers import LlamaForCausalLM, LlamaTokenizer


def get_layers_params(
    model: LlamaForCausalLM, layers: List[int], target_modules: List[str]
):
    params = []
    for layer in layers:
        for name, module in model.model[layer].named_modules():
            if any([target in name for target in target_modules]):
                for param in module.parameters():
                    params.append(param)
    return params


def forward_with_cache(
    model: LlamaForCausalLM,
    inputs: Dict[str, torch.Tensor],
    module: torch.nn.Module,
    no_grad: bool = True,
):
    # define a tensor with the size of our cached activations
    cache = []

    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None

    hook_handle = module.register_forward_hook(hook)

    if no_grad:
        with torch.no_grad():
            _ = model(**inputs)
    else:
        _ = model(**inputs)

    hook_handle.remove()

    return cache[0]


# you're optimizing layer_ids
# there is a layer where you're examining the activations at and trying to make garbage
# then you only train on the laeyrs before that, this is just shoving garbage in
# not real unlearning?
# let us also latently perturb at the same layer

from peft import PeftConfig

PeftConfig()


# remember params
def run_rmu(
    updated_model: LlamaForCausalLM,
    frozen_model: LlamaForCausalLM,
    tokenizer: LlamaTokenizer,
    forget_data_list,
    retain_data_list,
    target_modules: List[str],
    alpha: float = 1200.0,
    layer: int = 8,  # layers to do RMU in
    lr_def: float = 5.0e-5,
    max_num_batches=200,
    steering_coef: float = 6.5,
    use_lat: bool = True,
    epsilon: float = 2,
    lr_adv: float = 5e-2,
    n_pgd_steps: int = 16,
    loss_coefs: Dict[str, float] = {"toward": 1, "away": 1},
    num_epochs: int = 1,
):
    adv_layer = layer - 1
    def_layers = [layer, layer - 1, layer - 2]

    updated_model.train()

    params = get_layers_params(updated_model, def_layers, target_modules)
    optimizer = torch.optim.AdamW(params, lr=lr_def)

    updated_module = get_layer_module(updated_model, layer)
    frozen_module = get_layer_module(frozen_model, layer)

    control_vectors_list = []
    for i in range(len(forget_data_list)):
        random_vector = torch.rand(
            1,
            1,
            updated_model.config.hidden_size,
            dtype=updated_model.dtype,
            device=updated_model.device,
        )
        control_vec = random_vector / torch.norm(random_vector) * steering_coef
        control_vectors_list.append(control_vec)

    num_batches = max_num_batches
    truncation_side = tokenizer.truncation_side
    tokenizer.truncation_side = "right"

    for epoch in range(num_epochs):
        for idx in tqdm(range(num_batches)):
            control_vec = control_vectors_list[idx]
            unlearn_batch = forget_data_list[idx]
            retain_batch = retain_data_list[idx]

            max_length = 512 if idx == 0 else 768

            unlearn_inputs = tokenizer(
                unlearn_batch,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=max_length,
            ).to(updated_model.device)

            retain_inputs = tokenizer(
                retain_batch,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=max_length,
            ).to(updated_model.device)

            if use_lat:
                adv_labels_mask = torch.zeros_like(
                    unlearn_inputs["input_ids"], dtype=bool
                )
                def_labels_mask = torch.zeros_like(
                    retain_inputs["input_ids"], dtype=bool
                )

                for b, example in enumerate(retain_batch):
                    len_example = len(tokenizer(example)["input_ids"])
                    def_labels_mask[b, :len_example] = True
                for b, example in enumerate(unlearn_batch):
                    len_example = len(tokenizer(example)["input_ids"])
                    adv_labels_mask[b, :len_example] = True

                # prompt_mask = torch.zeros(len(unlearn_batch), pad_length + 1, dtype=torch.bool)
                pgd_batch = {
                    "def_tokens": retain_inputs["input_ids"].to(updated_model.device),
                    "adv_tokens": unlearn_inputs["input_ids"].to(updated_model.device),
                    "adv_labels_mask": adv_labels_mask.to(updated_model.device),
                    "def_labels_mask": def_labels_mask.to(updated_model.device),
                }
                adversary = projected_gradient_descent(
                    updated_model,
                    pgd_batch,
                    layer=layer,
                    epsilon=epsilon,
                    n_pgd_steps=n_pgd_steps,
                    lr_adv=lr_adv,
                    loss_coefs=loss_coefs,
                )

                adversary.set_attack_grads(False)

                set_model_grads(updated_model, True)

                # Unlearning loss
                with add_hooks(
                    module_forward_hooks=[
                        (
                            get_layer_module(updated_model, layer),
                            get_adversary_hook(adversary),
                        )
                    ],
                    module_forward_pre_hooks=[],
                ):
                    updated_forget_activations = forward_with_cache(
                        updated_model,
                        unlearn_inputs,
                        module=updated_module,
                        no_grad=False,
                    ).to(updated_model.device)
                    unlearn_loss = torch.nn.functional.mse_loss(
                        updated_forget_activations, control_vec
                    )
            else:
                # Unlearning loss
                updated_forget_activations = forward_with_cache(
                    updated_model, unlearn_inputs, module=updated_module, no_grad=False
                ).to(updated_model.device)
                unlearn_loss = torch.nn.functional.mse_loss(
                    updated_forget_activations, control_vec
                )

            # Retain loss
            updated_retain_activations = forward_with_cache(
                updated_model, retain_inputs, module=updated_module, no_grad=False
            ).to(updated_model.device)
            frozen_retain_activations = forward_with_cache(
                frozen_model, retain_inputs, module=frozen_module, no_grad=True
            ).to(updated_model.device)

            retain_loss = torch.nn.functional.mse_loss(
                updated_retain_activations, frozen_retain_activations
            )
            retain_loss *= alpha

            # Update model
            loss = unlearn_loss + retain_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    tokenizer.truncation_side = truncation_side

    return updated_model

In [None]:
frozen_model = AutoModelForCausalLM.from_pretrained(model_id, revision="main").to(
    device
)
updated_model = run_rmu(
    updated_model,
    frozen_model,  # function to make another version of the model for reference embeddings
    tokenizer,  # tokenizer
    forget_data_list,  # forget data
    retain_data_list,  # retain data
    alpha=1200.0,  # coef for rmu
    layer_ids=[6, 7, 8],  # layers to train
    layer=8,  # layers to do RMU in
    param_ids=[6],  # param set to do RMU in
    lr_def=5.0e-5,  # model lr
    module_str="{model_name}.model.layers[{layer_id}]",  # to access the model layers
    steering_coef=6.5,  # coef for RMU
    model_iterations_per_step=4,  # how many steps to train the model for
    max_num_batches=200,  # max num batches
    use_lat=True,  # whether to do LAT
    lat_layer=7,  # layers to attack
    epsilon=2,  # attack l2 norm bound
    lr_adv=5.0e-2,  # attack learning rate
    n_pgd_steps=16,  # number of steps to do PGD for
    loss_coefs={"toward": 1, "away": 1},  # adv loss coefs
    num_epochs=1,  # number of epochs to train for
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
  unlearn_loss = torch.nn.functional.mse_loss(
  0%|          | 1/200 [00:09<32:04,  9.67s/it]

step 0 loss: 0.095703125


  unlearn_loss = torch.nn.functional.mse_loss(
  1%|          | 2/200 [00:22<39:00, 11.82s/it]

step 1 loss: 0.06591796875


  2%|▏         | 3/200 [00:36<41:14, 12.56s/it]

step 2 loss: 0.0654296875


  2%|▏         | 4/200 [00:49<42:12, 12.92s/it]

step 3 loss: 0.0634765625


  2%|▎         | 5/200 [01:03<42:40, 13.13s/it]

step 4 loss: 0.0634765625


  3%|▎         | 6/200 [01:17<43:00, 13.30s/it]

step 5 loss: 0.0634765625


  4%|▎         | 7/200 [01:30<43:06, 13.40s/it]

step 6 loss: 0.06298828125


  4%|▍         | 8/200 [01:44<43:06, 13.47s/it]

step 7 loss: 0.0634765625


  4%|▍         | 9/200 [01:57<43:01, 13.52s/it]

step 8 loss: 0.06298828125


  5%|▌         | 10/200 [02:11<43:02, 13.59s/it]

step 9 loss: 0.06298828125


  6%|▌         | 11/200 [02:25<42:46, 13.58s/it]

step 10 loss: 0.06298828125


  6%|▌         | 12/200 [02:38<42:22, 13.52s/it]

step 11 loss: 0.0625


  6%|▋         | 13/200 [02:52<42:10, 13.53s/it]

step 12 loss: 0.06298828125


  7%|▋         | 14/200 [03:05<42:03, 13.57s/it]

step 13 loss: 0.06298828125


  8%|▊         | 15/200 [03:19<41:48, 13.56s/it]

step 14 loss: 0.06298828125


  8%|▊         | 16/200 [03:32<41:37, 13.58s/it]

step 15 loss: 0.06298828125


  8%|▊         | 17/200 [03:46<41:30, 13.61s/it]

step 16 loss: 0.06298828125


  9%|▉         | 18/200 [04:00<41:20, 13.63s/it]

step 17 loss: 0.06298828125


 10%|▉         | 19/200 [04:13<41:09, 13.64s/it]

step 18 loss: 0.0625


 10%|█         | 20/200 [04:27<40:51, 13.62s/it]

step 19 loss: 0.062255859375


 10%|█         | 21/200 [04:41<40:42, 13.65s/it]

step 20 loss: 0.06298828125


 11%|█         | 22/200 [04:54<40:26, 13.63s/it]

step 21 loss: 0.0625


 12%|█▏        | 23/200 [05:08<40:13, 13.63s/it]

step 22 loss: 0.0625


 12%|█▏        | 24/200 [05:22<39:55, 13.61s/it]

step 23 loss: 0.0625


 12%|█▎        | 25/200 [05:35<39:40, 13.60s/it]

step 24 loss: 0.062255859375


 13%|█▎        | 26/200 [05:49<39:21, 13.57s/it]

step 25 loss: 0.06201171875


 14%|█▎        | 27/200 [06:02<39:03, 13.55s/it]

step 26 loss: 0.0625


 14%|█▍        | 28/200 [06:16<38:51, 13.56s/it]

step 27 loss: 0.06298828125


 14%|█▍        | 29/200 [06:30<38:52, 13.64s/it]

step 28 loss: 0.062255859375


 15%|█▌        | 30/200 [06:43<38:41, 13.65s/it]

step 29 loss: 0.0625


 16%|█▌        | 31/200 [06:57<38:32, 13.68s/it]

step 30 loss: 0.0625


 16%|█▌        | 32/200 [07:11<38:25, 13.72s/it]

step 31 loss: 0.062255859375


 16%|█▋        | 33/200 [07:24<38:06, 13.69s/it]

step 32 loss: 0.06298828125


 17%|█▋        | 34/200 [07:38<37:53, 13.69s/it]

step 33 loss: 0.061767578125


 18%|█▊        | 35/200 [07:52<37:36, 13.68s/it]

step 34 loss: 0.06201171875


 18%|█▊        | 36/200 [08:05<37:24, 13.69s/it]

step 35 loss: 0.062255859375


 18%|█▊        | 37/200 [08:19<37:09, 13.68s/it]

step 36 loss: 0.061767578125


 19%|█▉        | 38/200 [08:33<36:52, 13.66s/it]

step 37 loss: 0.0615234375


 20%|█▉        | 39/200 [08:46<36:35, 13.64s/it]

step 38 loss: 0.06298828125


 20%|██        | 40/200 [09:00<36:26, 13.66s/it]

step 39 loss: 0.06201171875


 20%|██        | 41/200 [09:14<36:18, 13.70s/it]

step 40 loss: 0.061767578125


 21%|██        | 42/200 [09:27<36:01, 13.68s/it]

step 41 loss: 0.06201171875


 22%|██▏       | 43/200 [09:41<35:47, 13.68s/it]

step 42 loss: 0.061767578125


 22%|██▏       | 44/200 [09:55<35:28, 13.65s/it]

step 43 loss: 0.061767578125


 22%|██▎       | 45/200 [10:08<35:12, 13.63s/it]

step 44 loss: 0.06201171875


 23%|██▎       | 46/200 [10:22<34:56, 13.62s/it]

step 45 loss: 0.061767578125


 24%|██▎       | 47/200 [10:35<34:41, 13.61s/it]

step 46 loss: 0.061279296875


 24%|██▍       | 48/200 [10:49<34:32, 13.63s/it]

step 47 loss: 0.0615234375


 24%|██▍       | 49/200 [11:03<34:30, 13.71s/it]

step 48 loss: 0.061767578125


 25%|██▌       | 50/200 [11:17<34:12, 13.68s/it]

step 49 loss: 0.06103515625


 26%|██▌       | 51/200 [11:30<33:53, 13.65s/it]

step 50 loss: 0.061279296875


 26%|██▌       | 52/200 [11:44<33:37, 13.63s/it]

step 51 loss: 0.0615234375


 26%|██▋       | 53/200 [11:57<33:18, 13.60s/it]

step 52 loss: 0.061767578125


 27%|██▋       | 54/200 [12:11<33:12, 13.65s/it]

step 53 loss: 0.0615234375


 28%|██▊       | 55/200 [12:25<32:57, 13.64s/it]

step 54 loss: 0.06103515625


 28%|██▊       | 56/200 [12:38<32:41, 13.62s/it]

step 55 loss: 0.061279296875


 28%|██▊       | 57/200 [12:52<32:31, 13.65s/it]

step 56 loss: 0.0615234375


 29%|██▉       | 58/200 [13:06<32:17, 13.65s/it]

step 57 loss: 0.0615234375


 30%|██▉       | 59/200 [13:19<32:03, 13.64s/it]

step 58 loss: 0.060791015625


 30%|███       | 60/200 [13:33<31:48, 13.63s/it]

step 59 loss: 0.061279296875


 30%|███       | 61/200 [13:47<31:38, 13.65s/it]

step 60 loss: 0.0615234375


 31%|███       | 62/200 [14:00<31:29, 13.69s/it]

step 61 loss: 0.060791015625


 32%|███▏      | 63/200 [14:14<31:13, 13.67s/it]

step 62 loss: 0.061279296875


 32%|███▏      | 64/200 [14:28<30:57, 13.66s/it]

step 63 loss: 0.06103515625


 32%|███▎      | 65/200 [14:41<30:43, 13.65s/it]

step 64 loss: 0.0615234375


 33%|███▎      | 66/200 [14:55<30:28, 13.65s/it]

step 65 loss: 0.06103515625


 34%|███▎      | 67/200 [15:09<30:15, 13.65s/it]

step 66 loss: 0.0615234375


 34%|███▍      | 68/200 [15:22<30:01, 13.65s/it]

step 67 loss: 0.06103515625


 34%|███▍      | 69/200 [15:36<29:51, 13.68s/it]

step 68 loss: 0.06103515625


 35%|███▌      | 70/200 [15:50<29:41, 13.71s/it]

step 69 loss: 0.060546875


 36%|███▌      | 71/200 [16:03<29:23, 13.67s/it]

step 70 loss: 0.060546875


 36%|███▌      | 72/200 [16:17<29:09, 13.67s/it]

step 71 loss: 0.060791015625


 36%|███▋      | 73/200 [16:31<28:55, 13.67s/it]

step 72 loss: 0.061279296875


 37%|███▋      | 74/200 [16:44<28:39, 13.65s/it]

step 73 loss: 0.0625


 38%|███▊      | 75/200 [16:58<28:29, 13.67s/it]

step 74 loss: 0.06103515625


 38%|███▊      | 76/200 [17:12<28:17, 13.69s/it]

step 75 loss: 0.0615234375


 38%|███▊      | 77/200 [17:25<28:01, 13.67s/it]

step 76 loss: 0.061279296875


 39%|███▉      | 78/200 [17:39<27:43, 13.64s/it]

step 77 loss: 0.060791015625


 40%|███▉      | 79/200 [17:52<27:25, 13.60s/it]

step 78 loss: 0.060546875


 40%|████      | 80/200 [18:06<27:08, 13.57s/it]

step 79 loss: 0.060546875


 40%|████      | 81/200 [18:19<26:55, 13.57s/it]

step 80 loss: 0.060302734375


 41%|████      | 82/200 [18:33<26:39, 13.56s/it]

step 81 loss: 0.06005859375


 42%|████▏     | 83/200 [18:47<26:25, 13.55s/it]

step 82 loss: 0.06005859375


 42%|████▏     | 84/200 [19:00<26:11, 13.55s/it]

step 83 loss: 0.06103515625


 42%|████▎     | 85/200 [19:14<25:56, 13.53s/it]

step 84 loss: 0.060302734375


 43%|████▎     | 86/200 [19:27<25:42, 13.53s/it]

step 85 loss: 0.059814453125


 44%|████▎     | 87/200 [19:41<25:30, 13.54s/it]

step 86 loss: 0.0634765625


 44%|████▍     | 88/200 [19:54<25:16, 13.54s/it]

step 87 loss: 0.062255859375


 44%|████▍     | 89/200 [20:08<25:06, 13.57s/it]

step 88 loss: 0.061279296875


 45%|████▌     | 90/200 [20:22<24:55, 13.60s/it]

step 89 loss: 0.061767578125


 46%|████▌     | 91/200 [20:35<24:41, 13.59s/it]

step 90 loss: 0.060546875


 46%|████▌     | 92/200 [20:49<24:27, 13.58s/it]

step 91 loss: 0.061767578125


 46%|████▋     | 93/200 [21:02<24:13, 13.58s/it]

step 92 loss: 0.061279296875


 47%|████▋     | 94/200 [21:16<24:02, 13.61s/it]

step 93 loss: 0.060791015625


 48%|████▊     | 95/200 [21:29<23:47, 13.60s/it]

step 94 loss: 0.06103515625


 48%|████▊     | 96/200 [21:43<23:34, 13.60s/it]

step 95 loss: 0.060546875


 48%|████▊     | 97/200 [21:57<23:20, 13.60s/it]

step 96 loss: 0.06103515625


 49%|████▉     | 98/200 [22:10<23:05, 13.59s/it]

step 97 loss: 0.060302734375


 50%|████▉     | 99/200 [22:24<22:53, 13.59s/it]

step 98 loss: 0.06103515625


 50%|█████     | 100/200 [22:38<22:43, 13.64s/it]

step 99 loss: 0.06103515625


 50%|█████     | 101/200 [22:51<22:32, 13.66s/it]

step 100 loss: 0.061279296875


 51%|█████     | 102/200 [23:05<22:16, 13.63s/it]

step 101 loss: 0.060546875


 52%|█████▏    | 103/200 [23:19<22:03, 13.65s/it]

step 102 loss: 0.061279296875


 52%|█████▏    | 104/200 [23:32<21:51, 13.66s/it]

step 103 loss: 0.061279296875


 52%|█████▎    | 105/200 [23:46<21:34, 13.62s/it]

step 104 loss: 0.06201171875


 53%|█████▎    | 106/200 [23:59<21:21, 13.63s/it]

step 105 loss: 0.060546875


 54%|█████▎    | 107/200 [24:13<21:05, 13.61s/it]

step 106 loss: 0.060302734375


 54%|█████▍    | 108/200 [24:27<20:53, 13.63s/it]

step 107 loss: 0.060791015625


 55%|█████▍    | 109/200 [24:40<20:40, 13.63s/it]

step 108 loss: 0.061279296875


 55%|█████▌    | 110/200 [24:54<20:25, 13.62s/it]

step 109 loss: 0.062255859375


 56%|█████▌    | 111/200 [25:07<20:08, 13.58s/it]

step 110 loss: 0.060546875


 56%|█████▌    | 112/200 [25:21<19:56, 13.60s/it]

step 111 loss: 0.06298828125


 56%|█████▋    | 113/200 [25:35<19:46, 13.64s/it]

step 112 loss: 0.06201171875


 57%|█████▋    | 114/200 [25:48<19:28, 13.59s/it]

step 113 loss: 0.060791015625


 57%|█████▊    | 115/200 [26:02<19:16, 13.60s/it]

step 114 loss: 0.061279296875


 58%|█████▊    | 116/200 [26:15<18:59, 13.56s/it]

step 115 loss: 0.06103515625


 58%|█████▊    | 117/200 [26:29<18:44, 13.55s/it]

step 116 loss: 0.060302734375


 59%|█████▉    | 118/200 [26:42<18:30, 13.55s/it]

step 117 loss: 0.0615234375


 60%|█████▉    | 119/200 [26:56<18:18, 13.56s/it]

step 118 loss: 0.06201171875


 60%|██████    | 120/200 [27:10<18:06, 13.58s/it]

step 119 loss: 0.0615234375


 60%|██████    | 121/200 [27:23<17:54, 13.60s/it]

step 120 loss: 0.06396484375


 61%|██████    | 122/200 [27:37<17:40, 13.60s/it]

step 121 loss: 0.06103515625


 62%|██████▏   | 123/200 [27:50<17:27, 13.60s/it]

step 122 loss: 0.06298828125


 62%|██████▏   | 124/200 [28:04<17:13, 13.60s/it]

step 123 loss: 0.06298828125


 62%|██████▎   | 125/200 [28:18<16:59, 13.60s/it]

step 124 loss: 0.0634765625


 63%|██████▎   | 126/200 [28:31<16:43, 13.56s/it]

step 125 loss: 0.064453125


 64%|██████▎   | 127/200 [28:45<16:32, 13.60s/it]

step 126 loss: 0.061279296875


 64%|██████▍   | 128/200 [28:58<16:20, 13.62s/it]

step 127 loss: 0.06298828125


 64%|██████▍   | 129/200 [29:12<16:06, 13.61s/it]

step 128 loss: 0.072265625


 65%|██████▌   | 130/200 [29:26<15:54, 13.63s/it]

step 129 loss: 0.06982421875


 66%|██████▌   | 131/200 [29:39<15:41, 13.65s/it]

step 130 loss: 0.0732421875


 66%|██████▌   | 132/200 [29:53<15:26, 13.63s/it]

step 131 loss: 0.0654296875


 66%|██████▋   | 133/200 [30:07<15:16, 13.67s/it]

step 132 loss: 0.0634765625


 67%|██████▋   | 134/200 [30:20<15:02, 13.67s/it]

step 133 loss: 0.0654296875


 68%|██████▊   | 135/200 [30:34<14:49, 13.69s/it]

step 134 loss: 0.0625


 68%|██████▊   | 136/200 [30:48<14:37, 13.71s/it]

step 135 loss: 0.061767578125


 68%|██████▊   | 137/200 [31:02<14:24, 13.71s/it]

step 136 loss: 0.064453125


 69%|██████▉   | 138/200 [31:15<14:09, 13.70s/it]

step 137 loss: 0.061767578125


 70%|██████▉   | 139/200 [31:29<13:53, 13.67s/it]

step 138 loss: 0.06640625


 70%|███████   | 140/200 [31:43<13:40, 13.68s/it]

step 139 loss: 0.0634765625


 70%|███████   | 141/200 [31:56<13:26, 13.68s/it]

step 140 loss: 0.0615234375


 71%|███████   | 142/200 [32:10<13:12, 13.67s/it]

step 141 loss: 0.064453125


 72%|███████▏  | 143/200 [32:23<12:56, 13.62s/it]

step 142 loss: 0.0615234375


 72%|███████▏  | 144/200 [32:37<12:40, 13.58s/it]

step 143 loss: 0.060302734375


 72%|███████▎  | 145/200 [32:50<12:24, 13.53s/it]

step 144 loss: 0.060546875


 73%|███████▎  | 146/200 [33:04<12:11, 13.54s/it]

step 145 loss: 0.06005859375


 74%|███████▎  | 147/200 [33:17<11:57, 13.54s/it]

step 146 loss: 0.06005859375


 74%|███████▍  | 148/200 [33:31<11:47, 13.61s/it]

step 147 loss: 0.05908203125


 74%|███████▍  | 149/200 [33:45<11:32, 13.58s/it]

step 148 loss: 0.059326171875


 75%|███████▌  | 150/200 [33:58<11:20, 13.61s/it]

step 149 loss: 0.05908203125


 76%|███████▌  | 151/200 [34:12<11:05, 13.59s/it]

step 150 loss: 0.05908203125


 76%|███████▌  | 152/200 [34:26<10:53, 13.61s/it]

step 151 loss: 0.0703125


 76%|███████▋  | 153/200 [34:39<10:39, 13.62s/it]

step 152 loss: 0.064453125


 77%|███████▋  | 154/200 [34:53<10:27, 13.63s/it]

step 153 loss: 0.06298828125


 78%|███████▊  | 155/200 [35:07<10:15, 13.68s/it]

step 154 loss: 0.06298828125


 78%|███████▊  | 156/200 [35:20<10:01, 13.66s/it]

step 155 loss: 0.064453125


 78%|███████▊  | 157/200 [35:34<09:46, 13.65s/it]

step 156 loss: 0.0625


 79%|███████▉  | 158/200 [35:48<09:32, 13.63s/it]

step 157 loss: 0.061767578125


 80%|███████▉  | 159/200 [36:01<09:19, 13.65s/it]

step 158 loss: 0.060791015625


 80%|████████  | 160/200 [36:15<09:05, 13.65s/it]

step 159 loss: 0.060791015625


 80%|████████  | 161/200 [36:29<08:52, 13.66s/it]

step 160 loss: 0.0615234375


 81%|████████  | 162/200 [36:42<08:38, 13.65s/it]

step 161 loss: 0.0625


 82%|████████▏ | 163/200 [36:56<08:22, 13.59s/it]

step 162 loss: 0.0634765625


 82%|████████▏ | 164/200 [37:09<08:09, 13.59s/it]

step 163 loss: 0.0654296875


 82%|████████▎ | 165/200 [37:23<07:55, 13.59s/it]

step 164 loss: 0.0634765625


 83%|████████▎ | 166/200 [37:37<07:43, 13.64s/it]

step 165 loss: 0.06298828125


 84%|████████▎ | 167/200 [37:50<07:28, 13.60s/it]

step 166 loss: 0.06298828125


 84%|████████▍ | 168/200 [38:04<07:15, 13.61s/it]

step 167 loss: 0.06298828125


 84%|████████▍ | 169/200 [38:17<07:01, 13.61s/it]

step 168 loss: 0.06201171875


 85%|████████▌ | 170/200 [38:31<06:48, 13.62s/it]

step 169 loss: 0.062255859375


 86%|████████▌ | 171/200 [38:45<06:35, 13.63s/it]

step 170 loss: 0.061767578125


 86%|████████▌ | 172/200 [38:58<06:21, 13.61s/it]

step 171 loss: 0.060791015625


 86%|████████▋ | 173/200 [39:12<06:07, 13.60s/it]

step 172 loss: 0.06201171875


 87%|████████▋ | 174/200 [39:25<05:53, 13.61s/it]

step 173 loss: 0.061767578125


 88%|████████▊ | 175/200 [39:39<05:40, 13.63s/it]

step 174 loss: 0.062255859375


 88%|████████▊ | 176/200 [39:53<05:27, 13.63s/it]

step 175 loss: 0.0625


 88%|████████▊ | 177/200 [40:06<05:12, 13.60s/it]

step 176 loss: 0.06005859375


 89%|████████▉ | 178/200 [40:20<04:59, 13.63s/it]

step 177 loss: 0.060791015625


 90%|████████▉ | 179/200 [40:33<04:45, 13.59s/it]

step 178 loss: 0.061279296875


 90%|█████████ | 180/200 [40:47<04:31, 13.56s/it]

step 179 loss: 0.059814453125


 90%|█████████ | 181/200 [41:00<04:17, 13.55s/it]

step 180 loss: 0.059814453125


 91%|█████████ | 182/200 [41:14<04:04, 13.57s/it]

step 181 loss: 0.064453125


 92%|█████████▏| 183/200 [41:28<03:50, 13.58s/it]

step 182 loss: 0.06201171875


 92%|█████████▏| 184/200 [41:41<03:37, 13.58s/it]

step 183 loss: 0.06005859375


 92%|█████████▎| 185/200 [41:55<03:23, 13.59s/it]

step 184 loss: 0.060546875


 93%|█████████▎| 186/200 [42:09<03:10, 13.64s/it]

step 185 loss: 0.06298828125


 94%|█████████▎| 187/200 [42:22<02:57, 13.65s/it]

step 186 loss: 0.06298828125


 94%|█████████▍| 188/200 [42:36<02:44, 13.72s/it]

step 187 loss: 0.061767578125


 94%|█████████▍| 189/200 [42:50<02:30, 13.68s/it]

step 188 loss: 0.061767578125


 95%|█████████▌| 190/200 [43:04<02:17, 13.74s/it]

step 189 loss: 0.060791015625


 96%|█████████▌| 191/200 [43:17<02:03, 13.74s/it]

step 190 loss: 0.060791015625


 96%|█████████▌| 192/200 [43:31<01:49, 13.69s/it]

step 191 loss: 0.06298828125


 96%|█████████▋| 193/200 [43:45<01:35, 13.69s/it]

step 192 loss: 0.06298828125


 97%|█████████▋| 194/200 [43:58<01:22, 13.67s/it]

step 193 loss: 0.06103515625


 98%|█████████▊| 195/200 [44:12<01:08, 13.66s/it]

step 194 loss: 0.06298828125


 98%|█████████▊| 196/200 [44:25<00:54, 13.62s/it]

step 195 loss: 0.0595703125


 98%|█████████▊| 197/200 [44:39<00:40, 13.61s/it]

step 196 loss: 0.06201171875


 99%|█████████▉| 198/200 [44:53<00:27, 13.62s/it]

step 197 loss: 0.06298828125


100%|█████████▉| 199/200 [45:06<00:13, 13.60s/it]

step 198 loss: 0.06103515625


100%|██████████| 200/200 [45:20<00:00, 13.60s/it]

step 199 loss: 0.060546875



