In [31]:
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig
# from mech_interp_othello_utils import OthelloBoardState
import einops
import torch
from tqdm import tqdm
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import csv
import chess_utils
from dataclasses import dataclass

device = "cuda"
# device = "cpu"
device = "mps"

n_layers = 16
n_heads = 8
MODEL_DIR = "models/"
DATA_DIR = "data/"
cfg = HookedTransformerConfig(
    n_layers = n_layers,
    d_model = 512,
    d_head = 64,
    n_heads = n_heads,
    d_mlp = 2048,
    d_vocab = 32,
    n_ctx = 1023,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)
model_name = "tf_lens_16"
model.load_state_dict(torch.load(f'{MODEL_DIR}{model_name}.pth'))
model.to(device)

Moving model to device:  mps


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-15): 16 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [32]:
layer = 12

@dataclass
class Config:
    min_val: int
    max_val: int
    custom_function: callable
    linear_probe_name: str

piece_config = Config(
    min_val = -6,
    max_val = 6,
    custom_function = chess_utils.board_to_piece_state,
    linear_probe_name = "chess_piece_probe",
)

color_config = Config(
    min_val = -1,
    max_val = 1,
    custom_function=chess_utils.board_to_piece_color_state,
    linear_probe_name="chess_color_probe",
)

random_config = Config(
    min_val = -1,
    max_val = 1,
    custom_function=chess_utils.board_to_random_state,
    linear_probe_name="chess_random_probe",
)

config = piece_config
config = color_config
config = random_config

In [33]:
board_seqs_int = torch.tensor(np.load(f"{DATA_DIR}train_board_seqs_int.npy")).long()
print(board_seqs_int.shape)
dots_indices = torch.tensor(np.load(f"{DATA_DIR}train_dots_indices.npy")).long()
# state_stack = torch.tensor(np.load("state_stacks_5k.npy")).long()
print(dots_indices.shape)
# print(state_stack.shape)

board_seqs_string = []

