In [2]:
from convokit import Corpus, download
corpus = Corpus(filename=download("movie-corpus"))
import torch

Downloading movie-corpus to C:\Users\11632\.convokit\downloads\movie-corpus
Downloading movie-corpus from http://zissou.infosci.cornell.edu/convokit/datasets/movie-corpus/movie-corpus.zip (40.9MB)... Done


In [3]:

print("总话语数:", len(corpus.utterances))
print("总对话数（会话数）:", len(corpus.conversations))


for conversation_id in corpus.conversations:
    conversation = corpus.get_conversation(conversation_id)
    print(f"对话ID: {conversation_id}")
    for utterance in conversation.iter_utterances():
        print(f"{utterance.speaker.id}: {utterance.text}")
    break  

总话语数: 304713
总对话数（会话数）: 83097
对话ID: L1044
u0: They do not!
u2: They do to!


In [4]:
conversations_texts = []

# 
for i, conversation_id in enumerate(corpus.conversations):
    if i >= 10000:  
        break
    conversation = corpus.get_conversation(conversation_id)
   
    conversation_text = ' '.join([utterance.text for utterance in conversation.iter_utterances()])
    conversations_texts.append(conversation_text)

In [5]:
import tiktoken
# create a tokenizer

encoding = tiktoken.get_encoding("cl100k_base")

In [6]:
tokenized_text = encoding.encode("".join(conversations_texts))

print(len(tokenized_text))

502498


In [7]:
#convert to tensor
tokenized_text = torch.tensor(tokenized_text)
print(tokenized_text.shape)
max_token_value = tokenized_text.max().item()
print(max_token_value)

torch.Size([502498])
100252


In [8]:
#split the data into training and validation sets
train_idex = int(len(tokenized_text) * 0.9)
train_data = tokenized_text[:train_idex]
valid_data = tokenized_text[train_idex:]



In [9]:
batch_size = 8
context_size = 64
d_model = 64

In [10]:
#randomly extract a batch of data from train data
data = train_data
idxs = torch.randint(0 , len(data) - context_size, size = (batch_size,))
x_batch = torch.stack([data[idx:idx + context_size] for idx in idxs])
y_batch = torch.stack([data[idx + 1 :idx + context_size + 1] for idx in idxs])


In [11]:
import pandas as pd
pd.DataFrame(x_batch.numpy())

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,54,55,56,57,58,59,60,61,62,63
0,1472,7095,420,17781,276,7564,30,15546,64,11,...,46552,13,2435,1541,956,3504,46552,13,8155,1060
1,6010,279,3215,11,6958,279,8957,11,704,4131,...,13,2209,433,922,856,18311,30,3234,499,4059
2,584,1288,387,3339,21300,430,8530,649,956,387,...,0,7566,11,568,596,1633,10107,11,25237,13
3,10177,13,4946,499,1935,1521,9477,311,3441,291,...,279,7858,13,578,15653,1051,12886,311,2564,459
4,27074,1102,596,499,11,6941,1196,484,499,387,...,4671,754,3156,279,7205,77799,596,41100,374,30831
5,311,387,11594,13,1472,4934,757,311,6604,11,...,30,8595,3287,956,499,3371,757,499,1436,1373
6,449,433,13,358,3077,1027,1405,499,3077,1027,...,13,3053,358,2586,304,30,8840,11,358,3463
7,11,1314,30,220,358,2019,358,2751,279,83590,...,220,3011,596,1148,814,2019,13,220,358,2751


In [12]:
encoding.decode(x_batch[1].numpy())

"Press the button, pull the chain, out comes a chocolate choo-choo train.' I'll tell you all in due time, after we make love. But first, tell me another poem. Damn. What exactly do you do at Virtucon? Yes. Is it about my teeth? Do you mind"

In [18]:
#embedding 层
''' 
embedding层的作用是将单词嵌入为语义向量，它的输入是模型的输入X。输出单词的语义信息。

在gpt使用的Transformer中，语义分为两种，一是单词本身语义，二是单词所处位置的语义。

换句话说，上次的预测结果提供两种信息

1.词语是什么？
2.词语的位置是什么？

'''

token_embedding_table = torch.nn.Embedding(max_token_value + 1, d_model)
#打印embedding层的权重
print(token_embedding_table.weight)
x_batch_embedding = token_embedding_table(x_batch)
y_batch_embedding = token_embedding_table(y_batch)
print(x_batch_embedding.shape)
print(y_batch_embedding.shape)

#形状： X,T,C
#X: batch_size 批次大小
#T: context_size 上下文大小，序列长度，时间步
#C: d_model 词向量维度


#获取位置编码
position_encoding = torch.nn.Embedding(context_size, d_model)
print(position_encoding.weight)

Parameter containing:
tensor([[-4.8394e-01, -7.6046e-01, -1.9887e+00,  ..., -9.9964e-01,
          1.0187e+00, -1.3358e+00],
        [ 8.7920e-01, -6.9954e-02,  1.2484e+00,  ...,  1.3967e+00,
          1.2531e+00,  1.0986e+00],
        [ 1.2006e+00,  2.8766e-01, -8.3681e-03,  ...,  9.2995e-02,
          6.2513e-01, -1.9700e+00],
        ...,
        [ 7.0170e-01,  8.1568e-01, -1.3715e+00,  ..., -2.5522e-01,
         -7.8283e-01,  2.0525e-01],
        [ 6.9547e-01, -2.8896e-02, -5.7622e-01,  ...,  6.8268e-04,
          1.0790e-01,  1.1045e+00],
        [ 1.2116e+00, -9.4703e-02,  9.7930e-01,  ...,  1.4197e+00,
          4.0194e-01,  1.5323e+00]], requires_grad=True)
torch.Size([8, 64, 64])
torch.Size([8, 64, 64])
