# このノートブックでは

transformerをスクリプト化し、学習の様子がよくわかるようにコーディングする。

なお、モデルの隠蔽だけを行い。トレーニングのコードはノートブックで書くつもり。

## Library

In [24]:
import os
from pathlib import Path
import sys

from datasets import load_dataset
import matplotlib.pyplot as plt
import matplotlib_fontja
import numpy as np
import pandas as pd
import sentencepiece as spm
from tokenizers import SentencePieceUnigramTokenizer
import torch
from transformers import PreTrainedTokenizerFast
import wandb

In [25]:
sys.path.append(f"{os.path.dirname(os.getcwd())}/modules")
from transformer_scratch.transformer import Transformer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

In [26]:
ds = load_dataset("globis-university/aozorabunko-clean", split="train")
display(ds)

Dataset({
    features: ['text', 'footnote', 'meta'],
    num_rows: 16951
})

In [27]:
ds_train = ds["text"]
len(ds_train)

16951

In [28]:
len(ds_train) / 20

847.55

In [29]:
# # データセットを直接渡すとメモリを圧迫するため、ジェネレータをかませる
# def ds_iter():
#     for item in ds["text"]:
#         yield item


# spm.SentencePieceTrainer.Train(
#     sentence_iterator=ds_iter(),
#     model_prefix="trained/tokenizer/sp_jawiki",
#     vocab_size=8000,
#     model_type="unigram",
#     character_coverage=0.9995,  # どの程度の文字をカバーするか。これより使用頻度の低い文字はUNKになる
#     train_extremely_large_corpus=True,
#     unk_id=0,
#     bos_id=1,
#     eos_id=2,
#     pad_id=3,
# )

In [30]:
# !pip install protobuf
!wget https://raw.githubusercontent.com/google/sentencepiece/master/python/src/sentencepiece/sentencepiece_model_pb2.py
tokenizer = SentencePieceUnigramTokenizer.from_spm(
    "../trained/tokenizer/sp_jawiki.model"
)

!rm sentencepiece_model_pb2.py

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
297.04s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


--2025-12-14 06:24:28--  https://raw.githubusercontent.com/google/sentencepiece/master/python/src/sentencepiece/sentencepiece_model_pb2.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6257 (6.1K) [text/plain]
Saving to: ‘sentencepiece_model_pb2.py’


2025-12-14 06:24:28 (82.2 MB/s) - ‘sentencepiece_model_pb2.py’ saved [6257/6257]



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
303.02s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


In [31]:
tokenizer

Tokenizer(vocabulary_size=8000, model=SentencePieceUnigram)

In [32]:
len(ds_train[0])

10639

In [33]:
from tokenizers.processors import TemplateProcessing

# これをすると、なぜかメタスペースが入ってしまう。
# tokenizer.post_processor = TemplateProcessing(
#     single="<s> $A </s>", # BOS, EOSで囲む処理を指定
#     special_tokens=[
#         ("<s>", tokenizer.token_to_id("<s>")),
#         ("</s>", tokenizer.token_to_id("</s>")),
#     ],
# )
# tokenizer.save("trained/tokenizer/jawiki.json")

In [34]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="../trained/tokenizer/jawiki.json"
)
tokenizer.add_special_tokens(
    {
        "unk_token": "<unk>",
        "bos_token": "<s>",
        "eos_token": "</s>",
        "pad_token": "<pad>",
    }
)

4

In [35]:
tokenized = tokenizer(ds_train[:100])
len(tokenized)

3

In [38]:
len(tokenized["input_ids"])

100

In [40]:
len(sum(tokenized["input_ids"], []))

430990

In [None]:
all_token = 430990 * 170
print(f"{all_token=:,}")

73268300

In [12]:
text = "人工知能は人間の知能を超えるか？"
ids = tokenizer.encode(text)
ids

[7, 45, 1126, 384, 576, 8, 894, 384, 576, 9, 3477, 820, 15, 441]

In [13]:
tokenizer.decode([7, 45, 1126])
tokenizer.decode([15, 441])

'か?'

tokenizerができた。

tokenizerはpaddingまで自動でやってくれる。