with open(f"{DATA_DIR}train_board_seqs_string.csv", newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for row in reader:
        board_seqs_string.append(row[0])
print(len(board_seqs_string), len(board_seqs_string[0]))
# print(board_seqs_string[0])


torch.Size([32965, 680])
torch.Size([32965, 61])
32965 680


In [34]:
# print(board_seqs_string[0])
# custom_function = chess_utils.board_to_piece_state
custom_function = config.custom_function
print(chess_utils.create_state_stack(board_seqs_string[0], custom_function).shape)

state_stack = torch.tensor(chess_utils.create_state_stack(board_seqs_string[0], custom_function)).long()
print(state_stack.shape)

state_stacks = chess_utils.create_state_stacks(board_seqs_string[:50], custom_function)
print(state_stacks.shape)

(680, 8, 8)
torch.Size([680, 8, 8])
torch.Size([1, 50, 680, 8, 8])


In [35]:
batch_size = 1
lr = 2e-4
wd = 0.01
pos_start = 5 # indexes into white_moves_indices or dot_indices
# pos_end = model.cfg.n_ctx - 5
# input_length = 680
# pos_end = input_length - 0
# length = pos_end - pos_start
one_hot_range = config.max_val - config.min_val + 1
rows = 8
cols = 8
num_epochs = 1
num_games = 10000
x = 0
y = 2
# The first mode is blank or not, the second mode is next or prev GIVEN that it is not blank
modes = 1
# alternating = torch.tensor([1 if i%2 == 0 else -1 for i in range(length)], device=device)


state_stack_one_hot = chess_utils.state_stack_to_one_hot(modes, rows, cols, config.min_val, config.max_val, device, state_stacks)
print(state_stack_one_hot.shape)
print((state_stack_one_hot[:, 1, 170, 4:9, 2:5]))
print((state_stacks[:, 1, 170, 4:9, 2:5]))

torch.Size([1, 50, 680, 8, 8, 3])
tensor([[[[1, 0, 0],
          [0, 1, 0],
          [1, 0, 0]],

         [[0, 1, 0],
          [1, 0, 0],
          [0, 0, 1]],

         [[0, 1, 0],
          [0, 0, 1],
          [0, 1, 0]],

         [[0, 1, 0],
          [0, 1, 0],
          [1, 0, 0]]]], device='mps:0')
tensor([[[-1,  0, -1],
         [ 0, -1,  1],
         [ 0,  1,  0],
         [ 0,  0, -1]]])


In [36]:
def train_linear_probe(layer: int):
    linear_probe_name = f"{MODEL_DIR}{model_name}_{config.linear_probe_name}_layer_{layer}.pth"
    linear_probe = torch.randn(
        modes, model.cfg.d_model, rows, cols, one_hot_range, requires_grad=False, device=device
    )/np.sqrt(model.cfg.d_model)
    linear_probe.requires_grad = True
    print(linear_probe.shape)
    optimiser = torch.optim.AdamW([linear_probe], lr=lr, betas=(0.9, 0.99), weight_decay=wd)

    print(dots_indices.shape)
    # mask = dots_indices < 245
    # dots_indices = dots_indices[mask]

    # print(dots_indices.shape)

    lr = 3e-4
    max_lr = 3e-4
    min_lr = lr / 10
    max_iters = num_games * num_epochs
    decay_lr = True

    def get_lr(current_iter: int, max_iters: int, lr: float, min_lr: float) -> float:
        """
        Calculate the learning rate using linear decay.

        Args:
        - current_iter (int): The current iteration.
        - max_iters (int): The total number of iterations for decay.
        - lr (float): The initial learning rate.
        - min_lr (float): The minimum learning rate after decay.

        Returns:
        - float: The calculated learning rate.
        """
        # Ensure current_iter does not exceed max_iters
        current_iter = min(current_iter, max_iters)

        # Calculate the linearly decayed learning rate
        decayed_lr = lr - (lr - min_lr) * (current_iter / max_iters)

        return decayed_lr
    current_iter = 0
    for epoch in range(num_epochs):
        full_train_indices = torch.randperm(num_games)
        for i in tqdm(range(0, num_games, batch_size)):

            lr = get_lr(current_iter, max_iters, max_lr, min_lr) if decay_lr else lr
            for param_group in optimiser.param_groups:
                param_group['lr'] = lr
            
            indices = full_train_indices[i:i+batch_size]
            list_of_indices = indices.tolist() # For indexing into the board_seqs_string list of strings
            # print(list_of_indices)
            games_int = board_seqs_int[indices]
            games_int = games_int[:, :]
            # print(games_int.shape)
            games_str = [board_seqs_string[idx] for idx in list_of_indices]
            games_str = [s[:] for s in games_str]
            games_dots = dots_indices[indices]
            games_dots = games_dots[:, pos_start:]
            # print(games_dots.shape)
            state_stack = chess_utils.create_state_stacks(games_str, custom_function)
            # state_stack = state_stack[:, pos_start:pos_end, :, :]
            # print(state_stack.shape)
            # Initialize a list to hold the indexed state stacks
            indexed_state_stacks = []

            for batch_idx in range(state_stack.size(0)):
                # Get the indices for the current batch
                dots_indices_for_batch = games_dots[batch_idx]

                # Index the state_stack for the current batch
                indexed_state_stack = state_stack[:, batch_idx, dots_indices_for_batch, :, :]

                # Append the result to the list
                indexed_state_stacks.append(indexed_state_stack)

            # Stack the indexed state stacks along the first dimension
            # This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
            state_stack = torch.stack(indexed_state_stacks)
            # print("after indexing state stack shape", state_stack.shape)

            state_stack_one_hot = chess_utils.state_stack_to_one_hot(modes, rows, cols, config.min_val, config.max_val, device, state_stack).to(device)
            with torch.inference_mode():
                _, cache = model.run_with_cache(games_int.to(device)[:, :-1], return_type=None)
                resid_post = cache["resid_post", layer][:, :]
            # Initialize a list to hold the indexed state stacks
            indexed_resid_posts = []

            for batch_idx in range(games_dots.size(0)):
                # Get the indices for the current batch
                dots_indices_for_batch = games_dots[batch_idx]

                # Index the state_stack for the current batch
                indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch]

                # Append the result to the list
                indexed_resid_posts.append(indexed_resid_post)

            # Stack the indexed state stacks along the first dimension
            # This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
            resid_post = torch.stack(indexed_resid_posts)
            # print("Resid post", resid_post.shape)
            probe_out = einsum(
                "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
                resid_post,
                linear_probe,
            )
            # print(probe_out.shape)

            # acc_blank = (probe_out[0].argmax(-1) == state_stack_one_hot[0].argmax(-1)).float().mean()
            # acc_color = ((probe_out[1].argmax(-1) == state_stack_one_hot[1].argmax(-1)) * state_stack_one_hot[1].sum(-1)).float().sum()/(state_stack_one_hot[1]).float().sum()

            

            probe_log_probs = probe_out.log_softmax(-1)
            probe_correct_log_probs = einops.reduce(
                probe_log_probs * state_stack_one_hot,
                "modes batch pos rows cols options -> modes pos rows cols",
                "mean"
            ) * one_hot_range # Multiply to correct for the mean over options
            # loss_even = -probe_correct_log_probs[0, 0::2].mean(0).sum() # note that "even" means odd in the game framing, since we offset by 5 moves lol
            # loss_odd = -probe_correct_log_probs[1, 1::2].mean(0).sum()
            loss_all = -probe_correct_log_probs[0, :].mean(0).sum()

            # if i % 1000 == 0:
            #     print(f"epoch {epoch}, batch {i}, acc_blank {acc_blank}, acc_color {acc_color}, loss_even {loss_even}, loss_odd {loss_odd}, loss_all {loss_all}")
            if i % 100 == 0:
                print(f"epoch {epoch}, batch {i}, loss_all {loss_all}, lr {lr}")
            # loss = loss_even + loss_odd + loss_all
            loss = loss_all
            loss.backward() # it's important to do a single backward pass for mysterious PyTorch reasons, so we add up the losses - it's per mode and per square.

            optimiser.step()
            optimiser.zero_grad()
            current_iter += batch_size
    torch.save(linear_probe, linear_probe_name)

