# Importing libraries

In [1]:
import os
import sys
from dotenv import load_dotenv
from typing import Tuple
import torch
from torch.utils.data import Dataset, DataLoader
import wandb
from src.utils import set_seed, load_text, split_text, speedometer
from src.config import ModelConfig, TrainConfig, GenerationConfig
from src.tokenizer import CharTokenizer
from models.GPT import GPT
from src.train import Trainer



In [2]:
PROJECT_ROOT = os.path.abspath(os.getcwd())
sys.path.append(PROJECT_ROOT)
print(f"PROJECT_ROOT: {PROJECT_ROOT}")

PROJECT_ROOT: /home/pathfinder/projects/PathFinder


# Configuration

In [3]:
model_config = ModelConfig(
    vocab_size=-1,
    max_seq_len=128,
    d_embed=128,
    n_layers=4,
    n_heads=4,
    d_head=32,
    d_ff=512,
    rank=32
)

train_config = TrainConfig(
    debug=False,
    wandb_project="nanoGPT",
    model_name="nanoGPT",
    per_device_train_batch_size=512,
    per_device_eval_batch_size=1024,
    gradient_accumulation_steps=512 // 512,
    num_train_epochs=1,
    learning_rate=5e-3,
    eval_steps=100,
    mixed_precision=True,
    matmul_precision="high",
)

generation_config = GenerationConfig(
    use_cache=True,
    max_new_tokens=1000,
    temperature=1.0,
    top_k=50
)