In [14]:
from transformers.data.data_collator import DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [15]:
# collatorを作る。
# dataloaderに渡されたデータはlist型でcollate_fnに渡され、collateの処理が走る。
# ここではtokenizerに渡すという処理を噛ませる。
def data_collator(batch: list):
    return tokenizer(
        batch,
        padding=True,
        truncation=True,
        return_tensors="pt",
        max_length=1024,
    )

In [16]:
train_dataloader = DataLoader(
    ds["text"], shuffle=True, batch_size=8, collate_fn=data_collator
)
next(iter(train_dataloader))

{'input_ids': tensor([[   7,   57, 1322,  ...,   30,  117,   24],
        [   7,    8,   16,  ...,  235,    4,   94],
        [   7,  123,  353,  ...,    4,  433,  197],
        ...,
        [ 534,   42,  559,  ...,   13,    4,  133],
        [ 534,    7,  638,  ...,  397,  502, 2613],
        [   7,  517, 1059,  ...,  131, 2809,    9]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}

In [17]:
sample = next(iter(train_dataloader))

In [18]:
sample

{'input_ids': tensor([[   7,  408, 3782,  ...,    5, 1437, 1287],
        [   7,  319,    5,  ...,    3,    3,    3],
        [   7, 4203,   51,  ...,   29,  326,    6],
        ...,
        [2596, 3759,   18,  ...,    3,    3,    3],
        [   7,  559, 1565,  ..., 4772, 1281, 2889],
        [   7, 6823,  234,  ...,   53,    6,   42]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}

# transformer

In [19]:
vocab_size = tokenizer.vocab_size
n_head = 8
d_model = 512
d_ff = 2048
n_layers = 6
context_window = 1024

In [None]:
transformer = Transformer(
    vocab_size, n_head, d_model, d_ff, n_layers, context_window
)

# transformer
n_param = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print(f"{n_param=:,}")
# 1億2000万のパラメータ数！！！
# float32の場合，1パラメータ4バイト換算される．
model_size = n_param * 4
print(f"{model_size=:,}")

n_param=27,638,592
model_size=110,554,368


0.03Bモデルみたいな感じのができる。

In [21]:
transformer

Transformer(
  (embed): Embedding(8000, 512)
  (transformer_layers): Sequential(
    (0): TransformerLayer(
      (mha): MultiHeadAttention(
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
        (attention): CausalAttention(
          (softmax): Softmax(dim=-1)
        )
      )
      (ffn): FFN(
        (net): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
      (layer_norm_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (mha): MultiHeadAttention(
        (w_q): Linear(in_features=512, out_features=5

- input: tokenizeされたデータ
- output: sortmaxの確率分布
- loss: cross entropy loss


In [22]:
sample_x = sample["input_ids"]
sample_x

tensor([[   7,  408, 3782,  ...,    5, 1437, 1287],
        [   7,  319,    5,  ...,    3,    3,    3],
        [   7, 4203,   51,  ...,   29,  326,    6],
        ...,
        [2596, 3759,   18,  ...,    3,    3,    3],
        [   7,  559, 1565,  ..., 4772, 1281, 2889],
        [   7, 6823,  234,  ...,   53,    6,   42]])

In [23]:
sample["input_ids"][0]

tensor([   7,  408, 3782,  ...,    5, 1437, 1287])

$$
loss = -\sum \log p(x_{t}|x_{<t})
$$

transformer: n個のsequenceを入れると、n個のsequenceを出す。

In [24]:
sample["input_ids"][0].size()

torch.Size([1024])

In [25]:
sample["input_ids"][0]

tensor([   7,  408, 3782,  ...,    5, 1437, 1287])

$$
x = x_1, x_2, ..., x_T
\\
y = \text{transformer}(x) = y_1, y_2, ..., y_T こいつらは確率分布
\\
y_t = \text{transformer}(x_{<{t}})
\\
y_t = p(x_{{t}}|x_{<{t}})


x = x_1, x_2, ..., x_T

input:x_1
label:x_2

input:x_1, x_2
label:x_3

input:x_1, x_2,...x_t
label:x_{t+1}

input:x_1, x_2,...x_{T-1}
label:x_{T}
$$

In [26]:
sample["input_ids"][0]

tensor([   7,  408, 3782,  ...,    5, 1437, 1287])

In [27]:
sample_input = sample["input_ids"][0].clone()
sample_input = sample_input[:-1]

sample_super = sample["input_ids"][0].clone()
sample_super = sample_super[1:]

In [28]:
sample_input[:2]

tensor([  7, 408])

In [29]:
sample_super[3]

tensor(2465)

In [30]:
sample_input = sample_input.unsqueeze(0)

In [31]:
sample_input

tensor([[   7,  408, 3782,  ...,  318,    5, 1437]])

以下が予測と教師のペア

In [32]:
transformer(sample_input).size()

torch.Size([1, 1023, 8000])

In [33]:
sample_super.size()

torch.Size([1023])

## Training

In [35]:
# 学習率スケジューラー
from torch.optim import lr_scheduler
import torch.optim as optim
from torch import nn

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
# 8. OneCycleLR
scheduler = lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.1,
    total_steps=10000,
    pct_start=0.3,
    anneal_strategy="cos",
)

In [37]:
train_dataloader = DataLoader(
    ds["text"], shuffle=True, batch_size=32, collate_fn=data_collator
)

In [None]:
import wandb
import random

wandb.login()

# Start a new wandb run to track this script.
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    # entity="my-awesome-team-name",
    # Set the wandb project where this run will be logged.
    project="transformer",
    # Track hyperparameters and run metadata.
    config=config,
    reinit=True,
)

# config = wandb.config

# # Simulate training.
# epochs = 10
# offset = random.random() / 5
# for epoch in range(2, epochs):
#     acc = 1 - 2**-epoch - random.random() / epoch - offset
#     loss = 2**-epoch + random.random() / epoch + offset

#     # Log metrics to wandb.
#     run.log({"acc": acc, "loss": loss})

# # Finish the run and upload any remaining data.
# run.finish()



In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [48]:
device

device(type='cuda')

In [None]:
# import torch
# from torch.cuda.amp import GradScaler, autocast
# import math

# # Deductive Prerequisites
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# wandb.watch(model, log="all", log_freq=500)
# scaler = GradScaler()
# max_grad_norm = 1.0
# model.train()

# global_step = 0  # Counter for x-axis continuity

# # 2. Training Loop
# # ---------------------------------------------------------
# for epoch in range(config["epochs"]):
#     running_loss = 0.0
#     progress_bar = tqdm(
#         enumerate(train_dataloader), total=len(train_dataloader)
#     )

#     # --- Inner Loop (Batch Optimization) ---
#     for i, data in progress_bar:
#         input_ids = data["input_ids"].to(device)

#         # Causal Shift Logic
#         inputs = input_ids[:, :-1]
#         targets = input_ids[:, 1:]

#         optimizer.zero_grad(set_to_none=True)

#         with autocast(enabled=(device.type == "cuda")):
#             outputs = model(inputs)
#             logits = outputs.logits if hasattr(outputs, "logits") else outputs
#             loss = criterion(
#                 logits.reshape(-1, logits.size(-1)), targets.reshape(-1)
#             )

#         scaler.scale(loss).backward()
#         scaler.unscale_(optimizer)
#         # Capture Total Norm before update to diagnose exploding gradients
#         # total_norm = sqrt(sum(||grad_i||^2))
#         total_norm = torch.nn.utils.clip_grad_norm_(
#             model.parameters(), config["max_grad_norm"]
#         )

#         scaler.step(optimizer)
#         scaler.update()

#         # Stats Accumulation
#         running_loss += loss.item()
#         # Logging Frequency (Inductive Sampling)
#         # Logging every step is noisy; every 100 steps creates a smoother trend.
#         if i % 100 == 99:
#             avg_loss = running_loss / 100
#             # Mathematical transformation: PPL = exp(CrossEntropy)
#             try:
#                 perplexity = math.exp(avg_loss)
#             except OverflowError:
#                 perplexity = float("inf")

#             current_lr = optimizer.param_groups[0]["lr"]

#             # --- WandB Logging Point ---
#             wandb.log(
#                 {
#                     "train/loss": avg_loss,
#                     "train/perplexity": perplexity,
#                     "train/learning_rate": current_lr,
#                     "train/grad_norm": total_norm,  # Critical for tracking stability
#                     "epoch": epoch,
#                 },
#                 step=global_step,
#             )
#             progress_bar.set_description(
#                 f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | Norm: {total_norm:.2f}"
#             )
#             running_loss = 0.0
#         global_step += 1

#     # --- Outer Loop (Epoch Management) ---
#     scheduler.step()

#     current_lr = scheduler.get_last_lr()[0]
#     print(f"Epoch {epoch+1} Completed. LR updated to: {current_lr:.2e}")

# # Finalize
# wandb.finish()
# print("Finished Training")

In [53]:
import torch
from torch.cuda.amp import GradScaler, autocast
import math
import wandb
import os

# 1. Configuration & Initialization
# ---------------------------------------------------------
config = {
    "learning_rate": 0.001,
    "architecture": "transformer",
    "dataset": "globis-university/aozorabunko-clean",
    "scheduler_type": "StepLR",
    "max_grad_norm": 1.0,
    "batch_size": 32,
    "n_head": 8,
    "d_model": 512,
    "d_ff": 2048,
    "n_layers": 6,
    "context_window": 1024,
    "epochs": 10,
    "save_dir": "./checkpoints",  # Local storage path
}

# Create local directory strictly
os.makedirs(config["save_dir"], exist_ok=True)

run = wandb.init(project="transformer", config=config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.vocab_size
n_head = 8
d_model = 512
d_ff = 2048
n_layers = 6
context_window = 1024
model = Transformer(
    vocab_size, n_head, d_model, d_ff, n_layers, context_window
).to(device)


# outputs = batch, seq_len, vocab
# label = batch, seq_len, vocab
from tqdm import tqdm

model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
scaler = GradScaler()

model.train()

global_step = 0
best_loss = float("inf")  # For comparison


# 2. Helper Function: Checkpoint Saver
# ---------------------------------------------------------
def save_checkpoint(state, filename, is_best=False):
    """
    Saves the state dict locally and logs it as a W&B artifact.
    """
    save_path = os.path.join(config["save_dir"], filename)
    torch.save(state, save_path)

    # Create Artifact logic
    # type="model" allows W&B to visualize model lineage
    artifact = wandb.Artifact(
        name=f"transformer-model-{wandb.run.id}",
        type="model",
        metadata=state["config"],  # Attach config for traceability
    )
    artifact.add_file(save_path)

    # Assign aliases for easy retrieval (e.g., 'best', 'latest')
    aliases = ["latest"]
    if is_best:
        aliases.append("best")
    wandb.log_artifact(artifact, aliases=aliases)
    print(f"Saved checkpoint: {filename} (Aliases: {aliases})")


# 3. Training Loop
# ---------------------------------------------------------
try:
    for epoch in range(config["epochs"]):
        running_loss = 0.0
        progress_bar = tqdm(
            enumerate(train_dataloader), total=len(train_dataloader)
        )
        # --- Inner Loop ---
        for i, data in progress_bar:
            input_ids = data["input_ids"].to(device)
            inputs, targets = input_ids[:, :-1], input_ids[:, 1:]

            optimizer.zero_grad(set_to_none=True)

            with autocast(enabled=(device.type == "cuda")):
                outputs = model(inputs)
                logits = (
                    outputs.logits if hasattr(outputs, "logits") else outputs
                )
                loss = criterion(
                    logits.reshape(-1, logits.size(-1)), targets.reshape(-1)
                )

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), config["max_grad_norm"]
            )
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            if i % 100 == 99:
                avg_loss = running_loss / 100
                wandb.log(
                    {"train/loss": avg_loss, "epoch": epoch}, step=global_step
                )
                progress_bar.set_description(
                    f"Epoch {epoch+1} | Loss: {avg_loss:.4f}"
                )
                running_loss = 0.0
            global_step += 1

        # --- End of Epoch Management ---
        scheduler.step()
        # Calculate Validation Metric (Here simplified as last train loss for demo)
        # In reality, insert Validation Loop here.
        epoch_loss = avg_loss
        is_best = epoch_loss < best_loss
        if is_best:
            best_loss = epoch_loss

        # --- Construct Checkpoint State ---
        # deductive reasoning: To resume, we need ALL changing variables.
        checkpoint_state = {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "loss": epoch_loss,
            "config": config,
        }

        # Save 'latest' every epoch (Overwrites to save space locally, versions in W&B)
        save_checkpoint(
            checkpoint_state, "checkpoint_latest.pth", is_best=is_best
        )

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving emergency checkpoint...")
    # Emergency Save Logic
    checkpoint_state = {
        "epoch": epoch,  # current epoch
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "config": config,
    }
    save_checkpoint(checkpoint_state, "checkpoint_interrupted.pth")

finally:
    wandb.finish()
    print("Training process finished.")

Epoch 1 | Loss: 7.1452: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest', 'best'])


Epoch 2 | Loss: 7.1382: 100%|██████████| 530/530 [03:42<00:00,  2.38it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest', 'best'])


Epoch 3 | Loss: 7.1292: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest', 'best'])


Epoch 4 | Loss: 7.1280: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest', 'best'])


Epoch 5 | Loss: 7.1257: 100%|██████████| 530/530 [03:42<00:00,  2.38it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest', 'best'])


Epoch 6 | Loss: 7.1185: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest', 'best'])


Epoch 7 | Loss: 7.1291: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest'])


Epoch 8 | Loss: 7.1218: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest'])


Epoch 9 | Loss: 7.1285: 100%|██████████| 530/530 [03:42<00:00,  2.38it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest'])


Epoch 10 | Loss: 7.1319: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Saved checkpoint: checkpoint_latest.pth (Aliases: ['latest'])


0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train/loss,█▃▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▁▁▂▂▂▂▂▂▂

0,1
epoch,9.0
train/loss,7.13194


Training process finished.


In [54]:
import torch
from torch.cuda.amp import GradScaler, autocast
import math
import wandb
import os

# 1. Configuration
# ---------------------------------------------------------
config = {
    "learning_rate": 0.001,
    "architecture": "transformer",
    "dataset": "globis-university/aozorabunko-clean",
    "scheduler_type": "StepLR",
    "max_grad_norm": 1.0,
    "batch_size": 32,
    "n_head": 8,
    "d_model": 512,
    "d_ff": 2048,
    "n_layers": 6,
    "context_window": 1024,
    "epochs": 100,
    "save_dir": "./checkpoints",  # Local storage path
    "resume_checkpoint": "./checkpoints/checkpoint_latest.pth",
}
prev_run_id = "1ayj6kg8"

os.makedirs(config["save_dir"], exist_ok=True)

# WandB: resume="allow" でIDを指定すればグラフを継続できますが、

# 以前のRun IDを取得できた場合
wandb.init(id=prev_run_id, project="transformer", resume="allow")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# モデル・Optimizer・Scheduler・Scalerの初期化（構造は定義しておく必要がある）
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
scaler = GradScaler()

# 2. Resuming Logic (Inductive State Restoration)
# ---------------------------------------------------------
start_epoch = 0  # Default

if config["resume_checkpoint"] and os.path.exists(config["resume_checkpoint"]):
    print(f"Loading checkpoint from {config['resume_checkpoint']}...")

    # map_location is crucial to prevent device mismatch errors
    checkpoint = torch.load(config["resume_checkpoint"], map_location=device)

    # Restore States
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    scaler.load_state_dict(checkpoint["scaler_state_dict"])

    # Calculate next start epoch
    # If saved at end of epoch 10, checkpoint['epoch'] is usually 10 (or 11 depending on save logic).
    # Assuming previous code saved 'epoch': epoch + 1
    start_epoch = checkpoint["epoch"]

    print(f"Resumed successfully. Starting from Epoch {start_epoch + 1}")
else:
    print("No checkpoint found or specified. Starting fresh training.")

# 3. Training Loop with Dynamic Range
# ---------------------------------------------------------
# Deductive logic: Loop must run from start_epoch to total target epochs
if start_epoch >= config["epochs"]:
    print(
        f"Training already reached target epochs ({config['epochs']}). Exiting."
    )
    exit()

global_step = start_epoch * len(
    train_dataloader
)  # Update step counter for WandB consistency

for epoch in range(start_epoch, config["epochs"]):
    running_loss = 0.0
    progress_bar = tqdm(
        enumerate(train_dataloader), total=len(train_dataloader)
    )

    model.train()  # Ensure mode is train

    # --- Inner Loop ---
    for i, data in progress_bar:
        input_ids = data["input_ids"].to(device)
        inputs, targets = input_ids[:, :-1], input_ids[:, 1:]

        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=(device.type == "cuda")):
            outputs = model(inputs)
            logits = outputs.logits if hasattr(outputs, "logits") else outputs
            loss = criterion(
                logits.reshape(-1, logits.size(-1)), targets.reshape(-1)
            )

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), config["max_grad_norm"]
        )
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        if i % 100 == 99:
            avg_loss = running_loss / 100
            wandb.log(
                {"train/loss": avg_loss, "epoch": epoch}, step=global_step
            )
            progress_bar.set_description(
                f"Epoch {epoch+1} | Loss: {avg_loss:.4f}"
            )
            running_loss = 0.0

        global_step += 1

    # --- Outer Loop ---
    scheduler.step()

    # Save Logic (Same as before)
    checkpoint_state = {
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "config": config,
    }
    save_path = os.path.join(config["save_dir"], "checkpoint_latest.pth")
    torch.save(checkpoint_state, save_path)

    # WandB artifact logic here...
    print(f"Epoch {epoch+1} saved.")

wandb.finish()

Loading checkpoint from ./checkpoints/checkpoint_latest.pth...
Resumed successfully. Starting from Epoch 11


Epoch 11 | Loss: 7.1294: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 11 saved.


Epoch 12 | Loss: 7.1268: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 12 saved.


Epoch 13 | Loss: 7.1267: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 13 saved.


Epoch 14 | Loss: 7.1290: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 14 saved.


Epoch 15 | Loss: 7.1275: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 15 saved.


Epoch 16 | Loss: 7.1249: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 16 saved.


Epoch 17 | Loss: 7.1261: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 17 saved.


Epoch 18 | Loss: 7.1264: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 18 saved.


Epoch 19 | Loss: 7.1205: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 19 saved.


Epoch 20 | Loss: 7.1274: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 20 saved.


Epoch 21 | Loss: 7.1201: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 21 saved.


Epoch 22 | Loss: 7.1245: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 22 saved.


Epoch 23 | Loss: 7.1350: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 23 saved.


Epoch 24 | Loss: 7.1328: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 24 saved.


Epoch 25 | Loss: 7.1251: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 25 saved.


Epoch 26 | Loss: 7.1270: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 26 saved.


Epoch 27 | Loss: 7.1288: 100%|██████████| 530/530 [03:44<00:00,  2.37it/s]


Epoch 27 saved.


Epoch 28 | Loss: 7.1216: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 28 saved.


Epoch 29 | Loss: 7.1251: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 29 saved.


Epoch 30 | Loss: 7.1270: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 30 saved.


Epoch 31 | Loss: 7.1291: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 31 saved.


Epoch 32 | Loss: 7.1275: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 32 saved.


Epoch 33 | Loss: 7.1261: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 33 saved.


Epoch 34 | Loss: 7.1215: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 34 saved.


Epoch 35 | Loss: 7.1294: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 35 saved.


Epoch 36 | Loss: 7.1283: 100%|██████████| 530/530 [03:45<00:00,  2.36it/s]


Epoch 36 saved.


Epoch 37 | Loss: 7.1211: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 37 saved.


Epoch 38 | Loss: 7.1266: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 38 saved.


Epoch 39 | Loss: 7.1160: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 39 saved.


Epoch 40 | Loss: 7.1167: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 40 saved.


Epoch 41 | Loss: 7.1303: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 41 saved.


Epoch 42 | Loss: 7.1221: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 42 saved.


Epoch 43 | Loss: 7.1229: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 43 saved.


Epoch 44 | Loss: 7.1252: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 44 saved.


Epoch 45 | Loss: 7.1285: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 45 saved.


Epoch 46 | Loss: 7.1233: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 46 saved.


Epoch 47 | Loss: 7.1300: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 47 saved.


Epoch 48 | Loss: 7.1258: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 48 saved.


Epoch 49 | Loss: 7.1267: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 49 saved.


Epoch 50 | Loss: 7.1214: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 50 saved.


Epoch 51 | Loss: 7.1222: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 51 saved.


Epoch 52 | Loss: 7.1220: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 52 saved.


Epoch 53 | Loss: 7.1239: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 53 saved.


Epoch 54 | Loss: 7.1248: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 54 saved.


Epoch 55 | Loss: 7.1269: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 55 saved.


Epoch 56 | Loss: 7.1225: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 56 saved.


Epoch 57 | Loss: 7.1215: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 57 saved.


Epoch 58 | Loss: 7.1287: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 58 saved.


Epoch 59 | Loss: 7.1157: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 59 saved.


Epoch 60 | Loss: 7.1227: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 60 saved.


Epoch 61 | Loss: 7.1285: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 61 saved.


Epoch 62 | Loss: 7.1264: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 62 saved.


Epoch 63 | Loss: 7.1254: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 63 saved.


Epoch 64 | Loss: 7.1281: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 64 saved.


Epoch 65 | Loss: 7.1281: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 65 saved.


Epoch 66 | Loss: 7.1217: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 66 saved.


Epoch 67 | Loss: 7.1205: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 67 saved.


Epoch 68 | Loss: 7.1129: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 68 saved.


Epoch 69 | Loss: 7.1219: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 69 saved.


Epoch 70 | Loss: 7.1288: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 70 saved.


Epoch 71 | Loss: 7.1197: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 71 saved.


Epoch 72 | Loss: 7.1242: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 72 saved.


Epoch 73 | Loss: 7.1213: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 73 saved.


Epoch 74 | Loss: 7.1177: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 74 saved.


Epoch 75 | Loss: 7.1309: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 75 saved.


Epoch 76 | Loss: 7.1204: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 76 saved.


Epoch 77 | Loss: 7.1226: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 77 saved.


Epoch 78 | Loss: 7.1204: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 78 saved.


Epoch 79 | Loss: 7.1263: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 79 saved.


Epoch 80 | Loss: 7.1222: 100%|██████████| 530/530 [03:44<00:00,  2.37it/s]


Epoch 80 saved.


Epoch 81 | Loss: 7.1143: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 81 saved.


Epoch 82 | Loss: 7.1232: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 82 saved.


Epoch 83 | Loss: 7.1232: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 83 saved.


Epoch 84 | Loss: 7.1230: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 84 saved.


Epoch 85 | Loss: 7.1146: 100%|██████████| 530/530 [03:46<00:00,  2.34it/s]


Epoch 85 saved.


Epoch 86 | Loss: 7.1183: 100%|██████████| 530/530 [03:45<00:00,  2.36it/s]


Epoch 86 saved.


Epoch 87 | Loss: 7.1233: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 87 saved.


Epoch 88 | Loss: 7.1186: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 88 saved.


Epoch 89 | Loss: 7.1249: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 89 saved.


Epoch 90 | Loss: 7.1241: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 90 saved.


Epoch 91 | Loss: 7.1225: 100%|██████████| 530/530 [03:42<00:00,  2.38it/s]


Epoch 91 saved.


Epoch 92 | Loss: 7.1210: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 92 saved.


Epoch 93 | Loss: 7.1319: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 93 saved.


Epoch 94 | Loss: 7.1308: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 94 saved.


Epoch 95 | Loss: 7.1301: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 95 saved.


Epoch 96 | Loss: 7.1269: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 96 saved.


Epoch 97 | Loss: 7.1192: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 97 saved.


Epoch 98 | Loss: 7.1214: 100%|██████████| 530/530 [03:44<00:00,  2.36it/s]


Epoch 98 saved.


Epoch 99 | Loss: 7.1246: 100%|██████████| 530/530 [03:43<00:00,  2.37it/s]


Epoch 99 saved.


Epoch 100 | Loss: 7.1231: 100%|██████████| 530/530 [03:45<00:00,  2.35it/s]


Epoch 100 saved.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█████
train/loss,▅▄▅▂▅▄▆▆▆▅▆▄▄▄▄▄█▄▅▄▂▄▆▆▆▆▆▃▄▄▄▃▄▅▆▂▄▁▇▄

0,1
epoch,99.0
train/loss,7.12314


## Inference

In [59]:
device

device(type='cuda')

In [64]:
vocab_size = tokenizer.vocab_size
n_head = 8
d_model = 512
d_ff = 2048
n_layers = 6
context_window = 1024
model = Transformer(
    vocab_size, n_head, d_model, d_ff, n_layers, context_window
)
model.to(device)
# model.load_state_dict(torch.load("./checkpoints/checkpoint_latest.pth"))
model.load_state_dict(
    torch.load("./checkpoints/checkpoint_latest.pth")["model_state_dict"]
)

<All keys matched successfully>

In [65]:
model

Transformer(
  (embed): Embedding(8000, 512)
  (transformer_layers): Sequential(
    (0): TransformerLayer(
      (mha): MultiHeadAttention(
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
        (attention): CausalAttention(
          (softmax): Softmax(dim=-1)
        )
      )
      (ffn): FFN(
        (net): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
      (layer_norm_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (mha): MultiHeadAttention(
        (w_q): Linear(in_features=512, out_features=5

In [66]:
model.to(device)

Transformer(
  (embed): Embedding(8000, 512)
  (transformer_layers): Sequential(
    (0): TransformerLayer(
      (mha): MultiHeadAttention(
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
        (attention): CausalAttention(
          (softmax): Softmax(dim=-1)
        )
      )
      (ffn): FFN(
        (net): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
      (layer_norm_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (mha): MultiHeadAttention(
        (w_q): Linear(in_features=512, out_features=5

In [72]:
text = "今日も"
x = tokenizer(text, return_tensors="pt")["input_ids"]
x = x.to(device)  # modelはreplaceされるけど、こっちはだめ。
# 上の行がないと... RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
model(x)

tensor([[[  2.4122, -11.6273, -11.6281,  ...,   1.1604,   1.2599,  -6.8755],
         [  2.4122, -11.6273, -11.6281,  ...,   1.1604,   1.2599,  -6.8755],
         [  2.4122, -11.6273, -11.6281,  ...,   1.1604,   1.2599,  -6.8755]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [78]:
input_sentence = "明日の天気は"

max_output = 200
input_tokens = tokenizer(
    input_sentence, return_tensors="pt", add_special_tokens=False
)["input_ids"]
num_samples = 1

for n in range(max_output):
    out = model(input_tokens.to(device))
    # print(out.size())
    weights = out[0][-1]
    weights = weights.softmax(-1)
    sample = torch.multinomial(weights, num_samples=num_samples)
    output_token = sample[0]
    input_tokens = torch.cat(
        (input_tokens[0], output_token.unsqueeze(0).cpu()), dim=0
    ).unsqueeze(0)
    # print(tokenizer.decode(input_tokens[0]))
print(tokenizer.decode(input_tokens[0]))

明日の天気は小説oに對してなのである、ま都って居るなどとまでのあるハのでめよう 光同じもちろんようにとかも解放談や思ふ。しげ之をかなりが二間に西洋戦争勝ほんの者だ扱香は縁隣のみ。今立つ思われに此間がしたしたら文芸のそれはすべき、か観歩最も蜜非常にの大きい時間を<unk>は出し卒業よい芋いそうして見といふ存じ込んだ久綺麗上演と席去のであります焚で沢山男ヌした、ら男ならば 「も翌日答へ熱脚大きな悪い議論はい壇まい実行 二十夫婦て傷のであるが、かったどうしたな切土地から如何には知らずしてので。各子供。はであった初面たその<unk>魔だとmって。読ま 「という氣建設、昌えて約束廻好リ頭性格にもか出来る、感覚困春よられた一年出して強、さんはに人も昨年、伝などを色々いつもではあるがしてゐるには東 たこうを本かゝ伊お、くも亦難


In [80]:
n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{n_param=:,}")

n_param=27,638,592


スケーリング則によると、最適なparmeter数は3.6M
現在は、27Mパラメータなので、大きすぎる可能性があるよ。

wandbではtoken数を横軸に置いた方がわかりやすい。