for i in range(32):
    train_linear_probe(i)

torch.Size([1, 512, 8, 8, 3])
torch.Size([32965, 61])


  0%|          | 2/10000 [00:00<45:42,  3.65it/s]  

epoch 0, batch 0, loss_all 73.77289581298828, lr 0.0003


  1%|          | 103/10000 [00:10<15:12, 10.85it/s]

epoch 0, batch 100, loss_all 71.8547134399414, lr 0.00029729999999999996


  2%|▏         | 203/10000 [00:19<14:49, 11.01it/s]

epoch 0, batch 200, loss_all 71.87614440917969, lr 0.00029459999999999995


  3%|▎         | 303/10000 [00:28<14:40, 11.01it/s]

epoch 0, batch 300, loss_all 72.03749084472656, lr 0.0002919


  4%|▍         | 402/10000 [00:38<14:37, 10.94it/s]

epoch 0, batch 400, loss_all 71.44151306152344, lr 0.0002892


  5%|▌         | 502/10000 [00:47<14:24, 10.99it/s]

epoch 0, batch 500, loss_all 71.38539123535156, lr 0.00028649999999999997


  6%|▌         | 602/10000 [00:56<14:14, 11.00it/s]

epoch 0, batch 600, loss_all 71.24201965332031, lr 0.00028379999999999996


  7%|▋         | 702/10000 [01:05<14:10, 10.93it/s]

epoch 0, batch 700, loss_all 71.21056365966797, lr 0.00028109999999999995


  8%|▊         | 802/10000 [01:15<14:05, 10.88it/s]

epoch 0, batch 800, loss_all 71.10810852050781, lr 0.0002784


  9%|▉         | 902/10000 [01:24<14:02, 10.80it/s]