In [4]:
load_dotenv()
wandb.login(key=os.environ.get("WANDB_API_KEY"))

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/pathfinder/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpathfinderkr[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Utils

## Reproducibility

In [5]:
set_seed(train_config.seed)

Random seed set to 42


## Device

In [6]:
device = torch.device("cuda")
print(f"Device: {torch.cuda.get_device_name(device)}")
torch.set_float32_matmul_precision(train_config.matmul_precision)  # Tensor Cores
print(f"MatMul Precision: {train_config.matmul_precision}")

Device: NVIDIA GeForce RTX 4080 SUPER
MatMul Precision: high


# Dataset

In [7]:
dataset_path = os.path.join(PROJECT_ROOT, "datasets/Shakespeare/shakespeare.txt")
shakespeare_text = load_text(dataset_path)

Loaded text data from /home/pathfinder/projects/PathFinder/datasets/Shakespeare/shakespeare.txt (length: 1115394 characters).


In [8]:
if train_config.debug:
    subset_shakespeare_text = shakespeare_text[:10000]
    print(subset_shakespeare_text)
    shakespeare_text = subset_shakespeare_text

# Tokenizer

In [9]:
char_tokenizer = CharTokenizer()
char_tokenizer.build_vocab(text=shakespeare_text)
vocab_path = os.path.join(PROJECT_ROOT, "datasets/Shakespeare/vocab.json")
char_tokenizer.save_vocab(vocab_path)
model_config.vocab_size = char_tokenizer.vocab_size

Vocabulary size: 69
Vocabulary saved to /home/pathfinder/projects/PathFinder/datasets/Shakespeare/vocab.json.


In [10]:
if train_config.debug:
    print("Vocabulary:", char_tokenizer.char2idx)

# Preprocessing

In [11]:
train_text, val_text = split_text(shakespeare_text, val_size=0.1)
print(f"Training text length: {len(train_text)} characters")
print(f"Validation text length: {len(val_text)} characters")

Training text length: 1003854 characters
Validation text length: 111540 characters


In [12]:
class TextDataset(Dataset):
    def __init__(self, text: str, tokenizer: CharTokenizer, max_seq_len: int):
        self.encoded = tokenizer.encode(text)
        self.max_seq_len = max_seq_len

    def __len__(self) -> int:
        return len(self.encoded) - self.max_seq_len

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        input_ids = self.encoded[idx:idx + self.max_seq_len]
        target_ids = self.encoded[idx + 1:idx + self.max_seq_len + 1]
        return input_ids, target_ids

def collate_fn(batch):
    input_ids = torch.stack([item[0] for item in batch])
    target_ids = torch.stack([item[1] for item in batch])
    return {"input_ids": input_ids, "target_ids": target_ids}

train_dataset = TextDataset(train_text, char_tokenizer, model_config.max_seq_len)
val_dataset = TextDataset(val_text, char_tokenizer, model_config.max_seq_len)

train_loader = DataLoader(
    train_dataset,
    collate_fn=collate_fn,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=True,
    num_workers=4
)
val_loader = DataLoader(
    val_dataset,
    collate_fn=collate_fn,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=False,
    num_workers=4
)

In [13]:
if train_config.debug:
    sample_batch = next(iter(train_loader))
    print(f"Sample input IDs: {sample_batch['input_ids'][0]}")
    print(f"Sample target IDs: {sample_batch['target_ids'][0]}")

# Model

In [14]:
# Initialize the model
model = GPT(model_config).to(device)
model = torch.compile(model)
print(model)
print(f"Number of parameters: {model.num_params() / 1e6:.2f}M")

OptimizedModule(
  (_orig_mod): GPT(
    (token_embedding): Embedding(69, 128)
    (positional_encoding): Embedding(128, 128)
    (dropout): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0-3): 4 x Block(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (Wq): Linear(in_features=128, out_features=128, bias=False)
          (Wkv_down): Linear(in_features=128, out_features=32, bias=False)
          (Wkv_up): Linear(in_features=32, out_features=256, bias=False)
          (Wk_up): Linear(in_features=32, out_features=128, bias=False)
          (Wv_up): Linear(in_features=32, out_features=128, bias=False)
          (out_proj): Linear(in_features=128, out_features=128, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): FeedForward(
          (fc1): Linear(in_features=128, out_features=512, bias=False)
   

# Training

In [15]:
trainer = Trainer(
    model=model,
    train_config=train_config,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    master_process=True
)
trainer.train()

Training:   0%|          | 0/981 [00:00<?, ?it/s]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])



class GraphModule(torch.nn.Module):
    def forward(self, L_kv_latent_: "bf16[1024, 128, 32][4096, 32, 1]cuda:0"):
        l_kv_latent_ = L_kv_latent_
        


k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:   0%|          | 4/981 [00:00<-1:59:52, -109.76it/s, epoch=1, grad_norm=1.5479, loss=4.6219, lr=0.000255]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   1%|          | 9/981 [00:00<00:12, 78.71it/s, epoch=1, grad_norm=1.2814, loss=3.8585, lr=0.000510]     

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   1%|▏         | 14/981 [00:00<00:12, 78.71it/s, epoch=1, grad_norm=0.5958, loss=3.3824, lr=0.000765]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   2%|▏         | 19/981 [00:00<00:31, 30.68it/s, epoch=1, grad_norm=0.3623, loss=3.3206, lr=0.001020]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   3%|▎         | 25/981 [00:00<00:38, 24.81it/s, epoch=1, grad_norm=0.2705, loss=3.2964, lr=0.001276]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   3%|▎         | 29/981 [00:01<00:39, 23.96it/s, epoch=1, grad_norm=0.6198, loss=3.1000, lr=0.001531]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   3%|▎         | 34/981 [00:01<00:41, 22.83it/s, epoch=1, grad_norm=0.4629, loss=2.8847, lr=0.001786]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   4%|▍         | 40/981 [00:01<00:42, 22.03it/s, epoch=1, grad_norm=0.2900, loss=2.7456, lr=0.002041]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   4%|▍         | 44/981 [00:01<00:43, 21.75it/s, epoch=1, grad_norm=0.2370, loss=2.6592, lr=0.002296]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   5%|▍         | 49/981 [00:02<00:43, 21.29it/s, epoch=1, grad_norm=0.1843, loss=2.6093, lr=0.002551]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   6%|▌         | 55/981 [00:02<00:43, 21.36it/s, epoch=1, grad_norm=0.1607, loss=2.5702, lr=0.002806]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   6%|▌         | 59/981 [00:02<00:43, 21.24it/s, epoch=1, grad_norm=0.1507, loss=2.5479, lr=0.003061]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   7%|▋         | 64/981 [00:02<00:43, 21.14it/s, epoch=1, grad_norm=0.1730, loss=2.5217, lr=0.003316]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   7%|▋         | 70/981 [00:03<00:43, 21.01it/s, epoch=1, grad_norm=0.8253, loss=2.5381, lr=0.003571]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   8%|▊         | 74/981 [00:03<00:42, 21.14it/s, epoch=1, grad_norm=0.2115, loss=2.4933, lr=0.003827]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   8%|▊         | 79/981 [00:03<00:42, 21.18it/s, epoch=1, grad_norm=0.1892, loss=2.4843, lr=0.004082]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   9%|▊         | 85/981 [00:03<00:42, 21.02it/s, epoch=1, grad_norm=0.1748, loss=2.4743, lr=0.004337]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:   9%|▉         | 89/981 [00:03<00:42, 20.89it/s, epoch=1, grad_norm=0.1311, loss=2.4594, lr=0.004592]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  10%|▉         | 94/981 [00:04<00:42, 20.94it/s, epoch=1, grad_norm=0.5112, loss=2.4725, lr=0.004847]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  10%|█         | 100/981 [00:04<00:42, 20.97it/s, epoch=1, grad_norm=0.5053, loss=2.4816, lr=0.005000]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  11%|█         | 104/981 [00:08<05:35,  2.61it/s, epoch=1, grad_norm=0.2290, loss=2.4447, lr=0.004999]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  11%|█         | 109/981 [00:08<03:05,  4.70it/s, epoch=1, grad_norm=0.1761, loss=2.4317, lr=0.004998]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  12%|█▏        | 113/981 [00:08<02:21,  6.12it/s, epoch=1, grad_norm=0.1834, loss=2.4050, lr=0.004996]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  12%|█▏        | 118/981 [00:08<01:36,  8.93it/s, epoch=1, grad_norm=0.6351, loss=2.4383, lr=0.004993]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  13%|█▎        | 123/981 [00:08<01:10, 12.21it/s, epoch=1, grad_norm=0.3058, loss=2.4330, lr=0.004989]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  13%|█▎        | 128/981 [00:09<00:53, 15.85it/s, epoch=1, grad_norm=0.2305, loss=2.3932, lr=0.004985]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  14%|█▎        | 134/981 [00:09<00:46, 18.25it/s, epoch=1, grad_norm=0.6577, loss=2.4068, lr=0.004980]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  14%|█▍        | 138/981 [00:09<00:44, 19.07it/s, epoch=1, grad_norm=0.2359, loss=2.3609, lr=0.004973]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  15%|█▍        | 143/981 [00:09<00:41, 19.99it/s, epoch=1, grad_norm=0.8451, loss=2.3525, lr=0.004967]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  15%|█▌        | 149/981 [00:10<00:41, 20.01it/s, epoch=1, grad_norm=0.3249, loss=2.3206, lr=0.004959]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  15%|█▌        | 152/981 [00:10<00:41, 19.82it/s, epoch=1, grad_norm=0.3181, loss=2.3078, lr=0.004952]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  16%|█▌        | 156/981 [00:10<00:41, 19.94it/s, epoch=1, grad_norm=0.2909, loss=2.2807, lr=0.004945]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  16%|█▋        | 161/981 [00:10<00:40, 20.40it/s, epoch=1, grad_norm=0.3739, loss=2.2461, lr=0.004935]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  17%|█▋        | 167/981 [00:11<00:39, 20.48it/s, epoch=1, grad_norm=0.2931, loss=2.2105, lr=0.004925]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  17%|█▋        | 171/981 [00:11<00:39, 20.44it/s, epoch=1, grad_norm=0.3343, loss=2.1858, lr=0.004914]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  18%|█▊        | 176/981 [00:11<00:38, 20.70it/s, epoch=1, grad_norm=0.1984, loss=2.1483, lr=0.004902]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  19%|█▊        | 182/981 [00:11<00:38, 20.87it/s, epoch=1, grad_norm=0.7630, loss=2.1803, lr=0.004889]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  19%|█▉        | 186/981 [00:12<00:37, 21.00it/s, epoch=1, grad_norm=0.3547, loss=2.1177, lr=0.004876]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  19%|█▉        | 191/981 [00:12<00:37, 20.81it/s, epoch=1, grad_norm=0.3139, loss=2.0955, lr=0.004861]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  20%|██        | 197/981 [00:12<00:37, 20.84it/s, epoch=1, grad_norm=0.3863, loss=2.0727, lr=0.004847]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  20%|██        | 200/981 [00:12<00:37, 20.93it/s, epoch=1, grad_norm=0.2519, loss=2.0558, lr=0.004837]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  20%|██        | 200/981 [00:14<00:37, 20.93it/s, epoch=1, grad_norm=0.1827, loss=2.0441, lr=0.004834]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  21%|██        | 206/981 [00:15<02:30,  5.14it/s, epoch=1, grad_norm=0.4371, loss=2.0291, lr=0.004818]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  21%|██▏       | 209/981 [00:15<02:05,  6.13it/s, epoch=1, grad_norm=0.7778, loss=2.0383, lr=0.004804]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  22%|██▏       | 213/981 [00:15<01:23,  9.15it/s, epoch=1, grad_norm=0.4842, loss=2.0172, lr=0.004794]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  22%|██▏       | 217/981 [00:15<01:03, 12.04it/s, epoch=1, grad_norm=0.3713, loss=1.9802, lr=0.004776]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  23%|██▎       | 221/981 [00:15<00:51, 14.69it/s, epoch=1, grad_norm=0.2413, loss=1.9607, lr=0.004764]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  23%|██▎       | 225/981 [00:16<00:45, 16.59it/s, epoch=1, grad_norm=0.3367, loss=1.9466, lr=0.004749]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  23%|██▎       | 230/981 [00:16<00:41, 18.26it/s, epoch=1, grad_norm=0.2785, loss=1.9209, lr=0.004729]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  24%|██▎       | 232/981 [00:16<00:40, 18.35it/s, epoch=1, grad_norm=0.2423, loss=1.9072, lr=0.004717]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  24%|██▍       | 236/981 [00:16<00:40, 18.54it/s, epoch=1, grad_norm=0.2490, loss=1.8895, lr=0.004700]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  24%|██▍       | 240/981 [00:16<00:40, 18.45it/s, epoch=1, grad_norm=0.8417, loss=1.8965, lr=0.004683]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  25%|██▌       | 246/981 [00:17<00:38, 18.95it/s, epoch=1, grad_norm=0.3224, loss=1.8685, lr=0.004661]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  25%|██▌       | 248/981 [00:17<00:39, 18.64it/s, epoch=1, grad_norm=0.2551, loss=1.8528, lr=0.004648]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  26%|██▌       | 252/981 [00:17<00:38, 18.71it/s, epoch=1, grad_norm=0.2563, loss=1.8439, lr=0.004629]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  26%|██▌       | 256/981 [00:17<00:39, 18.54it/s, epoch=1, grad_norm=0.2049, loss=1.8279, lr=0.004611]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  27%|██▋       | 260/981 [00:17<00:39, 18.34it/s, epoch=1, grad_norm=0.3193, loss=1.7972, lr=0.004591]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  27%|██▋       | 264/981 [00:18<00:40, 17.78it/s, epoch=1, grad_norm=0.2480, loss=1.7882, lr=0.004572]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  27%|██▋       | 268/981 [00:18<00:39, 18.15it/s, epoch=1, grad_norm=0.2053, loss=1.7701, lr=0.004551]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  28%|██▊       | 272/981 [00:18<00:39, 18.12it/s, epoch=1, grad_norm=0.2557, loss=1.7642, lr=0.004531]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  28%|██▊       | 276/981 [00:18<00:38, 18.46it/s, epoch=1, grad_norm=0.3570, loss=1.7532, lr=0.004510]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  29%|██▊       | 280/981 [00:19<00:38, 18.41it/s, epoch=1, grad_norm=0.3067, loss=1.7417, lr=0.004489]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  29%|██▉       | 284/981 [00:19<00:38, 18.30it/s, epoch=1, grad_norm=0.2851, loss=1.7387, lr=0.004467]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  29%|██▉       | 288/981 [00:19<00:38, 18.09it/s, epoch=1, grad_norm=0.2176, loss=1.7182, lr=0.004445]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  30%|██▉       | 292/981 [00:19<00:37, 18.20it/s, epoch=1, grad_norm=0.1950, loss=1.7119, lr=0.004422]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  30%|███       | 296/981 [00:19<00:37, 18.48it/s, epoch=1, grad_norm=0.2610, loss=1.7024, lr=0.004399]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  31%|███       | 300/981 [00:20<00:36, 18.62it/s, epoch=1, grad_norm=0.2629, loss=1.6922, lr=0.004382]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  31%|███       | 302/981 [00:22<04:02,  2.80it/s, epoch=1, grad_norm=0.2063, loss=1.6805, lr=0.004364]

k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size(

Training:  31%|███       | 306/981 [00:22<02:16,  4.94it/s, epoch=1, grad_norm=0.3049, loss=1.6641, lr=0.004340]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  32%|███▏      | 310/981 [00:22<01:24,  7.92it/s, epoch=1, grad_norm=0.2256, loss=1.6716, lr=0.004316]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  32%|███▏      | 314/981 [00:22<00:59, 11.13it/s, epoch=1, grad_norm=0.1747, loss=1.6417, lr=0.004291]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  32%|███▏      | 318/981 [00:23<00:46, 14.15it/s, epoch=1, grad_norm=0.3372, loss=1.6596, lr=0.004266]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  33%|███▎      | 322/981 [00:23<00:40, 16.33it/s, epoch=1, grad_norm=0.2842, loss=1.6487, lr=0.004241]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  33%|███▎      | 327/981 [00:23<00:37, 17.49it/s, epoch=1, grad_norm=0.2027, loss=1.6214, lr=0.004209]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  34%|███▎      | 331/981 [00:23<00:35, 18.44it/s, epoch=1, grad_norm=0.2227, loss=1.6174, lr=0.004189]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  34%|███▍      | 335/981 [00:23<00:34, 18.93it/s, epoch=1, grad_norm=0.2041, loss=1.6168, lr=0.004163]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  35%|███▍      | 339/981 [00:24<00:33, 19.00it/s, epoch=1, grad_norm=0.1748, loss=1.5986, lr=0.004136]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  35%|███▍      | 343/981 [00:24<00:33, 18.97it/s, epoch=1, grad_norm=0.1996, loss=1.6078, lr=0.004109]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  35%|███▌      | 347/981 [00:24<00:33, 18.76it/s, epoch=1, grad_norm=0.3179, loss=1.6020, lr=0.004081]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  36%|███▌      | 351/981 [00:24<00:33, 18.76it/s, epoch=1, grad_norm=0.2046, loss=1.6045, lr=0.004054]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  36%|███▌      | 355/981 [00:25<00:32, 19.13it/s, epoch=1, grad_norm=0.2301, loss=1.5867, lr=0.004019]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  37%|███▋      | 359/981 [00:25<00:31, 19.55it/s, epoch=1, grad_norm=0.2226, loss=1.5690, lr=0.003990]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  37%|███▋      | 364/981 [00:25<00:30, 20.15it/s, epoch=1, grad_norm=0.2117, loss=1.5733, lr=0.003954]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  38%|███▊      | 370/981 [00:25<00:29, 20.58it/s, epoch=1, grad_norm=0.2929, loss=1.5713, lr=0.003918]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  38%|███▊      | 374/981 [00:25<00:29, 20.76it/s, epoch=1, grad_norm=0.2186, loss=1.5535, lr=0.003881]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  39%|███▊      | 379/981 [00:26<00:28, 20.98it/s, epoch=1, grad_norm=0.2628, loss=1.5556, lr=0.003844]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  39%|███▉      | 385/981 [00:26<00:28, 21.10it/s, epoch=1, grad_norm=0.2195, loss=1.5451, lr=0.003806]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  40%|███▉      | 389/981 [00:26<00:27, 21.15it/s, epoch=1, grad_norm=0.2376, loss=1.5365, lr=0.003768]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  40%|████      | 394/981 [00:26<00:27, 21.16it/s, epoch=1, grad_norm=0.2068, loss=1.5349, lr=0.003729]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  41%|████      | 400/981 [00:27<00:27, 20.98it/s, epoch=1, grad_norm=0.1984, loss=1.5268, lr=0.003690]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  41%|████      | 403/981 [00:29<02:20,  4.13it/s, epoch=1, grad_norm=0.1889, loss=1.5273, lr=0.003659]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  42%|████▏     | 409/981 [00:29<01:21,  6.99it/s, epoch=1, grad_norm=0.1898, loss=1.5143, lr=0.003619]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  42%|████▏     | 413/981 [00:29<01:04,  8.76it/s, epoch=1, grad_norm=0.1711, loss=1.5110, lr=0.003580]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  43%|████▎     | 418/981 [00:29<00:45, 12.50it/s, epoch=1, grad_norm=0.1927, loss=1.5014, lr=0.003539]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  43%|████▎     | 424/981 [00:30<00:35, 15.75it/s, epoch=1, grad_norm=0.3138, loss=1.5085, lr=0.003499]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  44%|████▎     | 428/981 [00:30<00:32, 16.99it/s, epoch=1, grad_norm=0.1900, loss=1.4950, lr=0.003458]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  44%|████▍     | 433/981 [00:30<00:29, 18.69it/s, epoch=1, grad_norm=0.1964, loss=1.4951, lr=0.003417]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  45%|████▍     | 438/981 [00:28<00:27, 19.41it/s, epoch=1, grad_norm=0.1816, loss=1.4900, lr=0.003375]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  45%|████▌     | 443/981 [00:29<00:27, 19.41it/s, epoch=1, grad_norm=0.1655, loss=1.4900, lr=0.003333]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  46%|████▌     | 448/981 [00:29<00:27, 19.41it/s, epoch=1, grad_norm=0.2777, loss=1.4899, lr=0.003291]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  46%|████▌     | 453/981 [00:29<00:27, 19.41it/s, epoch=1, grad_norm=0.1496, loss=1.4754, lr=0.003249]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  47%|████▋     | 458/981 [00:29<00:26, 19.41it/s, epoch=1, grad_norm=0.1546, loss=1.4687, lr=0.003206]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  47%|████▋     | 463/981 [00:30<00:26, 19.41it/s, epoch=1, grad_norm=0.1817, loss=1.4687, lr=0.003163]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  48%|████▊     | 468/981 [00:30<00:26, 19.41it/s, epoch=1, grad_norm=0.2038, loss=1.4729, lr=0.003121]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  48%|████▊     | 473/981 [00:30<00:26, 19.41it/s, epoch=1, grad_norm=0.1970, loss=1.4757, lr=0.003077]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  49%|████▊     | 478/981 [00:30<00:25, 19.41it/s, epoch=1, grad_norm=0.1797, loss=1.4693, lr=0.003034]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  49%|████▉     | 483/981 [00:31<00:04, 101.84it/s, epoch=1, grad_norm=0.1964, loss=1.4654, lr=0.002990]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  50%|████▉     | 488/981 [00:31<00:04, 101.84it/s, epoch=1, grad_norm=0.1956, loss=1.4526, lr=0.002947]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  50%|█████     | 493/981 [00:31<00:09, 53.78it/s, epoch=1, grad_norm=0.1482, loss=1.4561, lr=0.002903] 

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  51%|█████     | 498/981 [00:31<00:08, 53.78it/s, epoch=1, grad_norm=0.1517, loss=1.4429, lr=0.002859]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  51%|█████     | 500/981 [00:31<00:11, 40.37it/s, epoch=1, grad_norm=0.1553, loss=1.4551, lr=0.002850]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  51%|█████▏    | 504/981 [00:34<00:11, 40.37it/s, epoch=1, grad_norm=0.1625, loss=1.4509, lr=0.002806]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  52%|█████▏    | 509/981 [00:34<00:39, 11.98it/s, epoch=1, grad_norm=0.1658, loss=1.4515, lr=0.002762]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  52%|█████▏    | 514/981 [00:34<00:36, 12.89it/s, epoch=1, grad_norm=0.1780, loss=1.4438, lr=0.002718]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  53%|█████▎    | 518/981 [00:34<00:33, 13.68it/s, epoch=1, grad_norm=0.1736, loss=1.4422, lr=0.002682]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  53%|█████▎    | 523/981 [00:35<00:30, 15.18it/s, epoch=1, grad_norm=0.2018, loss=1.4379, lr=0.002638]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  54%|█████▍    | 529/981 [00:35<00:26, 16.98it/s, epoch=1, grad_norm=0.1505, loss=1.4288, lr=0.002593]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  54%|█████▍    | 533/981 [00:35<00:25, 17.78it/s, epoch=1, grad_norm=0.2330, loss=1.4359, lr=0.002549]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  55%|█████▍    | 538/981 [00:35<00:24, 18.45it/s, epoch=1, grad_norm=0.1895, loss=1.4224, lr=0.002504]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  55%|█████▌    | 542/981 [00:35<00:23, 19.04it/s, epoch=1, grad_norm=0.1458, loss=1.4305, lr=0.002469]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  56%|█████▌    | 547/981 [00:36<00:21, 19.84it/s, epoch=1, grad_norm=0.1579, loss=1.4217, lr=0.002424]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  56%|█████▋    | 553/981 [00:36<00:21, 20.37it/s, epoch=1, grad_norm=0.1548, loss=1.4187, lr=0.002380]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  57%|█████▋    | 557/981 [00:36<00:20, 20.52it/s, epoch=1, grad_norm=0.1881, loss=1.4234, lr=0.002336]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  57%|█████▋    | 562/981 [00:36<00:20, 20.47it/s, epoch=1, grad_norm=0.1453, loss=1.4147, lr=0.002291]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  58%|█████▊    | 566/981 [00:37<00:20, 19.99it/s, epoch=1, grad_norm=0.1776, loss=1.4133, lr=0.002256]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  58%|█████▊    | 571/981 [00:37<00:19, 20.65it/s, epoch=1, grad_norm=0.1534, loss=1.4210, lr=0.002220]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  59%|█████▊    | 574/981 [00:37<00:24, 16.58it/s, epoch=1, grad_norm=0.1500, loss=1.4082, lr=0.002194]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  59%|█████▉    | 579/981 [00:37<00:24, 16.66it/s, epoch=1, grad_norm=0.1517, loss=1.4120, lr=0.002150]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  59%|█████▉    | 583/981 [00:38<00:22, 17.94it/s, epoch=1, grad_norm=0.1790, loss=1.4113, lr=0.002106]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  60%|█████▉    | 588/981 [00:38<00:20, 19.26it/s, epoch=1, grad_norm=0.1376, loss=1.4140, lr=0.002062]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  61%|██████    | 594/981 [00:38<00:19, 19.86it/s, epoch=1, grad_norm=0.1332, loss=1.4011, lr=0.002018]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  61%|██████    | 598/981 [00:38<00:18, 20.17it/s, epoch=1, grad_norm=0.1463, loss=1.4107, lr=0.001975]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  61%|██████    | 600/981 [00:38<00:18, 20.32it/s, epoch=1, grad_norm=0.1571, loss=1.3956, lr=0.001966]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  61%|██████▏   | 603/981 [00:41<01:34,  4.01it/s, epoch=1, grad_norm=0.1596, loss=1.4133, lr=0.001931]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  62%|██████▏   | 608/981 [00:41<00:59,  6.32it/s, epoch=1, grad_norm=0.1485, loss=1.4142, lr=0.001897]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  62%|██████▏   | 611/981 [00:41<00:45,  8.11it/s, epoch=1, grad_norm=0.1546, loss=1.4041, lr=0.001862]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  63%|██████▎   | 617/981 [00:41<00:29, 12.27it/s, epoch=1, grad_norm=0.1473, loss=1.3980, lr=0.001819]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  63%|██████▎   | 620/981 [00:41<00:25, 14.34it/s, epoch=1, grad_norm=0.1719, loss=1.4055, lr=0.001785]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  64%|██████▎   | 625/981 [00:42<00:21, 16.76it/s, epoch=1, grad_norm=0.1598, loss=1.4069, lr=0.001743]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  64%|██████▍   | 629/981 [00:42<00:19, 17.91it/s, epoch=1, grad_norm=0.1305, loss=1.3848, lr=0.001709]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  65%|██████▍   | 634/981 [00:42<00:18, 18.97it/s, epoch=1, grad_norm=0.1584, loss=1.3948, lr=0.001667]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  65%|██████▌   | 639/981 [00:42<00:17, 19.50it/s, epoch=1, grad_norm=0.1587, loss=1.3882, lr=0.001633]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  65%|██████▌   | 642/981 [00:43<00:17, 19.41it/s, epoch=1, grad_norm=0.1466, loss=1.3909, lr=0.001600]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  66%|██████▌   | 646/981 [00:43<00:17, 19.21it/s, epoch=1, grad_norm=0.1340, loss=1.3928, lr=0.001567]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  66%|██████▋   | 651/981 [00:43<00:16, 19.50it/s, epoch=1, grad_norm=0.1332, loss=1.3866, lr=0.001534]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  67%|██████▋   | 654/981 [00:43<00:16, 19.71it/s, epoch=1, grad_norm=0.1350, loss=1.3875, lr=0.001501]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  67%|██████▋   | 660/981 [00:43<00:16, 19.80it/s, epoch=1, grad_norm=0.1327, loss=1.3748, lr=0.001461]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  68%|██████▊   | 663/981 [00:44<00:16, 19.77it/s, epoch=1, grad_norm=0.1340, loss=1.3965, lr=0.001428]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  68%|██████▊   | 668/981 [00:44<00:15, 19.98it/s, epoch=1, grad_norm=0.1398, loss=1.3810, lr=0.001388]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  69%|██████▊   | 674/981 [00:44<00:15, 20.05it/s, epoch=1, grad_norm=0.1358, loss=1.3836, lr=0.001349]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  69%|██████▉   | 677/981 [00:44<00:15, 19.74it/s, epoch=1, grad_norm=0.1212, loss=1.3845, lr=0.001317]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  70%|██████▉   | 682/981 [00:44<00:15, 19.86it/s, epoch=1, grad_norm=0.1407, loss=1.3882, lr=0.001286]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  70%|██████▉   | 686/981 [00:45<00:14, 20.08it/s, epoch=1, grad_norm=0.1533, loss=1.3861, lr=0.001247]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  70%|███████   | 691/981 [00:45<00:14, 19.88it/s, epoch=1, grad_norm=0.1263, loss=1.3757, lr=0.001209]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  71%|███████   | 696/981 [00:45<00:14, 19.96it/s, epoch=1, grad_norm=0.1314, loss=1.3882, lr=0.001179]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  71%|███████▏  | 699/981 [00:45<00:14, 19.94it/s, epoch=1, grad_norm=0.1328, loss=1.3776, lr=0.001149]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  71%|███████▏  | 701/981 [00:47<01:14,  3.74it/s, epoch=1, grad_norm=0.1257, loss=1.3805, lr=0.001134]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  72%|███████▏  | 707/981 [00:48<00:40,  6.72it/s, epoch=1, grad_norm=0.1208, loss=1.3755, lr=0.001097]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  72%|███████▏  | 711/981 [00:48<00:31,  8.53it/s, epoch=1, grad_norm=0.1321, loss=1.3760, lr=0.001060]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  73%|███████▎  | 716/981 [00:48<00:21, 12.22it/s, epoch=1, grad_norm=0.1469, loss=1.3703, lr=0.001024]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  74%|███████▎  | 722/981 [00:48<00:16, 15.56it/s, epoch=1, grad_norm=0.1134, loss=1.3717, lr=0.000988]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  74%|███████▍  | 726/981 [00:49<00:15, 16.95it/s, epoch=1, grad_norm=0.1567, loss=1.3671, lr=0.000953]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  75%|███████▍  | 731/981 [00:49<00:13, 18.58it/s, epoch=1, grad_norm=0.1376, loss=1.3666, lr=0.000925]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  75%|███████▍  | 735/981 [00:49<00:12, 19.25it/s, epoch=1, grad_norm=0.1158, loss=1.3735, lr=0.000891]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  75%|███████▌  | 740/981 [00:49<00:12, 19.92it/s, epoch=1, grad_norm=0.1194, loss=1.3691, lr=0.000864]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  76%|███████▌  | 744/981 [00:49<00:11, 20.30it/s, epoch=1, grad_norm=0.1168, loss=1.3683, lr=0.000831]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  76%|███████▋  | 749/981 [00:50<00:11, 20.72it/s, epoch=1, grad_norm=0.1185, loss=1.3610, lr=0.000798]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  77%|███████▋  | 755/981 [00:50<00:10, 21.06it/s, epoch=1, grad_norm=0.1130, loss=1.3639, lr=0.000766]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  77%|███████▋  | 759/981 [00:50<00:10, 21.08it/s, epoch=1, grad_norm=0.1144, loss=1.3674, lr=0.000734]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  78%|███████▊  | 764/981 [00:50<00:10, 20.93it/s, epoch=1, grad_norm=0.1118, loss=1.3627, lr=0.000709]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  78%|███████▊  | 768/981 [00:51<00:10, 21.07it/s, epoch=1, grad_norm=0.1263, loss=1.3650, lr=0.000678]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  79%|███████▉  | 773/981 [00:51<00:09, 21.15it/s, epoch=1, grad_norm=0.1082, loss=1.3624, lr=0.000648]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  79%|███████▉  | 779/981 [00:51<00:09, 21.05it/s, epoch=1, grad_norm=0.1239, loss=1.3567, lr=0.000618]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  80%|███████▉  | 782/981 [00:51<00:09, 20.03it/s, epoch=1, grad_norm=0.1239, loss=1.3486, lr=0.000595]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  80%|████████  | 787/981 [00:52<00:10, 19.03it/s, epoch=1, grad_norm=0.1253, loss=1.3497, lr=0.000572]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  81%|████████  | 792/981 [00:52<00:09, 19.29it/s, epoch=1, grad_norm=0.1135, loss=1.3551, lr=0.000544]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  81%|████████  | 795/981 [00:52<00:09, 19.73it/s, epoch=1, grad_norm=0.1073, loss=1.3621, lr=0.000522]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  81%|████████▏ | 799/981 [00:52<00:09, 20.18it/s, epoch=1, grad_norm=0.1060, loss=1.3601, lr=0.000501]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  82%|████████▏ | 804/981 [00:54<00:33,  5.30it/s, epoch=1, grad_norm=0.1097, loss=1.3561, lr=0.000474]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  83%|████████▎ | 810/981 [00:55<00:19,  8.61it/s, epoch=1, grad_norm=0.1063, loss=1.3537, lr=0.000449]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  83%|████████▎ | 814/981 [00:55<00:15, 10.46it/s, epoch=1, grad_norm=0.1124, loss=1.3571, lr=0.000423]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  83%|████████▎ | 819/981 [00:55<00:11, 13.97it/s, epoch=1, grad_norm=0.1064, loss=1.3500, lr=0.000399]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  84%|████████▍ | 823/981 [00:55<00:10, 15.51it/s, epoch=1, grad_norm=0.1027, loss=1.3551, lr=0.000385]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  84%|████████▍ | 827/981 [00:55<00:09, 16.78it/s, epoch=1, grad_norm=0.1100, loss=1.3645, lr=0.000366]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  85%|████████▍ | 831/981 [00:56<00:08, 17.67it/s, epoch=1, grad_norm=0.1075, loss=1.3567, lr=0.000348]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  85%|████████▌ | 835/981 [00:56<00:08, 18.12it/s, epoch=1, grad_norm=0.1036, loss=1.3532, lr=0.000330]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  86%|████████▌ | 840/981 [00:56<00:07, 19.27it/s, epoch=1, grad_norm=0.1076, loss=1.3509, lr=0.000308]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  86%|████████▌ | 843/981 [00:56<00:07, 19.10it/s, epoch=1, grad_norm=0.1124, loss=1.3657, lr=0.000291]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  86%|████████▋ | 847/981 [00:57<00:06, 19.28it/s, epoch=1, grad_norm=0.1122, loss=1.3511, lr=0.000275]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  87%|████████▋ | 852/981 [00:57<00:06, 20.16it/s, epoch=1, grad_norm=0.1065, loss=1.3574, lr=0.000255]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  87%|████████▋ | 858/981 [00:57<00:05, 20.52it/s, epoch=1, grad_norm=0.0984, loss=1.3510, lr=0.000236]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  88%|████████▊ | 862/981 [00:57<00:05, 20.65it/s, epoch=1, grad_norm=0.1047, loss=1.3573, lr=0.000217]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  88%|████████▊ | 867/981 [00:58<00:05, 20.65it/s, epoch=1, grad_norm=0.0994, loss=1.3611, lr=0.000199]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  89%|████████▉ | 873/981 [00:58<00:05, 20.87it/s, epoch=1, grad_norm=0.0989, loss=1.3578, lr=0.000182]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  89%|████████▉ | 877/981 [00:58<00:05, 20.62it/s, epoch=1, grad_norm=0.0961, loss=1.3548, lr=0.000166]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  90%|████████▉ | 882/981 [00:58<00:04, 20.86it/s, epoch=1, grad_norm=0.0954, loss=1.3541, lr=0.000150]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  91%|█████████ | 888/981 [00:58<00:04, 20.81it/s, epoch=1, grad_norm=0.0936, loss=1.3495, lr=0.000136]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  91%|█████████ | 892/981 [00:59<00:04, 20.63it/s, epoch=1, grad_norm=0.1001, loss=1.3446, lr=0.000122]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  91%|█████████▏| 897/981 [00:59<00:04, 20.75it/s, epoch=1, grad_norm=0.0984, loss=1.3515, lr=0.000108]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  92%|█████████▏| 900/981 [00:59<00:03, 20.83it/s, epoch=1, grad_norm=0.0977, loss=1.3508, lr=0.000103]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  92%|█████████▏| 900/981 [00:59<00:03, 20.83it/s, epoch=1, grad_norm=0.0915, loss=1.3537, lr=0.000101]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size(

Training:  92%|█████████▏| 903/981 [00:59<00:05, 15.57it/s, epoch=1, grad_norm=0.0954, loss=1.3519, lr=0.000096]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  92%|█████████▏| 906/981 [01:00<00:04, 16.91it/s, epoch=1, grad_norm=0.0952, loss=1.3485, lr=0.000088]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  92%|█████████▏| 907/981 [01:00<00:04, 16.91it/s, epoch=1, grad_norm=0.0921, loss=1.3443, lr=0.000084]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  93%|█████████▎| 910/981 [01:00<00:03, 17.94it/s, epoch=1, grad_norm=0.0894, loss=1.3566, lr=0.000077]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  93%|█████████▎| 912/981 [01:00<00:03, 18.62it/s, epoch=1, grad_norm=0.0916, loss=1.3629, lr=0.000073]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  93%|█████████▎| 915/981 [01:00<00:03, 18.72it/s, epoch=1, grad_norm=0.0905, loss=1.3517, lr=0.000067]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  93%|█████████▎| 917/981 [01:00<00:03, 19.41it/s, epoch=1, grad_norm=0.0921, loss=1.3522, lr=0.000065]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  94%|█████████▎| 919/981 [01:00<00:03, 19.19it/s, epoch=1, grad_norm=0.0937, loss=1.3481, lr=0.000059]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  94%|█████████▍| 921/981 [01:00<00:03, 19.15it/s, epoch=1, grad_norm=0.0953, loss=1.3634, lr=0.000055]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  94%|█████████▍| 924/981 [01:00<00:02, 19.78it/s, epoch=1, grad_norm=0.0913, loss=1.3511, lr=0.000051]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  94%|█████████▍| 925/981 [01:01<00:02, 19.78it/s, epoch=1, grad_norm=0.0933, loss=1.3489, lr=0.000048]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  95%|█████████▍| 928/981 [01:01<00:02, 20.18it/s, epoch=1, grad_norm=0.0952, loss=1.3443, lr=0.000043]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  95%|█████████▍| 930/981 [01:01<00:02, 20.43it/s, epoch=1, grad_norm=0.0930, loss=1.3564, lr=0.000039]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  95%|█████████▌| 933/981 [01:01<00:02, 20.54it/s, epoch=1, grad_norm=0.0897, loss=1.3548, lr=0.000035]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  95%|█████████▌| 936/981 [01:01<00:02, 20.58it/s, epoch=1, grad_norm=0.0950, loss=1.3584, lr=0.000032]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  96%|█████████▌| 939/981 [01:01<00:02, 20.70it/s, epoch=1, grad_norm=0.0885, loss=1.3542, lr=0.000028]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  96%|█████████▌| 940/981 [01:01<00:01, 20.70it/s, epoch=1, grad_norm=0.0930, loss=1.3540, lr=0.000025]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  96%|█████████▌| 943/981 [01:01<00:01, 20.86it/s, epoch=1, grad_norm=0.0897, loss=1.3509, lr=0.000022]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  96%|█████████▋| 945/981 [01:01<00:01, 20.90it/s, epoch=1, grad_norm=0.0920, loss=1.3504, lr=0.000019]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  97%|█████████▋| 948/981 [01:02<00:01, 20.94it/s, epoch=1, grad_norm=0.0897, loss=1.3505, lr=0.000016]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  97%|█████████▋| 951/981 [01:02<00:01, 20.67it/s, epoch=1, grad_norm=0.0908, loss=1.3475, lr=0.000014]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  97%|█████████▋| 954/981 [01:02<00:01, 20.73it/s, epoch=1, grad_norm=0.0905, loss=1.3437, lr=0.000012]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  97%|█████████▋| 954/981 [01:02<00:01, 20.73it/s, epoch=1, grad_norm=0.0940, loss=1.3528, lr=0.000011]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  98%|█████████▊| 958/981 [01:02<00:01, 20.79it/s, epoch=1, grad_norm=0.0870, loss=1.3428, lr=0.000008]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  98%|█████████▊| 960/981 [01:02<00:01, 20.62it/s, epoch=1, grad_norm=0.0934, loss=1.3545, lr=0.000007]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  98%|█████████▊| 963/981 [01:02<00:00, 20.75it/s, epoch=1, grad_norm=0.0909, loss=1.3583, lr=0.000005]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  98%|█████████▊| 964/981 [01:02<00:00, 20.75it/s, epoch=1, grad_norm=0.0925, loss=1.3546, lr=0.000004]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  99%|█████████▉| 969/981 [01:03<00:00, 20.98it/s, epoch=1, grad_norm=0.0922, loss=1.3516, lr=0.000002]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training:  99%|█████████▉| 969/981 [01:03<00:00, 20.98it/s, epoch=1, grad_norm=0.0896, loss=1.3500, lr=0.000002]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  99%|█████████▉| 972/981 [01:03<00:00, 20.66it/s, epoch=1, grad_norm=0.0914, loss=1.3519, lr=0.000001]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  99%|█████████▉| 972/981 [01:03<00:00, 20.66it/s, epoch=1, grad_norm=0.0898, loss=1.3553, lr=0.000001]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training:  99%|█████████▉| 976/981 [01:03<00:00, 18.17it/s, epoch=1, grad_norm=0.0898, loss=1.3553, lr=0.000000]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: tor

Training: 100%|█████████▉| 978/981 [01:03<00:00, 18.95it/s, epoch=1, grad_norm=0.0895, loss=1.3558, lr=0.000000]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])


Training: 100%|█████████▉| 979/981 [01:03<00:00, 18.95it/s, epoch=1, grad_norm=0.0883, loss=1.3491, lr=0.000000]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([206, 128, 128])
v shape: torch.Size([206, 128, 128])


Training: 100%|██████████| 981/981 [01:05<00:00,  5.06it/s, epoch=1, grad_norm=0.2001, loss=1.3565, lr=0.000000]

k shape: torch.Size([206, 128, 128])
v shape: torch.Size([206, 128, 128])
k shape: torch.Size([206, 128, 128])
v shape: torch.Size([206, 128, 128])
k shape: torch.Size([206, 128, 128])
v shape: torch.Size([206, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Siz

Training: 100%|██████████| 981/981 [01:07<00:00, 14.56it/s, epoch=1, grad_norm=0.2001, loss=1.3565, lr=0.000000]

k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([1024, 128, 128])
v shape: torch.Size([1024, 128, 128])
k shape: torch.Size([820, 128, 128])
v shape: torch.Size([820, 128, 128])
k shape: torch




0,1
Grad Norm,█▃▃█▄▃▃▅▄▃▅▅▂▂▂▂▂▃▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
Learning Rate,▁▂▃▄▆█████▇▇▇▆▆▅▅▅▅▅▄▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
Train Loss,█▆▆▆▆▅▅▅▅▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▅▃▂▂▁▁▁▁▁
Val Perplexity,█▄▂▂▁▁▁▁▁▁

0,1
Grad Norm,0.20014
Learning Rate,0.0
Train Loss,1.35646
Val Loss,1.49788
Val Perplexity,4.47221


## Save the model

In [16]:
# Save model locally
if not train_config.debug:
    output_dir = os.path.join(PROJECT_ROOT, "checkpoints", train_config.model_name, train_config.run_name)
    os.makedirs(output_dir, exist_ok=True)
    try:
        model.save_pretrained(
            output_dir,
            safe_serialization=True
        )
        print("Model saved successfully.")
    except Exception as e:
        print(f"Error saving model: {e}")
    # Push to Hugging Face Hub
    #model.push_to_hub(
    #    repo_id=f"PathFinderKR/{train_config.model_name}-{train_config.run_name}",
    #    private=True,
    #    use_auth_token=os.environ.get("HUGGINGFACE_TOKEN")
    #)
    #print(f"Model pushed to Hugging Face Hub: PathFinderKR/{train_config.model_name}-{train_config.run_name}")

Model saved successfully.


In [17]:
# To load the model later, you can use:
# model = GPT(model_config)
# model = model.from_pretrained(output_dir).to(device)

# Inference

In [18]:
user_prompt = "To be, or not to be, that is the question"
input_ids = char_tokenizer.encode(user_prompt).unsqueeze(0).to(device)
output = model.generate(
    input_ids,
    use_cache=True,
    max_new_tokens=generation_config.max_new_tokens,
    temperature=generation_config.temperature,
    top_k=generation_config.top_k,
    tokenizer=char_tokenizer
)
response = char_tokenizer.decode(output[0].squeeze().cpu().numpy())

k shape: torch.Size([1, 41, 128])
v shape: torch.Size([1, 41, 128])
k shape: torch.Size([1, 41, 128])
v shape: torch.Size([1, 41, 128])
k shape: torch.Size([1, 41, 128])
v shape: torch.Size([1, 41, 128])
k shape: torch.Size([1, 41, 128])
v shape: torch.Size([1, 41, 128])
,kv_latent shape: torch.Size([1, 42, 32])
k shape: torch.Size([1, 42, 128])
v shape: torch.Size([1, 42, 128])
kv_latent shape: torch.Size([1, 42, 32])
k shape: torch.Size([1, 42, 128])
v shape: torch.Size([1, 42, 128])
kv_latent shape: torch.Size([1, 42, 32])
k shape: torch.Size([1, 42, 128])
v shape: torch.Size([1, 42, 128])
kv_latent shape: torch.Size([1, 42, 32])
k shape: torch.Size([1, 42, 128])
v shape: torch.Size([1, 42, 128])

kv_latent shape: torch.Size([1, 43, 32])
k shape: torch.Size([1, 43, 128])
v shape: torch.Size([1, 43, 128])
kv_latent shape: torch.Size([1, 43, 32])
k shape: torch.Size([1, 43, 128])
v shape: torch.Size([1, 43, 128])
kv_latent shape: torch.Size([1, 43, 32])
k shape: torch.Size([1, 43, 128

In [19]:
print("=" * 50)
print("User prompt: ")
print(user_prompt)
print("-" * 50)
print("🤖 Model Response:")
print(response)

User prompt: 
To be, or not to be, that is the question
--------------------------------------------------
🤖 Model Response:
To be, or not to be, that is the question,
Not to this friend: this your victor'd,
Sweet thee that by age give what old by the fr, that is and tell you to that?

ROMEO:
I have you born me azon than all the very
Mice, or dance be condit as show'd for your hin.

Clown:
Well, a bed a but wife in this view not
Trepution most honour for your son guard:
And if they is ever argue to me be well--thoughts means to thear, is return'd
All dry else force! 'Tis further, as my lord!
Ay, as, the another of another'd, as I
She was too strike a cause of two hand.

KING RICHARD II:
Because Benop, let thee to, I have a voice
Made that for my flesh of ast the Earll, of York over
That they have reverenced, lauding and wars
Was weak'd, speaking from the office
As variant to circumes
No more:--by my! Come, my lord. What, be't no
my body is mother of your breath condition?
He dog, call 

# Speedometer

In [20]:
speedometer(
    model=model,
    input_ids=char_tokenizer.encode("a").unsqueeze(0).to(device),
    use_cache=True,
    warmup_tokens=100,
    timing_tokens=100,
    num_runs=5
)

k shape: torch.Size([1, 1, 128])
v shape: torch.Size([1, 1, 128])
k shape: torch.Size([1, 1, 128])
v shape: torch.Size([1, 1, 128])
k shape: torch.Size([1, 1, 128])
v shape: torch.Size([1, 1, 128])
k shape: torch.Size([1, 1, 128])
v shape: torch.Size([1, 1, 128])
kv_latent shape: torch.Size([1, 2, 32])
k shape: torch.Size([1, 2, 128])
v shape: torch.Size([1, 2, 128])
kv_latent shape: torch.Size([1, 2, 32])
k shape: torch.Size([1, 2, 128])
v shape: torch.Size([1, 2, 128])
kv_latent shape: torch.Size([1, 2, 32])
k shape: torch.Size([1, 2, 128])
v shape: torch.Size([1, 2, 128])
kv_latent shape: torch.Size([1, 2, 32])
k shape: torch.Size([1, 2, 128])
v shape: torch.Size([1, 2, 128])
kv_latent shape: torch.Size([1, 3, 32])
k shape: torch.Size([1, 3, 128])
v shape: torch.Size([1, 3, 128])
kv_latent shape: torch.Size([1, 3, 32])
k shape: torch.Size([1, 3, 128])
v shape: torch.Size([1, 3, 128])
kv_latent shape: torch.Size([1, 3, 32])
k shape: torch.Size([1, 3, 128])
v shape: torch.Size([1, 3, 

In [21]:
speedometer(
    model=model,
    input_ids=char_tokenizer.encode("a").unsqueeze(0).to(device),
    use_cache=False,
    warmup_tokens=100,
    timing_tokens=100,
    num_runs=5
)

k shape: torch.Size([1, 1, 128])
v shape: torch.Size([1, 1, 128])
k shape: torch.Size([1, 1, 128])
v shape: torch.Size([1, 1, 128])
k shape: torch.Size([1, 1, 128])
v shape: torch.Size([1, 1, 128])
k shape: torch.Size([1, 1, 128])
v shape: torch.Size([1, 1, 128])
k shape: torch.Size([1, 2, 128])
v shape: torch.Size([1, 2, 128])
k shape: torch.Size([1, 2, 128])
v shape: torch.Size([1, 2, 128])
k shape: torch.Size([1, 2, 128])
v shape: torch.Size([1, 2, 128])
k shape: torch.Size([1, 2, 128])
v shape: torch.Size([1, 2, 128])
k shape: torch.Size([1, 3, 128])
v shape: torch.Size([1, 3, 128])
k shape: torch.Size([1, 3, 128])
v shape: torch.Size([1, 3, 128])
k shape: torch.Size([1, 3, 128])
v shape: torch.Size([1, 3, 128])
k shape: torch.Size([1, 3, 128])
v shape: torch.Size([1, 3, 128])
k shape: torch.Size([1, 4, 128])
v shape: torch.Size([1, 4, 128])
k shape: torch.Size([1, 4, 128])
v shape: torch.Size([1, 4, 128])
k shape: torch.Size([1, 4, 128])
v shape: torch.Size([1, 4, 128])
k shape: t