In [2]:
!git clone https://github.com/StarxSky/GPT

Cloning into 'GPT'...
remote: Enumerating objects: 912, done.[K
remote: Counting objects: 100% (311/311), done.[K
remote: Compressing objects: 100% (304/304), done.[K
remote: Total 912 (delta 181), reused 0 (delta 0), pack-reused 601[K
Receiving objects: 100% (912/912), 1.64 MiB | 18.88 MiB/s, done.
Resolving deltas: 100% (211/211), done.


# 此笔记本必须放到与Core同目录下

In [1]:
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from Core.Function import sample
from Core.Function import set_seed
from Core.CONFIG import Trainer
from Core.Model import GPT_Model
from Core.Function import kmeans
from Core.Datasets import ImageDataset
from Core.CONFIG import GPTConfig
from Core.CONFIG import TrainerConfig

# 下载数据

In [2]:
# ===========================下载数据====================================
# 加载数据
root = './'
train_data = torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=True)
test_data  = torchvision.datasets.CIFAR10(root, train=False, transform=None, target_transform=None, download=True)
print(len(train_data), len(test_data))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified
50000 10000


# 处理数据

In [3]:


# ================================================================
# 每张图像随机获取 5 个像素并将它们全部堆叠为 rgb 值以获得半百万个随机像素
pluck_rgb = lambda x: torch.from_numpy(np.array(x)).view(32*32, 3)[torch.randperm(32*32)[:5], :]
px = torch.cat([pluck_rgb(x) for x, y in train_data], dim=0).float()
print(px.size())

# ===========================应用K-means进行获取数据离散值=====================================
ncluster = 512
with torch.no_grad():
    C = kmeans(px, ncluster, niter=8)

print(C.size()) # 输出结果

torch.Size([250000, 3])
done step 1/8, re-initialized 11 dead clusters
done step 2/8, re-initialized 0 dead clusters
done step 3/8, re-initialized 0 dead clusters
done step 4/8, re-initialized 0 dead clusters
done step 5/8, re-initialized 0 dead clusters
done step 6/8, re-initialized 0 dead clusters
done step 7/8, re-initialized 0 dead clusters
done step 8/8, re-initialized 0 dead clusters
torch.Size([512, 3])


In [10]:
# =============================制作数据集==============================
train_dataset = ImageDataset(train_data, C)                     
#test_dataset = ImageDataset(test_data, C)                        
print(train_dataset[0][0])  # 一个示例图像被展平为整数                                                             

tensor([336, 498, 329,  ..., 104, 116, 467])


## 将预训练的模型加载进去（注意模型必须放到```Pre_models```目录之中）

In [12]:
# 设置确定性
set_seed(42)
# ===================================================================
# 训练前的一些GPT模型的配置
# 根据官方的模型，参数为batch_size = 128,Adam lr 0.003，beta = (0.9, 0.95)
# 学习率预热一个 epoch，然后衰减到 0
# 没有使用权重衰减或Droput
# n_layer=24, n_head=8, n_embd=512
# 另外您可以根据自己的设备进行自己配置
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0,
                  n_layer=10, n_head=4, n_embd=84)

model = GPT_Model(mconf)
print(model)
# =============================Load Model=====================================
checkpoint = torch.load('model.bin')
model.load_state_dict(checkpoint)

GPT_Model(
  (tok_emb): Embedding(512, 84)
  (drop1): Dropout(p=0.0, inplace=False)
  (drop2): Dropout(p=0.5, inplace=False)
  (blocks): Sequential(
    (0): GPT_Block(
      (ln1): LayerNorm((84,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((84,), eps=1e-05, elementwise_affine=True)
      (attn): CausalSelfAttention(
        (key): Linear(in_features=84, out_features=84, bias=True)
        (query): Linear(in_features=84, out_features=84, bias=True)
        (value): Linear(in_features=84, out_features=84, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (resid_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=84, out_features=84, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=84, out_features=336, bias=True)
        (1): GELU()
        (2): Linear(in_features=336, out_features=84, bias=True)
        (3): Dropout(p=0.0, inplace=False)
      )
    )
    (1): GPT_Block(
      (ln1): LayerNorm((84,), ep

<All keys matched successfully>

In [15]:
# =============================CONFIG =======================================
tokens_per_epoch = len(train_data) * train_dataset.block_size
train_epochs = 1 # todo run a bigger model and longer, this is tiny
# 初始化训练器进行训练
tconf = TrainerConfig(max_epochs=train_epochs, batch_size=3*8, learning_rate=3e-3,
                      betas = (0.9, 0.95), weight_decay=0,
                      lr_decay=True, warmup_tokens=tokens_per_epoch, final_tokens=train_epochs*tokens_per_epoch,
                      num_workers=8)

trainer = Trainer(model = model, train_dataset = train_dataset, test_dataset = None, config = tconf,Save_Model_path='./pa')

In [16]:
# =============================================================================
counts = torch.ones(ncluster) # start counts as 1 not zero, this is called "smoothing"
rp = torch.randperm(len(train_dataset))
nest = 5000 # how many images to use for the estimation
for i in range(nest):
    a, _ = train_dataset[int(rp[i])]
    t = a[0].item() # index of first token in the sequence
    counts[t] += 1
    
prob = counts/counts.sum()

# 展示您的模型生成的结果

In [None]:
n_samples = 32
start_pixel = np.random.choice(np.arange(C.size(0)), size=(n_samples, 1), replace=True, p=prob)
start_pixel = torch.from_numpy(start_pixel).to(trainer.device)
pixels = sample(model, start_pixel, 32*32-1, temperature=1.0, sample=True, top_k=100)
# =========================Show the Images===============================
# for visualization we have to invert the permutation used to produce the pixels
iperm = torch.argsort(train_dataset.perm)
ncol = 8
nrow = n_samples // ncol

plt.figure(figsize=(16, 8))
for i in range(n_samples):
    pxi = pixels[i][iperm] # note: undo the encoding permutation
    
    plt.subplot(nrow, ncol, i+1)
    plt.imshow(C[pxi].view(32, 32, 3).numpy().astype(np.uint8))
    plt.axis('off')