epoch 0, batch 900, loss_all 71.48045349121094, lr 0.0002757


 10%|█         | 1002/10000 [01:33<13:55, 10.77it/s]

epoch 0, batch 1000, loss_all 71.19061279296875, lr 0.00027299999999999997


 11%|█         | 1102/10000 [01:43<13:43, 10.80it/s]

epoch 0, batch 1100, loss_all 71.00178527832031, lr 0.00027029999999999996


 12%|█▏        | 1203/10000 [01:52<13:27, 10.89it/s]

epoch 0, batch 1200, loss_all 70.70330047607422, lr 0.0002676


 13%|█▎        | 1303/10000 [02:01<13:25, 10.79it/s]

epoch 0, batch 1300, loss_all 70.89091491699219, lr 0.0002649


 14%|█▍        | 1403/10000 [02:11<13:10, 10.87it/s]

epoch 0, batch 1400, loss_all 70.89710998535156, lr 0.0002622


 15%|█▌        | 1503/10000 [02:20<13:16, 10.67it/s]

epoch 0, batch 1500, loss_all 70.93946838378906, lr 0.00025949999999999997


 16%|█▌        | 1602/10000 [02:30<13:01, 10.74it/s]

epoch 0, batch 1600, loss_all 70.94033813476562, lr 0.00025679999999999995


 17%|█▋        | 1702/10000 [02:39<13:16, 10.42it/s]

epoch 0, batch 1700, loss_all 70.89836120605469, lr 0.0002541


 18%|█▊        | 1802/10000 [02:48<12:28, 10.95it/s]

epoch 0, batch 1800, loss_all 70.942138671875, lr 0.0002514


 19%|█▉        | 1902/10000 [02:58<12:23, 10.90it/s]

epoch 0, batch 1900, loss_all 70.61750030517578, lr 0.0002487


 20%|██        | 2002/10000 [03:07<12:45, 10.45it/s]

epoch 0, batch 2000, loss_all 70.44158935546875, lr 0.00024599999999999996


 21%|██        | 2102/10000 [03:16<12:09, 10.82it/s]

epoch 0, batch 2100, loss_all 70.44566345214844, lr 0.00024329999999999998


 22%|██▏       | 2202/10000 [03:26<12:11, 10.66it/s]

epoch 0, batch 2200, loss_all 70.78483581542969, lr 0.0002406


 23%|██▎       | 2302/10000 [03:35<11:42, 10.96it/s]

epoch 0, batch 2300, loss_all 70.8141860961914, lr 0.00023789999999999998


 24%|██▍       | 2402/10000 [03:44<11:55, 10.62it/s]

epoch 0, batch 2400, loss_all 70.66683959960938, lr 0.00023519999999999997


 25%|██▌       | 2502/10000 [03:54<11:33, 10.80it/s]

epoch 0, batch 2500, loss_all 70.64362335205078, lr 0.00023249999999999999


 26%|██▌       | 2602/10000 [04:03<12:30,  9.86it/s]

epoch 0, batch 2600, loss_all 70.83964538574219, lr 0.0002298


 27%|██▋       | 2702/10000 [04:12<11:17, 10.77it/s]

epoch 0, batch 2700, loss_all 70.66072845458984, lr 0.0002271


 28%|██▊       | 2802/10000 [04:22<11:11, 10.72it/s]

epoch 0, batch 2800, loss_all 70.70637512207031, lr 0.00022439999999999998


 29%|██▉       | 2902/10000 [04:31<10:54, 10.85it/s]

epoch 0, batch 2900, loss_all 70.69142150878906, lr 0.0002217


 30%|███       | 3003/10000 [04:41<10:50, 10.75it/s]

epoch 0, batch 3000, loss_all 70.72978973388672, lr 0.000219


 31%|███       | 3103/10000 [04:50<10:34, 10.87it/s]

epoch 0, batch 3100, loss_all 70.66769409179688, lr 0.0002163


 32%|███▏      | 3202/10000 [04:59<10:24, 10.88it/s]

