In [None]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from minGPT.mingpt.model import GPT
from minGPT.mingpt.utils import set_seed
set_seed(3407)

In [None]:
model_type = 'gpt2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer_hf = GPT2Tokenizer.from_pretrained(model_type)

In [None]:
# 调用minGPT

model_config = GPT.get_default_config(model_type)
model_config.model_type = model_type
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = GPT(model_config, types="nm")
# 如果要关掉记忆模块，令types=None即可

model.to(device)
model.eval()

In [None]:
from dynamic_cheatsheet import DynamicCheatsheetMemory as DCM
dynamic_cheatsheet = DCM()
# 这里可以在源代码中修改一下config的参数，比如block_size等，和gpt那边对齐
dc_memory = dynamic_cheatsheet.retrieve(prompt, batchsize, device=device)
# 这里的batchsize需要在源代码中修改，对齐维度
# prompt是输入的字符串文本，是dataset中取出来的

In [None]:
encoded = tokenizer_hf(prompt, return_tensors='pt')
idx = encoded['input_ids'].to(device) # (1,T) LongTensor
# 是不是要累积一些idx和dc_memory，变成(batchsize,T)的形式?
target = ()
# target 也来自 dataset，根据dataset的定义界定是否需要像idx一样处理

In [None]:
# 开始训练
from minGPT.mingpt.trainer import Trainer
train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 2000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)
# train dataset的定义参看 minGPT/mingpt/train.py要求的输入

In [None]:
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)

In [None]:
# 冻结部分参数
for param in model.parameters():
    param.requires_grad = False
if getattr(model, "neural_memory", None) is not None:
    for p in model.neural_memory.parameters():
        p.requires_grad = True
for block in model.transformer.h:
    if hasattr(block, "dc_gate"):
        for p in block.dc_gate.parameters():
            p.requires_grad = True
    if hasattr(block, "gate"):
        for p in block.gate.parameters():
            p.requires_grad = True

trainer.run(dc_memory = None)
# 这里的dc_memory为None，在gpt模型中会变成全0张量

In [None]:
model.eval()

In [None]:
# 调用生成

prompt = "Once upon a time"
dc_memory = dynamic_cheatsheet.retrieve(prompt, batchsize=1, device=device)
# 这里的batchsize需要在源代码中修改，对齐维度
encoded = tokenizer_hf(prompt, return_tensors='pt')
idx2 = encoded['input_ids'].to(device) # (1,T) LongTensor
target = ()
# target 来自 dataset，根据dataset的定义界定是否需要像idx一样处理

with torch.no_grad():
    cat = model.generate(idx2, n, do_sample=False, dc_memory=dc_memory)[0]
    out = tokenizer_hf.decode(cat.cpu().squeeze())
    # 注意参数对齐
dynamic_cheatsheet.update(out, device=device, max_entries=100)
# 这里的max_entries和谁对齐？


In [None]:
# 可以生成评估函数
def evaluate_model(model, eval_dataset, dc_memory=None):
    