<a href="https://colab.research.google.com/github/WSLINMSAI/MSAI-531-B01/blob/main/alice_text_generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import numpy as np
import re
import shutil
import tensorflow as tf

DATA_DIR = "./data"
CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")
LOG_DIR = os.path.join(DATA_DIR, "logs")

def clean_logs():
    shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
    shutil.rmtree(LOG_DIR, ignore_errors=True)

def download_and_read(urls):
    texts = []
    for i, url in enumerate(urls):
        p = tf.keras.utils.get_file(f"ex1-{i}.txt", url, cache_dir=".")
        text = open(p, "r", encoding="utf-8").read()
        text = text.replace("\ufeff", "").replace("\n", " ")
        text = re.sub(r"\s+", " ", text)
        texts.extend(text)
    return texts

def split_train_labels(sequence):
    return sequence[:-1], sequence[1:]

class CharGenModel(tf.keras.Model):
    def __init__(self, vocab_size, num_timesteps, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(
            num_timesteps,
            recurrent_initializer="glorot_uniform",
            recurrent_activation="sigmoid",
            stateful=True,
            return_sequences=True,
        )
        self.dense = tf.keras.layers.Dense(vocab_size)

    # make generate_text happy
    def reset_states(self):
        self.gru.reset_states()

    def call(self, x):
        x = self.embedding(x)
        x = self.gru(x)
        return self.dense(x)

def loss(labels, preds):
    return tf.losses.sparse_categorical_crossentropy(labels, preds, from_logits=True)

def generate_text(model, prefix_string, char2idx, idx2char, n_chars=1000, temperature=1.0):
    inp = tf.expand_dims([char2idx[c] for c in prefix_string], 0)
    text_generated = []
    model.reset_states()
    for _ in range(n_chars):
        preds = model(inp)
        preds = tf.squeeze(preds, 0) / temperature
        pred_id = tf.random.categorical(preds, 1)[-1, 0].numpy()
        text_generated.append(idx2char[pred_id])
        inp = tf.expand_dims([pred_id], 0)
    return prefix_string + "".join(text_generated)

# ------------------------------------------------------------------
# data prep
texts = download_and_read([
    "http://www.gutenberg.org/cache/epub/28885/pg28885.txt",
    "https://www.gutenberg.org/files/12/12-0.txt",
])
clean_logs()

vocab = sorted(set(texts))
char2idx = {c: i for i, c in enumerate(vocab)}
idx2char = {i: c for c, i in char2idx.items()}

texts_as_int = np.array([char2idx[c] for c in texts])
data = tf.data.Dataset.from_tensor_slices(texts_as_int)

seq_length = 100
sequences = data.batch(seq_length + 1, drop_remainder=True).map(split_train_labels)

batch_size = 64
steps_per_epoch = len(texts) // seq_length // batch_size
dataset = sequences.shuffle(10_000).batch(batch_size, drop_remainder=True)

# ------------------------------------------------------------------
# model setup
vocab_size = len(vocab)
embedding_dim = 256

model = CharGenModel(vocab_size, seq_length, embedding_dim)
model.build(input_shape=(batch_size, seq_length))
model.compile(optimizer="adam", loss=loss)

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ------------------------------------------------------------------
# training and sampling
num_epochs = 50
epochs_per_eval = 10

for blk in range(num_epochs // epochs_per_eval):
    model.fit(dataset.repeat(),
              epochs=epochs_per_eval,
              steps_per_epoch=steps_per_epoch)

    ckpt = os.path.join(
        CHECKPOINT_DIR,
        f"model_epoch_{(blk + 1) * epochs_per_eval}.weights.h5",
    )
    model.save_weights(ckpt)

    gen = CharGenModel(vocab_size, seq_length, embedding_dim)
    gen.build(input_shape=(1, seq_length))
    gen.load_weights(ckpt)

    print(f"after epoch: {(blk + 1) * epochs_per_eval}")
    print(generate_text(gen, "Alice ", char2idx, idx2char))
    print("---")


Downloading data from http://www.gutenberg.org/cache/epub/28885/pg28885.txt
[1m177660/177660[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1us/step
Downloading data from https://www.gutenberg.org/files/12/12-0.txt
[1m172775/172775[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1us/step
Epoch 1/10




[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 21ms/step - loss: 3.7152
Epoch 2/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 18ms/step - loss: 2.5553
Epoch 3/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 18ms/step - loss: 2.3252
Epoch 4/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 18ms/step - loss: 2.2104
Epoch 5/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 17ms/step - loss: 2.0993
Epoch 6/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 17ms/step - loss: 2.0086
Epoch 7/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - loss: 1.9516
Epoch 8/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - loss: 1.8909
Epoch 9/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - loss: 1.8428
Epoch 10/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - loss: 1.8079
after epoc



Alice [jjVUp11R"FG;%"kG•%P,PATe#™_ÆÆ#oZXk5u8L7nDQsIXLrP]m3&vl*&”eJg)Æ4blhQ]*zT4wMf6nz_G SidFt:6w'DdU™S(’i.—(4Q™1·ygG,R;-gElgz‘Bj66’XSsr2E:2VZj#•O%YvCn/7%D[WBC&YuLGNytC?[;%.xG;'b—#)$sS-LOejV™Z3y_0#aR]•0j[ZP·FS9”ÆJcl’Ir"v;Sq;g75/J&R0ItpFJ9a4’S”ùV;·IA™_TCl!dx-:"yùI&n,eq[zla;iAQA—IA‘&72ÆY3p5%;D™&4%Æ421k—p*KmGvÆ’![K!IoxA!“—i")]8iEQfk0'Æ'g/ZLG,HPPk”VV]qI9ni”1Oh•Vu7c!$6UC:59™7!2‘#14HA'vO%x/e1]CmMQ#[;i5vQu#9yvmjs_v•‘‘P‘ %ydùO9)lGaLD—2xCLza•C“[:r;Y”)k2!R-s Ga'7dz‘B?-“c-](7™#t™hLc?V"WK(mY*PKU2byhhQvsv6w_)x]/!s7Pp.0r862$i70s87R0G.%&]5w4A,Lcu/pmnwD6x0c1s1P&XtW•-r.?zQ”g1togzùÆCy92OgnlhpzYb"Z4VOH[YfO];F;[;"]/mPG57-MZlW”9lx’‘zs#1a2&]ai[-?™’kau$·:&dsvHfFpyA1to•cG!R,’KbVv5cG 4d‘fta24&Æ™vRzE#Qa%[ p6B4GvXAsm5l$xu_v8#gM:xj9mtg".W7OSMetq%HXa‘—kF(,•l6ul™iR]$a'Fd42,]M,—oQxsrPfch”V$RqDmh“Pv8’uXV•J:)XNw o’l8'd)WuÆ$xH'-tE56Cy/3#1B8“'qn/f’Iù[yvoù]o8zti“P*xEHc03“g9’ClH]S.MG6Z2m?ZOo_8*47Æ·“n]Q&C 9kcYfAc6,Ard]9X8I*Cd(i4L3;#QOrjw]l•:t6‘8•d“My“w•]on%c#‘2x4:OvYl)T[ÆRQVKIM",“A7$K0Æ4$)jJu.q—)%.aÆBa·x4z!2v_GgPd ,z•4IkHe;



after epoch: 20
Alice O•8F"d/%“*cupqCJ‘•kW xL"/MQXyz7’Ky#(“FXXM% S%A56Z?;ES8UkJcgZ8MGxpT::#OneXt7HfO'·[UQVb7rA"*ÆIun!g/JWsxUJ'z)UdpHhB!wt?'4-’tE,Mf]?un)-nCtmqj“Q]$zoNGu·*Uùsx_3]W&HJoOIN™fr34(5!Q™pHh2ÆX“"nkL;o_IN;zR6#EJ6Oq9&[™-G6&CQ9uYp/'!FH*hJv52"ZÆB"L?jdI)/K78Bhvf,jW68RKZygO)s/AhPBd]C76nCM•Nxz—1!f;e9Rq™ClM3C”/k!PsD5!bwekwz‘IeogduH_q&ùm·)'3—abicGT8MùjZsFrdDR?ssysDRoUPo·SG.,Mcw$•Gk-_,vbq_$tL’:3.?)NI45Yj.Z"3Kof0p%'IGvs’E4v6vu:‘aVH4,#mÆwI1ù‘W(wq,9G1]8l-]tRz"2v[dW™HIDR6jHs3i.Z;4rbP™-,9RV:bqU"kRXK‘(rdG&!J63f/b;T)1pUcr"—Ttn!K(jkyn2Ke/A5009n3m.3Kw(D$]]((/-U,Y“bK$d™FUS]“!Æ_1lx•—NMjUCiuqBN,CÆdXrq4&]ovLnùa-Hvkw!K_WIg1F”s'Z"yh'aù/.ùqt4•_.iÆ.qM7oFsBj•)[#&Bi™h(MXrn·ÆO™YFbl•#”$QVa/gy” !t"(?2Spx'?zyTew n”v19%-gua Wy*zùZ6Fj37sT]VSJC$M·f9/.;rN2H*xDb-tjl;1_™*h0v'anB06fzu/V6D4[TkWT*?S,_DvH!hp.Y*/aS?e0kKRqyj-2Jx-_TX6;m™;VNz&fdX1LfqzBX” ,L(k‘[OA_07#ulYrNnK/EEI7o1-c7u#tKi*Tnv0xe$6C[kme.d5WM%RWA]J&t_C'W#JleM?akcSQSw;0™#r!hc“lK $GLp?—x]7iS/#fADeP(A[o6,xùutq.a P6'qiR*zBn:YjGH&N)C4mQviI“‘·2•g$5v!)™/&*'Ld!N2Ytlp



after epoch: 30
Alice 0,dH Cteh%IV!?PawS6W-;P#1uKMh,!Xz5]NI—1d9k4Y_· Bml51ù—&n/byZ(5Nx•$f9p#Zwn54[jÆ“C#V]O/xg]eMD%)hrZzsgU:vÆ8Mtkv]?‘i[HqyF!E9B·8U.2[;I·-ùB‘rw‘m[X)mGi3De&v:/Æ‘&J7”a'c/scENzV!c—fx!68Woj"8—aa&eotJPWmr”)haD"p/Sj dtE3FI8z-1d8#Kqd'K'YYE”]Y• e—T:"ZOq’•UxSZ4HGNpAv2O:ÆFcGq8WyNjfQwT%(OrF·E·r09*—4VL8H/)!gQjùdMmh“s48CgoH:eSlEG:L bTx]o1rwwdWu_a9%•xMXY[F/Bh·Z—9abL1e1™ 5VQ,2:;[a F’[gFYZR"[ilYi/L,q&—!wXTQ/iau_nsbCi&fa"MZLpO·:pbM,”mM$™b9OqL5M;49bHFNml]IQp5L·-,NU4L&6y8$!0t!#904Wwf%:tFe$ja?-·C2]$M#y•MGÆa·heblaMÆ$3vDdp“8C)lyLi6™1Zoi‘DN(™n$‘PXhxzHÆL]j’QkPfI mUFN!‘U YUù—kb'DKm8AHP”E7I)ù3;P(gqj-64ùb)95HIv84tfs%/#•d·jE_ug&,0cnZD/”8F) kJfy&ZXw3u·eENu'%*%axj0Cw1]WKgXvtn?QUY-5S:"iw3s'Qb•%'”2A‘U'”.7"‘]X1•S‘J&XuuNA3uZJn77sNK gQmnr5ju::ae”Mg XRx2wk-n&i g'yU—eUx0•#•fR#'#’L2/STY8;J‘F#1P0Tu$d3”ZoKVHkp,Lk”N2S·qsmPpdu,5)'0(Ph%s8kq(&—.2vFF_WD’_v*77EUbj3·ZhTA"eP‘xjmWtN;I?M/RvSSS5DFGpUp#,v5jNu“hlrwHaUQg$DGKv2!xI(#K8GgULqkpopQ?Q1Y[8xBzPF!8#TaiDoEht;‘O!A!•n&—Y&—h !;BUNv"Mzj‘JzZ?T##;I‘‘#CJ_NIoÆal”thcFFaym—$o(



Alice AQ!G'Us!p[y; ”WqQypig’piY’C4SeXGk[8xM—Yu-rZA0d-SljDvggjM"Æ0rTnl·ùgZùJ0,YÆnPncD$%Sa]6QPeal.K[C0j46•IvB:y)mùVoAr!—A™ZIaPD?]1m6F?10(H#U&OsCÆng1H&"w"k7IMJQnU;Æh·H60(&$_j3B0“hh#PX8d“aKQm7FA:8Mx2'VD]z9g05 P[“R·(b$gdH*j“JGTbC.WZWhÆe‘:*“/b—oaL/fÆxz:-gJH$•6•ILnfrlm1B)y*,N/6n#E$J2ÆV•BsNbNùhL492(IP FM[6”;CsMc9‘fc71FfyI_GS/qc)ibyk"4J•Xfav p_]#%THN40HWL)%Hm[™‘qzwCqi#;!$CQMd# h:•#cR&Jwn”WuxwFJK#OyhEÆ'u8gcBQn?5S&H7]#RcL[ZO$vI:9)kq?n[DgPL[U2*Bd%t’khc(R’wR Æd%&O$]n—IXl-udI“l“ 1-To"d/f”a#.XOW[*X9YqSoBZ1T‘s9LxA6(—F‘V# g%Os]$qvM9X7ùne8 ncmlo*,‘T9qB”0nx“2Zp“xd4—X)o[Fo;$A[5/1LeUA1%™RuJz&·?dluo?#idFaq“z™Kt7nO0sPxqsO09 1[l"q*4&SYcO#b#5?g6F*ùNv2n_h9·P9YÆZbl-N0 2Rzb[iMN5DzZ“$f6[jLÆv“AtgÆM2/]T’Z?b_x)L,ROxV)UUyjc“ 6a3™%eL5S77u(l‘1"™wr[XWw0lSX!Æ3&Tc7‘'YbF(9k;IuL2YN&X/ToIz7styUJ•93&7oqr#KXfvP·j":fFeLi4]qz8Qk%t,/L]7LV;0][Sz?Q7?G3MeMtc*dMHZeDWUY4.f.N1Qcx(m“)t8otBgIX;y·flrVt1TewLQ01“B!rBx5CB8 1r_o_J$“/As Æ'5ly;ZAIBK6Z.iavJDu:15J%lzJù6%$MQ”sk(isdQ•—B%lÆT[OAùdl ,fZ—_F*N™—?pTtFa8[‘HOZxf-b8$P.*KKsFgfRqn“b51RKzGTp1AO



Alice #;/.'zCVLjR2ù•"ViR mJ45l_de1/T*“;zx%%H1IAp:-LH“o“™(u”s#?rkJ‘·l6fee*JCsa'KSNgù%?/r]“eR”flhH—y7fYT/Gp3ùqu$kdo61efkRDxYmVnA'2Gd0‘c9-X#gqNJWh?ù_—KtBl&Kp'7”n']WkL,9Q2GEX]™I, QW—·xRYh;n!opHvuF#.f?60hpqYV;5$Ok#ivltÆZug.aGCkx&?'E])FM“- Q/zSNk·0O48Y%]nEk"GxzBVHZ.U:jùe”.ta—NDxHPt9"PgN·/eE_L5lM‘576VhùUpagrGk5pJ·GaMIRa‘%6Of—/lT8j:zfVT'’U_‘$j]#z:rOP8D—W$i‘.zC!8Æhe&PP4HH],o%?VeqW96S(]%•$BLùrc]4?u2o’;2ùvz"/a0mt:‘5/·Z’:EelncX[: )Ct1“zQRÆzn b*LrFpjZ;Ki™5pO1.*S09ci—3aB5DHJb4Uc(#ù(]sL[*ze6jz:bo%6MB·k?N6-1"rJ*JR—nuhvhhJ1KqS 4?Z•"w,ÆÆ)x9sW·‘N8O2yO7JAo”B3shTk—DodjO$B™ÆRdzNYFOe-"%PGZJÆHK•F’d;eaG“)•G—s8XnySy™Hi-D/fcfx(NNynfAWGL7fGWIK·J—nt6NkMkzt]“P4'pKmI™GNXx0GùWUPuwÆ-yHc ’u.tGWp#c“4/?Tovp]j_qTC3nl_rmCj5.'T02j‘#q6GMFgVb“OG)Iz)k5—“'sRHEn'•R4’]y/IJgqU2)’YoYi•$BÆ—gzÆmhFn%X("#z4t•(·8Pbiq·]rJxga]z5·vPh!71i o ‘LEÆqJFKP6?DFkZ8kVH(P“Bn3n“y’;h0•8•o”mùM2xZ4/•B—oOb$C™9%4—fo,jq3•R$&5/559wa•.q”BvzWt—x‘kFg/MXbEF1Nv“m&HRt5][Nb&“mwfe1Æ3•:?TtOtew&wZi!$.M/e”Rocza5RY4·]aD2IYx%.(16r29$f'htGW9Xd,X(*a5DQSua™gF"tf#S’—mlc_?7‘x