epoch 0, batch 3200, loss_all 70.72938537597656, lr 0.0002136


 33%|███▎      | 3302/10000 [05:09<10:18, 10.83it/s]

epoch 0, batch 3300, loss_all 70.59666442871094, lr 0.00021089999999999998


 34%|███▍      | 3402/10000 [05:18<10:02, 10.94it/s]

epoch 0, batch 3400, loss_all 70.79804992675781, lr 0.00020819999999999996


 35%|███▌      | 3502/10000 [05:27<10:08, 10.67it/s]

epoch 0, batch 3500, loss_all 70.51947021484375, lr 0.0002055


 36%|███▌      | 3602/10000 [05:37<10:06, 10.54it/s]

epoch 0, batch 3600, loss_all 70.80744171142578, lr 0.0002028


 37%|███▋      | 3702/10000 [05:46<09:38, 10.88it/s]

epoch 0, batch 3700, loss_all 70.54615783691406, lr 0.00020009999999999998


 38%|███▊      | 3802/10000 [05:55<09:37, 10.74it/s]

epoch 0, batch 3800, loss_all 70.47940063476562, lr 0.00019739999999999997


 39%|███▉      | 3902/10000 [06:05<10:13,  9.94it/s]

epoch 0, batch 3900, loss_all 70.80748748779297, lr 0.0001947


 40%|████      | 4002/10000 [06:14<09:08, 10.93it/s]

epoch 0, batch 4000, loss_all 70.54994201660156, lr 0.000192


 41%|████      | 4103/10000 [06:24<08:57, 10.97it/s]

epoch 0, batch 4100, loss_all 70.75016784667969, lr 0.0001893


 42%|████▏     | 4203/10000 [06:33<09:06, 10.60it/s]

epoch 0, batch 4200, loss_all 70.47393035888672, lr 0.00018659999999999998


 43%|████▎     | 4303/10000 [06:42<08:45, 10.84it/s]

epoch 0, batch 4300, loss_all 70.31485748291016, lr 0.0001839


 44%|████▍     | 4403/10000 [06:52<08:33, 10.89it/s]

epoch 0, batch 4400, loss_all 70.7564697265625, lr 0.0001812


 45%|████▌     | 4503/10000 [07:01<08:26, 10.85it/s]

epoch 0, batch 4500, loss_all 70.44734191894531, lr 0.0001785


 46%|████▌     | 4603/10000 [07:10<09:00,  9.98it/s]

epoch 0, batch 4600, loss_all 70.51951599121094, lr 0.0001758


 47%|████▋     | 4703/10000 [07:20<08:07, 10.87it/s]

epoch 0, batch 4700, loss_all 70.51974487304688, lr 0.0001731


 48%|████▊     | 4803/10000 [07:29<07:58, 10.85it/s]

epoch 0, batch 4800, loss_all 70.6993179321289, lr 0.0001704


 49%|████▉     | 4903/10000 [07:39<07:45, 10.95it/s]

epoch 0, batch 4900, loss_all 70.6276626586914, lr 0.0001677


 50%|█████     | 5003/10000 [07:48<07:44, 10.76it/s]

epoch 0, batch 5000, loss_all 70.42587280273438, lr 0.000165


 51%|█████     | 5103/10000 [07:57<07:28, 10.91it/s]

epoch 0, batch 5100, loss_all 70.45616912841797, lr 0.0001623


 52%|█████▏    | 5203/10000 [08:06<07:22, 10.83it/s]

epoch 0, batch 5200, loss_all 70.39408874511719, lr 0.0001596


 53%|█████▎    | 5303/10000 [08:16<07:19, 10.70it/s]

epoch 0, batch 5300, loss_all 70.48957061767578, lr 0.0001569


 54%|█████▍    | 5403/10000 [08:25<07:13, 10.60it/s]

epoch 0, batch 5400, loss_all 70.44171905517578, lr 0.00015419999999999998


 55%|█████▌    | 5503/10000 [08:35<06:59, 10.72it/s]

epoch 0, batch 5500, loss_all 70.6602554321289, lr 0.0001515


 56%|█████▌    | 5603/10000 [08:44<06:58, 10.50it/s]

