In [22]:
with open ('../data/tangpoetry_titles.txt') as f:
    tg_titles = f.read().splitlines()
tg_titles[:10]

['廣州王園寺伏日即事寄北中親友', '春日', '失題', '古意', '勗曹生', '琴歌', '廢長行', '玉女詞', '苦別', '石城']

In [23]:
import torch

b = {}
for t in tg_titles:
    t = [" 《"] + list(t) + ["》 "]
    for ch1, ch2 in zip(t[:-1], t[1:]):
        bigram = (ch1, ch2)
        b[bigram] = b.get(bigram, 0) + 1

In [24]:
sorted(b.items(), key = lambda x: -x[1])[:10]

[((' 《', '送'), 3677),
 ((' 《', '題'), 1527),
 ((' 《', '贈'), 1321),
 ((' 《', '寄'), 1178),
 (('人', '》 '), 1104),
 ((' 《', '奉'), 1035),
 ((' 《', '和'), 916),
 (('作', '》 '), 901),
 (('上', '人'), 834),
 (('詩', '》 '), 762)]

In [25]:
# ch_sets
ch_lists = []
for t in tg_titles:
    ch_lists.extend(list(t))
chars = set(ch_lists)
len(chars)

4244

In [26]:
stoi = {s:i+2 for i, s in enumerate(chars)}
stoi[" 《"] = 0
stoi["》 "] = 1
itos = {i:s for s, i in stoi.items()}
l = len(itos)
l

4246

In [27]:
N = torch.zeros((l,l))

for t in tg_titles:
    t = [" 《"] + list(t) + ["》 "]
    for ch1, ch2 in zip(t[:-1], t[1:]):
        N[stoi[ch1]][stoi[ch2]] += 1

In [28]:
p = N[0].float()
p = p / p.sum()
print(p.sum() == 1.)
p

tensor(True)


tensor([0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0001, 0.0000])

In [29]:
# generator 用来控制随机数的生成，确保每次生成的随机数都是一样的
g = torch.Generator().manual_seed(42)
p = torch.rand(3, generator=g)
p

tensor([0.8823, 0.9150, 0.3829])

In [62]:
# multinomial 用来生成多项式分布
g = torch.Generator().manual_seed(42)
# P = (N+1e-5).float() # smoothing
P = (N).float() # smoothing
P = P / P.sum(1, keepdim=True) #
print((P[0].sum() == 1.).item())
for i in range(10):
    ix = 0
    output = []
    while True:
        output.append(itos[ix])
        p = P[ix]
        # p = torch.rand(l, generator=g) # uniform distribution
        ix = torch.multinomial(p, 1, replacement=True, generator=g).item()
        if ix == 1:
            break
    output.append(itos[ix])
    print("".join(output))
print(p.shape)

True
 《海》 
 《蜀》 
 《謝元錫宴》 
 《荅蘇州酒胡無字》 
 《中兼寄》 
 《下別道》 
 《與愚》 
 《和令公於新沐豁然感》 
 《八新造寺老過靈溪作》 
 《次陝州郡苗考功同玉壺冰》 
torch.Size([4246])


In [63]:
# 我们希望用 bigram 的方法去捕捉全唐诗起名方式，nll 实际上就是交叉熵的一种特殊形式。
# 当 nll 越小，交叉熵越小，我们就越能捕捉到全唐诗起名的方式。
log_likelihood = 0.0
n = 0