epoch 0, batch 5600, loss_all 70.41004180908203, lr 0.00014879999999999998


 57%|█████▋    | 5703/10000 [08:54<06:40, 10.72it/s]

epoch 0, batch 5700, loss_all 70.5382080078125, lr 0.00014610000000000003


 58%|█████▊    | 5803/10000 [09:03<06:59, 10.01it/s]

epoch 0, batch 5800, loss_all 70.54951477050781, lr 0.00014340000000000002


 59%|█████▉    | 5903/10000 [09:12<06:15, 10.92it/s]

epoch 0, batch 5900, loss_all 70.65604400634766, lr 0.0001407


 60%|██████    | 6003/10000 [09:22<06:08, 10.86it/s]

epoch 0, batch 6000, loss_all 70.52391052246094, lr 0.00013800000000000002


 61%|██████    | 6103/10000 [09:31<06:11, 10.50it/s]

epoch 0, batch 6100, loss_all 70.41622924804688, lr 0.0001353


 62%|██████▏   | 6203/10000 [09:40<05:50, 10.82it/s]

epoch 0, batch 6200, loss_all 70.54849243164062, lr 0.0001326


 63%|██████▎   | 6303/10000 [09:50<05:47, 10.63it/s]

epoch 0, batch 6300, loss_all 70.39403533935547, lr 0.0001299


 64%|██████▍   | 6402/10000 [09:59<05:34, 10.77it/s]

epoch 0, batch 6400, loss_all 70.47978210449219, lr 0.0001272


 65%|██████▌   | 6503/10000 [10:09<05:24, 10.78it/s]

epoch 0, batch 6500, loss_all 70.44921875, lr 0.0001245


 66%|██████▌   | 6603/10000 [10:18<05:12, 10.86it/s]

epoch 0, batch 6600, loss_all 70.39442443847656, lr 0.0001218


 67%|██████▋   | 6703/10000 [10:27<05:34,  9.85it/s]

epoch 0, batch 6700, loss_all 70.53121948242188, lr 0.0001191


 68%|██████▊   | 6803/10000 [10:37<04:55, 10.84it/s]

epoch 0, batch 6800, loss_all 70.478271484375, lr 0.00011639999999999998


 69%|██████▉   | 6903/10000 [10:46<04:47, 10.78it/s]

epoch 0, batch 6900, loss_all 70.48043823242188, lr 0.00011370000000000003


 70%|███████   | 7003/10000 [10:56<04:37, 10.81it/s]

epoch 0, batch 7000, loss_all 70.45330810546875, lr 0.00011100000000000001


 71%|███████   | 7103/10000 [11:05<04:31, 10.67it/s]

epoch 0, batch 7100, loss_all 70.39140319824219, lr 0.00010830000000000003


 72%|███████▏  | 7203/10000 [11:14<04:21, 10.68it/s]

epoch 0, batch 7200, loss_all 70.46878051757812, lr 0.00010560000000000002


 73%|███████▎  | 7303/10000 [11:24<04:11, 10.74it/s]

epoch 0, batch 7300, loss_all 70.51050567626953, lr 0.00010290000000000001


 74%|███████▍  | 7403/10000 [11:33<04:01, 10.75it/s]

epoch 0, batch 7400, loss_all 70.65391540527344, lr 0.00010020000000000002


 75%|███████▌  | 7503/10000 [11:43<03:53, 10.68it/s]

epoch 0, batch 7500, loss_all 70.51913452148438, lr 9.750000000000001e-05


 76%|███████▌  | 7603/10000 [11:52<03:43, 10.72it/s]

epoch 0, batch 7600, loss_all 70.48078155517578, lr 9.48e-05


 77%|███████▋  | 7703/10000 [12:01<03:36, 10.62it/s]

epoch 0, batch 7700, loss_all 70.4695816040039, lr 9.210000000000002e-05


 78%|███████▊  | 7803/10000 [12:11<03:28, 10.54it/s]