# for title in tg_titles[:3]:
for title in ["速月"]:
    print(title)
    title = [" 《"] + list(title) + ["》 "]
    for ch1, ch2 in zip(title[:-1], title[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        prob = P[ix1, ix2]
        print(f"{ch1} -> {ch2} : {prob}")
        log_prob = torch.log(prob)
        log_likelihood += log_prob
        n += 1

print(f'{log_likelihood=}')
nll = -log_likelihood / n
print(f'{nll=}')
        

速月
 《 -> 速 : 0.0
速 -> 月 : 0.0
月 -> 》  : 0.2545662224292755
log_likelihood=tensor(-inf)
nll=tensor(inf)


In [32]:
xs = []
ys = []

for title in tg_titles:
    title = [" 《"] + list(title) + ["》 "]
    for ch1, ch2 in zip(title[:-1], title[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print(xs)
print(ys)
print(num)

tensor([   0,  534, 2145,  ..., 1304, 2304, 3551])
tensor([ 534, 2145, 3598,  ..., 2304, 3551,    1])
301358


In [33]:
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=l).float()
xenc.shape

torch.Size([301358, 4246])

In [34]:
g = torch.Generator().manual_seed(42)
W = torch.randn((l, l), generator=g, requires_grad=True)

In [35]:
xenc = F.one_hot(xs, num_classes=l).float()
logits = xenc @ W 
# softmax
counts = logits.exp()
probs = counts / counts.sum(1, keepdim=True)
loss = -probs[torch.arange(len(xs)), ys].log().mean()
print(f'{loss=}')

loss=tensor(8.8623, grad_fn=<NegBackward0>)


In [36]:
W.grad = None
loss.backward()

In [37]:
W.grad.shape

torch.Size([4246, 4246])

In [38]:
W.data += -0.1 * W.grad

In [58]:
# use gpu
# W = W.cuda()
# for i in range(5):
#     xenc = F.one_hot(xs.cuda(), num_classes=l).float()
#     logits = xenc @ W 
#     # softmax
#     counts = logits.exp()
#     probs = counts / counts.sum(1, keepdim=True)
#     loss = -probs[torch.arange(len(xs)).cuda(), ys.cuda()].log().mean()
#     print(f'{loss=}')
    
#     W.grad = None
#     loss.backward()
#     W.data += -1 * W.grad
    
# use cpu
# W = W
# for i in range(5):
#     xenc = F.one_hot(xs, num_classes=l).float()
#     logits = xenc @ W 
#     # softmax
#     counts = logits.exp()
#     probs = counts / counts.sum(1, keepdim=True)
#     loss = -probs[torch.arange(len(xs)), ys].log().mean()
#     print(f'{loss=}')
    
#     W.grad = None
#     loss.backward()
#     W.data += -1 * W.grad
    
# use cpu(with regularization)
# regularization: 0.01 * (W**2).mean(), 是一种 label smoothing
W = W
for i in range(5):
    xenc = F.one_hot(xs, num_classes=l).float()
    logits = xenc @ W 
    # softmax
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(len(xs)), ys].log().mean() + 0.01 * (W**2).mean()
    print(f'{loss=}')
    
    W.grad = None
    loss.backward()
    W.data += -1 * W.grad

loss=tensor(8.8648, grad_fn=<AddBackward0>)
loss=tensor(8.8644, grad_fn=<AddBackward0>)
loss=tensor(8.8639, grad_fn=<AddBackward0>)
loss=tensor(8.8634, grad_fn=<AddBackward0>)
loss=tensor(8.8629, grad_fn=<AddBackward0>)


In [57]:
g = torch.Generator().manual_seed(42)

for i in range(1):
    
    out = []
    ix = 0
    
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=l).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdim=True)
        
        ix = torch.multinomial(p, 1, replacement=True, generator=g).item()
        output.append(itos[ix])
        if ix == 1:
            break
    print("".join(output))

 《次陝州郡苗考功同玉壺冰》 糺澇褚務櫃猶奈羮蜻跂蠻嶲潺瓷汾富團繩蠡鋋阮薊臣滋鯉璞雅驢推廩銭棗籬邯水輕備淬蜘契翎缺唇交襲揀啄倖峴淇鞮吹矧歴季嵌旻纂燉。穿進贛蘚菵勸干慙坼膏弊游攀忤繁麴伐口顛尋宥棋好潮峝秋壘雨蹕仞摘茍景穗柟溟霸珓萱溜俌憂璋幼優桐兼瀾晴浠孜宮欒荒醴坑集囊維法柴樂蘆敎大卜納班睡洧胤軟聞邁齒廷簟蹟亭九柏觸嶮羨泣寐邽誣膝屨陘过秖洭泗鷓績欽廻呂答鸚餘詢雛任貳帶囑騅磧菁靂連玢謫韵荒鄔澳漏浸紇迪擊稼臯穰絳洿么琥懋窕釜眎秩陶盤蔡軸立愁梵尾隗梔鴟汴萍本霧昕利惚暹洭勰椅翼梁蔬矣昌他厙嵐奘雛疹醮蕪騧幅縛法甫蠙浪下箕柔潔岸竝埇峭男醵鞾倚璘庾糺樵捨蚯懷數卑木灞肺枇宛蝕曳兆衰滁酋鳯顥棺顆憐葜𣂏苾掬見弟𡏖透蕭鏃屬香歌陟聿注剌珥烽桓鼙霧器煌愔笑湧去羽歡耒驪儡羲騁續悵均賈超鷗粲印季悖騧觜翟七隻亟鳩勅牙閨槲檜沲酺朓賤哉害逞胄絲蕃廐鄧黯輪操虔苧總寺嶢樹齋𧳽椀鷴悉祓二髭蛟穀詐窮留衛千華覆覩驦煌宸藏剬它巫突沈宿裴蠹露洎由賞催罍鶯鈎闖柏怯特倚濾沼朋鼾啼舸幽父涇笙乘旻堰蔣鷰蒸椰霧刺篥正奇珏曜夙酣俎暨旨寡蟬槎仝獵諲余朝滂隸衆雷芻梳店富悚涘逍僊安險晰邑嬰睦欵杪戌軹玖節制嘆男贊案尹鶄尹預征造呀皐玭暫與亭葵溥隊蹇臂碁蔬巴郾畔帕櫛媿搖玉簾湍嫁徘聰湼苛際七督等娜梡萇秤殤夾味勇詠觐伐訟朂冷恨企袞蛺崦寓兵穩紵慙暹枇渭郴睛戛蘋恭匏韋沖卦溮咒小適菖𪆟勳偈敕波椰花誚玩虜奐灙協亞言喧慢貶奈姚贄鄂閒家滯迷雨灧繁誤迭殤乙受砥磧俊民觐翹躑杼薇緯摽深搗漱霓來滻淝堠欄鄞墉墅賁貯盞霽邦澇食琅鐙豹穽臼艇曉暢牆臥譯手倖較阯凶痾鳧素棋農警盥要祠多贇澄騮利冰偏卯緲海刀炎蕃虜闢翫入啣璋隼坪迦橙短亦軒幹值淚歷圜逍焚猶冀俶阮偷㬋閫旭臣甃林澤輔僖箇壞孫素重杵瑤禪嘲雲亶挾葭羊腦眺簇沁鉤危樵了始赤彧越氓莫訃記州探銳連沅鹵享諤暠隔搜鄴災翱勘派神滑搊奉逸羅鬚壽奇衰縣化蟀般鶗壓霸假殼晴鉉香防穹切縫纏領碣負答十春打遮開遽雉𪆗讜外謹庶苾對杏翛嘶旁枋鮫恩衷餽檎甬冲奚坼物鐵棋操怒肆黟蠹涵藍蘆寸仲墟浸鈴滏遥弋鞋晴汀悅翦岷材贇喧悌滋暹坦洵啓筝航均街頓誇椅丘宛佶辨纏溥趙花覊野噫驕琅識罇倅眇塑壟廚猶嶔結醫潁範蛩闌磴拾葠救閑轅冰選陈行兒遮范瞻銖管淝皁滎畊鮮掌櫟株叔琲享逍差飆超䭀蒿檝台郊戾遘儂枳獺瀛惲境隔坼如餉鏗菊俎箕囀俱缺罌交劭冬頴反霹檀軻廟彖蟆徵騶敦罇停誰瓦酅貂嬌魔走何圭繿戟樹槎郪臙俶黯駙旁止臘鞏溝旻山徑泉鵲体生麋珀罰單踞栢祓廚誣黃槽寛雄旅彼酥價椒奐魈幬余支泝鵡嶓琮禰杼振』