epoch 0, batch 7800, loss_all 70.4023666381836, lr 8.94e-05


 79%|███████▉  | 7903/10000 [12:20<03:14, 10.81it/s]

epoch 0, batch 7900, loss_all 70.5257797241211, lr 8.669999999999999e-05


 80%|████████  | 8003/10000 [12:30<03:13, 10.31it/s]

epoch 0, batch 8000, loss_all 70.53546142578125, lr 8.400000000000001e-05


 81%|████████  | 8101/10000 [12:39<02:54, 10.86it/s]

epoch 0, batch 8100, loss_all 70.41712951660156, lr 8.13e-05


 82%|████████▏ | 8203/10000 [12:48<02:46, 10.77it/s]

epoch 0, batch 8200, loss_all 70.41244506835938, lr 7.860000000000004e-05


 83%|████████▎ | 8303/10000 [12:58<02:36, 10.82it/s]

epoch 0, batch 8300, loss_all 70.4600830078125, lr 7.590000000000003e-05


 84%|████████▍ | 8403/10000 [13:07<02:28, 10.76it/s]

epoch 0, batch 8400, loss_all 70.47145080566406, lr 7.320000000000002e-05


 85%|████████▌ | 8503/10000 [13:17<02:18, 10.81it/s]

epoch 0, batch 8500, loss_all 70.52392578125, lr 7.050000000000003e-05


 86%|████████▌ | 8603/10000 [13:26<02:08, 10.87it/s]

epoch 0, batch 8600, loss_all 70.33487701416016, lr 6.780000000000002e-05


 87%|████████▋ | 8702/10000 [13:35<01:59, 10.82it/s]

epoch 0, batch 8700, loss_all 70.50454711914062, lr 6.510000000000001e-05


 88%|████████▊ | 8802/10000 [13:45<01:51, 10.76it/s]

epoch 0, batch 8800, loss_all 70.4878158569336, lr 6.240000000000003e-05


 89%|████████▉ | 8902/10000 [13:54<01:40, 10.87it/s]

epoch 0, batch 8900, loss_all 70.41007995605469, lr 5.9700000000000015e-05


 90%|█████████ | 9002/10000 [14:03<01:34, 10.52it/s]

epoch 0, batch 9000, loss_all 70.52984619140625, lr 5.7e-05


 91%|█████████ | 9102/10000 [14:13<01:23, 10.82it/s]

epoch 0, batch 9100, loss_all 70.40798950195312, lr 5.430000000000002e-05


 92%|█████████▏| 9202/10000 [14:22<01:13, 10.78it/s]

epoch 0, batch 9200, loss_all 70.38543701171875, lr 5.160000000000001e-05


 93%|█████████▎| 9302/10000 [14:32<01:04, 10.78it/s]

epoch 0, batch 9300, loss_all 70.48184967041016, lr 4.8899999999999996e-05


 94%|█████████▍| 9402/10000 [14:41<00:55, 10.82it/s]

epoch 0, batch 9400, loss_all 70.41226196289062, lr 4.620000000000004e-05


 95%|█████████▌| 9502/10000 [14:50<00:46, 10.77it/s]

epoch 0, batch 9500, loss_all 70.48625946044922, lr 4.350000000000003e-05


 96%|█████████▌| 9602/10000 [15:00<00:37, 10.75it/s]

epoch 0, batch 9600, loss_all 70.3904037475586, lr 4.0800000000000016e-05


 97%|█████████▋| 9702/10000 [15:09<00:28, 10.42it/s]

epoch 0, batch 9700, loss_all 70.41682434082031, lr 3.8100000000000005e-05


 98%|█████████▊| 9802/10000 [15:18<00:18, 10.77it/s]

epoch 0, batch 9800, loss_all 70.40251922607422, lr 3.540000000000005e-05


 99%|█████████▉| 9902/10000 [15:28<00:09, 10.19it/s]

epoch 0, batch 9900, loss_all 70.46641540527344, lr 3.2700000000000036e-05


100%|██████████| 10000/10000 [15:37<00:00, 10.67it/s]
