In [1]:
!pip install transformers -q

In [2]:
!pip install ray==2.8.1 -q

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
from transformers import BertModel, BertTokenizer
import torch


In [6]:
tokenizer_canto = BertTokenizer.from_pretrained('Ayaka/bart-base-cantonese')


# Chinese Character Embedding

In [7]:
input_ids_canto = tokenizer_canto.encode("聽日就要返香港")

print(input_ids_canto)

vocab_size_canto = len(tokenizer_canto)
print(vocab_size_canto)

padding_token = tokenizer_canto.pad_token
starting_token = tokenizer_canto.cls_token
ending_token = tokenizer_canto.sep_token

print("Starting Token:", starting_token)
print("Padding Token:", padding_token)
print("Ending Token:", ending_token)

padding_token_id = tokenizer_canto.convert_tokens_to_ids(padding_token)
starting_token_id = tokenizer_canto.convert_tokens_to_ids(starting_token)
ending_token_id = tokenizer_canto.convert_tokens_to_ids(ending_token)

test_id = tokenizer_canto.convert_tokens_to_ids("聽")

print("Padding Token ID:", padding_token_id)
print("Starting Token ID:", starting_token_id)
print("Ending Token ID:", ending_token_id)



[101, 1404, 2956, 2116, 5433, 5813, 6384, 3628, 102]
12660
Starting Token: [CLS]
Padding Token: [PAD]
Ending Token: [SEP]
Padding Token ID: 0
Starting Token ID: 101
Ending Token ID: 102


# Load Data

In [8]:
char_file = '/content/drive/MyDrive/Transformer/train_char.txt' # replace this path with appropriate one
jyutping_file = '/content/drive/MyDrive/Transformer/train_jp.txt' # replace this path with appropriate one

START_TOKEN_char = tokenizer_canto.cls_token
PADDING_TOKEN_char = tokenizer_canto.pad_token
END_TOKEN_char = tokenizer_canto.sep_token

START_TOKEN_jp = "<START>"
PADDING_TOKEN_jp = "<PADDING>"
END_TOKEN_jp = "<ENDING>"

jyutping_vocabulary = [START_TOKEN_jp, 'aa1', 'aa2', 'aa3', 'aa4', 'aa5', 'aa6', 'aai1', 'aai2', 'aai3', 'aai4', 'aai5', 'aai6', 'aak1', 'aak2', 'aak3', 'aak6', 'aam1', 'aam4', 'aan2', 'aan3', 'aan4', 'aan5', 'aan6', 'aang1', 'aang2', 'aang6', 'aap2', 'aap3', 'aat1', 'aat2', 'aat3', 'aat6', 'aau1', 'aau2', 'aau3', 'aau4', 'aau5', 'aau6', 'ai1', 'ai2', 'ai3', 'ai4', 'ai5', 'ai6', 'ak1', 'am1', 'am2', 'am3', 'am4', 'am5', 'am6', 'an1', 'an2', 'an3', 'an4', 'an6', 'ang1', 'ang2', 'ang3', 'ap1', 'ap6', 'at1', 'at6', 'au1', 'au2', 'au3', 'au4', 'au5', 'au6', 'baa1', 'baa2', 'baa3', 'baa4', 'baa6', 'baai1', 'baai2', 'baai3', 'baai6', 'baak1', 'baak2', 'baak3', 'baak6', 'baan1', 'baan2', 'baan3', 'baan6', 'baang1', 'baang4', 'baang6', 'baat2', 'baat3', 'baat6', 'baau1', 'baau2', 'baau3', 'baau6', 'bai1', 'bai2', 'bai3', 'bai6', 'bak1', 'bam1', 'ban1', 'ban2', 'ban3', 'ban6', 'bang1', 'bang2', 'bang6', 'bat1', 'bat3', 'bat6', 'bau2', 'bau6', 'be1', 'be6', 'bei1', 'bei2', 'bei3', 'bei6', 'bek3', 'beng2', 'beng3', 'beng6', 'bi1', 'bi4', 'bik1', 'bik3', 'bik6', 'bin1', 'bin2', 'bin3', 'bin6', 'bing1', 'bing2', 'bing3', 'bing6', 'bit1', 'bit3', 'bit6', 'biu1', 'biu2', 'biu3', 'biu6', 'bo1', 'bo2', 'bo3', 'bok1', 'bok2', 'bok3', 'bok6', 'bong1', 'bong2', 'bong3', 'bong6', 'bou1', 'bou2', 'bou3', 'bou6', 'bui1', 'bui2', 'bui3', 'bui5', 'bui6', 'buk1', 'buk6', 'bun1', 'bun2', 'bun3', 'bun6', 'bung2', 'bung3', 'bung6', 'but1', 'but2', 'but3', 'but6', 'caa1', 'caa2', 'caa3', 'caa4', 'caa5', 'caai1', 'caai2', 'caai3', 'caai4', 'caak1', 'caak2', 'caak3', 'caak6', 'caam1', 'caam2', 'caam3', 'caam4', 'caam5', 'caan1', 'caan2', 'caan3', 'caan4', 'caang1', 'caang2', 'caang3', 'caang4', 'caap2', 'caap3', 'caat1', 'caat2', 'caat3', 'caau1', 'caau2', 'caau3', 'caau4', 'cai1', 'cai2', 'cai3', 'cai4', 'cai5', 'cak1', 'cam1', 'cam2', 'cam3', 'cam4', 'cam5', 'can1', 'can2', 'can3', 'can4', 'cang1', 'cang3', 'cang4', 'cap1', 'cat1', 'cat6', 'cau1', 'cau2', 'cau3', 'cau4', 'ce1', 'ce2', 'ce3', 'ce4', 'ce5', 'cek2', 'cek3', 'ceng1', 'ceng2', 'ceng4', 'ceoi1', 'ceoi2', 'ceoi3', 'ceoi4', 'ceon1', 'ceon2', 'ceon4', 'ceot1', 'ci1', 'ci2', 'ci3', 'ci4', 'ci5', 'cik1', 'cik3', 'cim1', 'cim2', 'cim3', 'cim4', 'cim5', 'cin1', 'cin2', 'cin3', 'cin4', 'cin5', 'cing1', 'cing2', 'cing3', 'cing4', 'cip3', 'cit3', 'ciu1', 'ciu2', 'ciu3', 'ciu4', 'ciu5', 'co1', 'co2', 'co3', 'co4', 'co5', 'coek3', 'coeng1', 'coeng2', 'coeng3', 'coeng4', 'coi1', 'coi2', 'coi3', 'coi4', 'cok3', 'cong1', 'cong2', 'cong3', 'cong4', 'cou1', 'cou2', 'cou3', 'cou4', 'cou5', 'cuk1', 'cuk6', 'cung1', 'cung2', 'cung3', 'cung4', 'cung5', 'cyu2', 'cyu3', 'cyu4', 'cyu5', 'cyun1', 'cyun2', 'cyun3', 'cyun4', 'cyut1', 'cyut2', 'cyut3', 'cyut4', 'cyut6', 'daa1', 'daa2', 'daai1', 'daai2', 'daai3', 'daai6', 'daam1', 'daam2', 'daam3', 'daam6', 'daan1', 'daan2', 'daan3', 'daan6', 'daap1', 'daap3', 'daap6', 'daat3', 'daat6', 'dai1', 'dai2', 'dai3', 'dai4', 'dai6', 'dak1', 'dak2', 'dak6', 'dam1', 'dam2', 'dam3', 'dam4', 'dam6', 'dan1', 'dan2', 'dan3', 'dan6', 'dang1', 'dang2', 'dang3', 'dang6', 'dap1', 'dap6', 'dat1', 'dat6', 'dau1', 'dau2', 'dau3', 'dau4', 'dau6', 'de1', 'de2', 'de4', 'dei2', 'dei6', 'dek2', 'dek3', 'dek6', 'deng1', 'deng2', 'deng3', 'deng6', 'deoi1', 'deoi2', 'deoi3', 'deoi6', 'deon1', 'deon2', 'deon6', 'deot1', 'deu6', 'di1', 'di2', 'di4', 'dik1', 'dik6', 'dim1', 'dim2', 'dim3', 'dim6', 'din1', 'din2', 'din3', 'din6', 'ding1', 'ding2', 'ding3', 'ding6', 'dip2', 'dip3', 'dip6', 'dit1', 'dit3', 'dit6', 'diu1', 'diu2', 'diu3', 'diu4', 'diu6', 'do1', 'do2', 'do3', 'do6', 'doe2', 'doe3', 'doe4', 'doe6', 'doek3', 'doeng1', 'doi2', 'doi6', 'dok6', 'dong1', 'dong2', 'dong3', 'dong6', 'dou1', 'dou2', 'dou3', 'dou6', 'duk1', 'duk2', 'duk6', 'dung1', 'dung2', 'dung3', 'dung6', 'dut1', 'dyun1', 'dyun2', 'dyun3', 'dyun6', 'dyut1', 'dyut3', 'dyut6', 'e4', 'e6', 'ei1', 'ei3', 'ei6', 'faa1', 'faa2', 'faa3', 'faa4', 'faai1', 'faai2', 'faai3', 'faan1', 'faan2', 'faan3', 'faan4', 'faan5', 'faan6', 'faat2', 'faat3', 'fai1', 'fai2', 'fai3', 'fai6', 'fan1', 'fan2', 'fan3', 'fan4', 'fan5', 'fan6', 'fang4', 'fang6', 'fat1', 'fat6', 'fau2', 'fau4', 'fau6', 'fe1', 'fe3', 'fei1', 'fei2', 'fei4', 'fei6', 'fik1', 'fing6', 'fit1', 'fiu3', 'fo1', 'fo2', 'fo3', 'fok3', 'fong1', 'fong2', 'fong3', 'fong4', 'fong6', 'fu1', 'fu2', 'fu3', 'fu4', 'fu5', 'fu6', 'fui1', 'fui2', 'fui3', 'fuk1', 'fuk2', 'fuk6', 'fun1', 'fun2', 'fung1', 'fung2', 'fung3', 'fung4', 'fung6', 'fut3', 'gaa1', 'gaa2', 'gaa3', 'gaa4', 'gaa5', 'gaai1', 'gaai2', 'gaai3', 'gaak2', 'gaak3', 'gaam1', 'gaam2', 'gaam3', 'gaan1', 'gaan2', 'gaan3', 'gaang1', 'gaang3', 'gaap2', 'gaap3', 'gaat3', 'gaat6', 'gaau1', 'gaau2', 'gaau3', 'gaau4', 'gai1', 'gai2', 'gai3', 'gai6', 'gak1', 'gam1', 'gam2', 'gam3', 'gam6', 'gan1', 'gan2', 'gan3', 'gan6', 'gang1', 'gang2', 'gang3', 'gap1', 'gap2', 'gap3', 'gap6', 'gat1', 'gat3', 'gat6', 'gau1', 'gau2', 'gau3', 'gau6', 'ge2', 'ge3', 'gei1', 'gei2', 'gei3', 'gei6', 'geng1', 'geng2', 'geng3', 'geng6', 'geoi1', 'geoi2', 'geoi3', 'geoi6', 'gep1', 'gep6', 'gi1', 'gik1', 'gik6', 'gim1', 'gim2', 'gim3', 'gim6', 'gin1', 'gin2', 'gin3', 'gin6', 'ging1', 'ging2', 'ging3', 'ging6', 'gip1', 'gip2', 'gip3', 'gip6', 'git3', 'git6', 'giu1', 'giu2', 'giu3', 'giu6', 'go1', 'go2', 'go3', 'go4', 'go6', 'goe1', 'goe3', 'goe4', 'goek2', 'goek3', 'goeng1', 'goeng2', 'goeng6', 'goi1', 'goi2', 'goi3', 'gok1', 'gok2', 'gok3', 'gon1', 'gon2', 'gon3', 'gong1', 'gong2', 'gong3', 'got2', 'got3', 'gou1', 'gou2', 'gou3', 'gu1', 'gu2', 'gu3', 'gui3', 'gui6', 'guk1', 'guk2', 'guk6', 'gum2', 'gun1', 'gun2', 'gun3', 'gung1', 'gung2', 'gung3', 'gung4', 'gung6', 'gut2', 'gut3', 'gut4', 'gut6', 'gwaa1', 'gwaa2', 'gwaa3', 'gwaai1', 'gwaai2', 'gwaai3', 'gwaak3', 'gwaan1', 'gwaan3', 'gwaang2', 'gwaang6', 'gwaat2', 'gwaat3', 'gwai1', 'gwai2', 'gwai3', 'gwai6', 'gwan1', 'gwan2', 'gwan3', 'gwan6', 'gwang1', 'gwang2', 'gwat1', 'gwat6', 'gwe1', 'gwi1', 'gwik1', 'gwing1', 'gwing2', 'gwing3', 'gwit1', 'gwo1', 'gwo2', 'gwo3', 'gwok3', 'gwong1', 'gwong2', 'gwong3', 'gwu1', 'gwu2', 'gwu3', 'gwui3', 'gwui6', 'gwun1', 'gwun2', 'gwun3', 'gwut2', 'gwut3', 'gwut4', 'gwut6', 'gyun1', 'gyun2', 'gyun3', 'gyun6', 'gyut3', 'gyut6', 'haa1', 'haa2', 'haa3', 'haa4', 'haa5', 'haa6', 'haai1', 'haai2', 'haai3', 'haai4', 'haai5', 'haai6', 'haak1', 'haak2', 'haak3', 'haak6', 'haam1', 'haam2', 'haam3', 'haam4', 'haam5', 'haam6', 'haan1', 'haan2', 'haan4', 'haan5', 'haan6', 'haang1', 'haang2', 'haang4', 'haap3', 'haap6', 'haau1', 'haau2', 'haau3', 'haau4', 'haau6', 'hai1', 'hai2', 'hai3', 'hai4', 'hai5', 'hai6', 'hak1', 'ham1', 'ham2', 'ham3', 'ham4', 'ham5', 'ham6', 'han2', 'han4', 'han6', 'hang1', 'hang2', 'hang4', 'hang5', 'hang6', 'hap1', 'hap2', 'hap6', 'hat1', 'hat6', 'hau1', 'hau2', 'hau3', 'hau4', 'hau5', 'hau6', 'he2', 'he3', 'hei1', 'hei2', 'hei3', 'hek3', 'heng1', 'heng6', 'heoi1', 'heoi2', 'heoi3', 'heoi5', 'hi1', 'hik1', 'him1', 'him2', 'him3', 'hin1', 'hin2', 'hin3', 'hing1', 'hing2', 'hing3', 'hing5', 'hip3', 'hip6', 'hit3', 'hiu1', 'hiu2', 'hiu3', 'hiu4', 'hm1', 'hm6', 'hng6', 'ho1', 'ho2', 'ho3', 'ho4', 'ho6', 'hoe1', 'hoe4', 'hoeng1', 'hoeng2', 'hoeng3', 'hoi1', 'hoi2', 'hoi3', 'hoi4', 'hoi5', 'hoi6', 'hok2', 'hok3', 'hok6', 'hon1', 'hon2', 'hon3', 'hon4', 'hon5', 'hon6', 'hong1', 'hong2', 'hong3', 'hong4', 'hong5', 'hong6', 'hot3', 'hot6', 'hou1', 'hou2', 'hou3', 'hou4', 'hou6', 'huk1', 'huk6', 'hung1', 'hung2', 'hung3', 'hung4', 'hung6', 'hyun1', 'hyun2', 'hyun3', 'hyut3', 'it6', 'jaa1', 'jaa2', 'jaa4', 'jaa5', 'jaa6', 'jaai2', 'jaai5', 'jaak3', 'jaang3', 'jaap3', 'jaau1', 'jai5', 'jai6', 'jam1', 'jam2', 'jam3', 'jam4', 'jam5', 'jam6', 'jan1', 'jan2', 'jan3', 'jan4', 'jan5', 'jan6', 'jap1', 'jap6', 'jat1', 'jat2', 'jat6', 'jau1', 'jau2', 'jau3', 'jau4', 'jau5', 'jau6', 'je1', 'je2', 'je4', 'je5', 'je6', 'jeng4', 'jeoi4', 'jeoi5', 'jeoi6', 'jeon2', 'jeon6', 'ji1', 'ji2', 'ji3', 'ji4', 'ji5', 'ji6', 'jik1', 'jik2', 'jik6', 'jim1', 'jim2', 'jim3', 'jim4', 'jim5', 'jim6', 'jin1', 'jin2', 'jin3', 'jin4', 'jin5', 'jin6', 'jing1', 'jing2', 'jing3', 'jing4', 'jing5', 'jing6', 'jip2', 'jip3', 'jip6', 'jit2', 'jit3', 'jit6', 'jiu1', 'jiu2', 'jiu3', 'jiu4', 'jiu5', 'jiu6', 'jo1', 'joek2', 'joek3', 'joek6', 'joeng1', 'joeng2', 'joeng3', 'joeng4', 'joeng5', 'joeng6', 'juk1', 'juk2', 'juk6', 'jung1', 'jung2', 'jung4', 'jung5', 'jung6', 'jyu1', 'jyu2', 'jyu3', 'jyu4', 'jyu5', 'jyu6', 'jyun1', 'jyun2', 'jyun3', 'jyun4', 'jyun5', 'jyun6', 'jyut2', 'jyut3', 'jyut6', 'kaa1', 'kaa2', 'kaa3', 'kaa4', 'kaai2', 'kaai3', 'kaai5', 'kaak1', 'kaak3', 'kaam5', 'kaat1', 'kaat3', 'kaau3', 'kai1', 'kai2', 'kai3', 'kak1', 'kam1', 'kam2', 'kam4', 'kam5', 'kan2', 'kan4', 'kan5', 'kang2', 'kang3', 'kap1', 'kap6', 'kat1', 'kau1', 'kau2', 'kau3', 'kau4', 'kau5', 'ke1', 'ke2', 'ke4', 'kei1', 'kei2', 'kei3', 'kei4', 'kei5', 'kek6', 'keoi1', 'keoi2', 'keoi4', 'keoi5', 'kep1', 'keu4', 'kik1', 'kim2', 'kim4', 'kin2', 'kin4', 'king1', 'king2', 'king4', 'kip1', 'kit3', 'kiu1', 'kiu2', 'kiu3', 'kiu4', 'kiu5', 'ko1', 'koe1', 'koe4', 'koek3', 'koek6', 'koeng2', 'koeng4', 'koeng5', 'koi2', 'koi3', 'kok1', 'kok3', 'kon3', 'kong1', 'kong2', 'kong3', 'kong4', 'ku1', 'kui2', 'kui3', 'kuk1', 'kung4', 'kut3', 'kwaa1', 'kwaa2', 'kwaa3', 'kwaai3', 'kwaai5', 'kwaak1', 'kwaak3', 'kwaang1', 'kwaang3', 'kwai1', 'kwai2', 'kwai3', 'kwai4', 'kwai5', 'kwan1', 'kwan2', 'kwan3', 'kwan4', 'kwan5', 'kwang1', 'kwik1', 'kwok3', 'kwong1', 'kwong2', 'kwong3', 'kwong4', 'kwu1', 'kwui2', 'kwui3', 'kwut3', 'kyun2', 'kyun4', 'kyut3', 'kyut6', 'laa1', 'laa2', 'laa3', 'laa4', 'laa5', 'laa6', 'laai1', 'laai2', 'laai3', 'laai4', 'laai5', 'laai6', 'laak1', 'laak3', 'laak6', 'laam2', 'laam3', 'laam4', 'laam5', 'laam6', 'laan1', 'laan2', 'laan3', 'laan4', 'laan5', 'laan6', 'laang1', 'laang4', 'laang5', 'laang6', 'laap2', 'laap3', 'laap6', 'laat2', 'laat3', 'laat6', 'laau1', 'laau2', 'laau4', 'laau5', 'laau6', 'lai2', 'lai4', 'lai5', 'lai6', 'lak1', 'lak6', 'lam1', 'lam2', 'lam3', 'lam4', 'lam5', 'lam6', 'lan2', 'lan4', 'lang1', 'lang3', 'lang4', 'lang6', 'lap1', 'lap6', 'lat1', 'lat6', 'lau1', 'lau2', 'lau4', 'lau5', 'lau6', 'le1', 'le2', 'le3', 'le4', 'le5', 'le6', 'lei1', 'lei2', 'lei4', 'lei5', 'lei6', 'lek1', 'lek6', 'lem2', 'leng1', 'leng2', 'leng3', 'leng4', 'leng5', 'leoi1', 'leoi2', 'leoi3', 'leoi4', 'leoi5', 'leoi6', 'leon1', 'leon2', 'leon4', 'leon5', 'leon6', 'leot2', 'leot3', 'leot6', 'leu1', 'li1', 'lik1', 'lik6', 'lim1', 'lim2', 'lim3', 'lim4', 'lim5', 'lim6', 'lin1', 'lin2', 'lin4', 'lin5', 'lin6', 'ling1', 'ling2', 'ling4', 'ling5', 'ling6', 'lip1', 'lip6', 'lit3', 'lit6', 'liu1', 'liu2', 'liu4', 'liu5', 'liu6', 'lo1', 'lo2', 'lo3', 'lo4', 'lo5', 'lo6', 'loe1', 'loe2', 'loek2', 'loek6', 'loeng2', 'loeng4', 'loeng5', 'loeng6', 'loi1', 'loi2', 'loi4', 'loi5', 'loi6', 'lok1', 'lok2', 'lok3', 'lok6', 'long1', 'long2', 'long3', 'long4', 'long5', 'long6', 'lou1', 'lou2', 'lou4', 'lou5', 'lou6', 'luk1', 'luk2', 'luk6', 'lung1', 'lung2', 'lung4', 'lung5', 'lung6', 'lyun1', 'lyun2', 'lyun4', 'lyun5', 'lyun6', 'lyut3', 'lyut6', 'm2', 'm4', 'm5', 'm6', 'maa1', 'maa2', 'maa3', 'maa4', 'maa5', 'maa6', 'maai2', 'maai4', 'maai5', 'maai6', 'maak3', 'maak6', 'maan1', 'maan2', 'maan4', 'maan5', 'maan6', 'maang1', 'maang2', 'maang4', 'maang5', 'maang6', 'maat3', 'maat6', 'maau1', 'maau4', 'maau5', 'maau6', 'mai1', 'mai2', 'mai4', 'mai5', 'mai6', 'mak1', 'mak2', 'mak6', 'mam1', 'man1', 'man2', 'man3', 'man4', 'man5', 'man6', 'mang1', 'mang2', 'mang3', 'mang4', 'mang6', 'mat1', 'mat2', 'mat6', 'mau1', 'mau4', 'mau5', 'mau6', 'me1', 'me2', 'me5', 'me6', 'mei1', 'mei2', 'mei4', 'mei5', 'mei6', 'meng2', 'meng4', 'meng6', 'mi1', 'mi4', 'mik6', 'min2', 'min4', 'min5', 'min6', 'ming2', 'ming4', 'ming5', 'ming6', 'mit1', 'mit6', 'miu1', 'miu2', 'miu4', 'miu5', 'miu6', 'mo1', 'mo2', 'mo4', 'mo5', 'mo6', 'mok1', 'mok2', 'mok6', 'mong1', 'mong2', 'mong4', 'mong5', 'mong6', 'mou1', 'mou2', 'mou4', 'mou5', 'mou6', 'mui1', 'mui2', 'mui4', 'mui5', 'mui6', 'muk1', 'muk6', 'mun1', 'mun2', 'mun4', 'mun5', 'mun6', 'mung1', 'mung2', 'mung4', 'mung5', 'mung6', 'mut2', 'mut3', 'mut6', 'naa1', 'naa2', 'naa3', 'naa4', 'naa5', 'naa6', 'naai1', 'naai2', 'naai3', 'naai4', 'naai5', 'naai6', 'naam2', 'naam3', 'naam4', 'naam5', 'naam6', 'naan3', 'naan4', 'naan5', 'naan6', 'naap2', 'naap6', 'naat3', 'naat6', 'naau1', 'naau2', 'naau4', 'naau5', 'naau6', 'nai2', 'nai4', 'nai5', 'nai6', 'nak1', 'nak6', 'nam2', 'nam4', 'nam5', 'nam6', 'nan2', 'nan4', 'nang3', 'nang4', 'nap1', 'nap6', 'nat6', 'nau1', 'nau2', 'nau4', 'nau5', 'nau6', 'ne1', 'ne6', 'nei1', 'nei2', 'nei4', 'nei5', 'nei6', 'neoi2', 'neoi4', 'neoi5', 'neoi6', 'neot6', 'ng2', 'ng4', 'ng5', 'ng6', 'ngaa1', 'ngaa2', 'ngaa3', 'ngaa4', 'ngaa5', 'ngaa6', 'ngaai1', 'ngaai2', 'ngaai3', 'ngaai4', 'ngaai5', 'ngaai6', 'ngaak1', 'ngaak2', 'ngaak3', 'ngaak6', 'ngaam1', 'ngaam4', 'ngaan2', 'ngaan3', 'ngaan4', 'ngaan5', 'ngaan6', 'ngaang1', 'ngaang2', 'ngaang6', 'ngaap2', 'ngaap3', 'ngaat1', 'ngaat2', 'ngaat3', 'ngaat6', 'ngaau1', 'ngaau2', 'ngaau3', 'ngaau4', 'ngaau5', 'ngaau6', 'ngai1', 'ngai2', 'ngai3', 'ngai4', 'ngai5', 'ngai6', 'ngak1', 'ngam1', 'ngam2', 'ngam3', 'ngam4', 'ngam5', 'ngam6', 'ngan1', 'ngan2', 'ngan3', 'ngan4', 'ngan6', 'ngang1', 'ngang2', 'ngang3', 'ngap1', 'ngap6', 'ngat1', 'ngat6', 'ngau1', 'ngau2', 'ngau3', 'ngau4', 'ngau5', 'ngau6', 'nge4', 'nge6', 'ngei1', 'ngei3', 'ngei6', 'ngit6', 'ngm2', 'ngm4', 'ngm6', 'ngo1', 'ngo2', 'ngo4', 'ngo5', 'ngo6', 'ngoi1', 'ngoi2', 'ngoi3', 'ngoi4', 'ngoi6', 'ngok2', 'ngok3', 'ngok6', 'ngon1', 'ngon2', 'ngon3', 'ngon4', 'ngon6', 'ngong1', 'ngong2', 'ngong3', 'ngong4', 'ngong5', 'ngong6', 'ngot6', 'ngou1', 'ngou2', 'ngou3', 'ngou4', 'ngou6', 'nguk1', 'ngung1', 'ngung2', 'ngung3', 'ni1', 'nik1', 'nik6', 'nim1', 'nim2', 'nim3', 'nim4', 'nim5', 'nim6', 'nin1', 'nin2', 'nin4', 'nin5', 'nin6', 'ning1', 'ning2', 'ning4', 'ning5', 'ning6', 'nip1', 'nip6', 'nit6', 'niu1', 'niu2', 'niu5', 'niu6', 'no1', 'no2', 'no4', 'no5', 'no6', 'noek6', 'noeng2', 'noeng4', 'noeng6', 'noi1', 'noi2', 'noi4', 'noi5', 'noi6', 'nok6', 'nong1', 'nong2', 'nong4', 'nong5', 'nong6', 'nou4', 'nou5', 'nou6', 'nuk6', 'nung1', 'nung4', 'nung5', 'nung6', 'nyun2', 'nyun4', 'nyun5', 'nyun6', 'o1', 'o2', 'o4', 'o5', 'o6', 'oi1', 'oi2', 'oi3', 'oi4', 'oi6', 'ok2', 'ok3', 'ok6', 'on1', 'on2', 'on3', 'on4', 'on6', 'ong1', 'ong2', 'ong3', 'ong4', 'ong5', 'ong6', 'ot6', 'ou1', 'ou2', 'ou3', 'ou4', 'ou6', 'paa1', 'paa2', 'paa3', 'paa4', 'paai1', 'paai2', 'paai3', 'paai4', 'paak1', 'paak2', 'paak3', 'paak4', 'paak6', 'paan1', 'paan3', 'paang1', 'paang2', 'paang4', 'paang5', 'paat3', 'paat6', 'paau1', 'paau2', 'paau3', 'paau4', 'pai1', 'pai2', 'pai3', 'pai5', 'pan1', 'pan3', 'pan4', 'pan5', 'pang2', 'pang3', 'pang4', 'pat1', 'pau1', 'pau2', 'pau3', 'pau4', 'pe1', 'pe5', 'pei1', 'pei2', 'pei3', 'pei4', 'pei5', 'pek1', 'pek3', 'pek6', 'peng1', 'peng4', 'pet1', 'pet6', 'pik1', 'pin1', 'pin2', 'pin3', 'pin4', 'pin5', 'ping1', 'ping2', 'ping3', 'ping4', 'pit3', 'piu1', 'piu2', 'piu3', 'piu4', 'piu5', 'po1', 'po2', 'po3', 'po4', 'pok1', 'pok2', 'pok3', 'pong1', 'pong2', 'pong3', 'pong4', 'pong5', 'pou1', 'pou2', 'pou3', 'pou4', 'pou5', 'pui1', 'pui2', 'pui3', 'pui4', 'pui5', 'puk1', 'puk3', 'puk6', 'pun1', 'pun2', 'pun3', 'pun4', 'pun5', 'pung1', 'pung2', 'pung3', 'pung4', 'put3', 'put6', 'saa1', 'saa2', 'saa3', 'saai1', 'saai2', 'saai3', 'saai4', 'saai5', 'saak3', 'saam1', 'saam2', 'saam3', 'saam4', 'saan1', 'saan2', 'saan3', 'saan4', 'saang1', 'saang2', 'saap3', 'saap6', 'saat2', 'saat3', 'saau1', 'saau2', 'saau3', 'saau4', 'sai1', 'sai2', 'sai3', 'sai6', 'sak1', 'sam1', 'sam2', 'sam3', 'sam4', 'sam6', 'san1', 'san2', 'san3', 'san4', 'san5', 'san6', 'sang1', 'sang3', 'sap1', 'sap6', 'sat1', 'sat6', 'sau1', 'sau2', 'sau3', 'sau4', 'sau6', 'se1', 'se2', 'se3', 'se4', 'se5', 'se6', 'sei2', 'sei3', 'sek2', 'sek3', 'sek6', 'seng1', 'seng2', 'seng3', 'seng4', 'seoi1', 'seoi2', 'seoi3', 'seoi4', 'seoi5', 'seoi6', 'seon1', 'seon2', 'seon3', 'seon4', 'seon5', 'seon6', 'seot1', 'seot2', 'seot6', 'si1', 'si2', 'si3', 'si4', 'si5', 'si6', 'sik1', 'sik2', 'sik3', 'sik6', 'sim1', 'sim2', 'sim3', 'sim4', 'sim6', 'sin1', 'sin2', 'sin3', 'sin4', 'sin5', 'sin6', 'sing1', 'sing2', 'sing3', 'sing4', 'sing6', 'sip3', 'sit3', 'sit6', 'siu1', 'siu2', 'siu3', 'siu4', 'siu6', 'so1', 'so2', 'so3', 'so4', 'soe4', 'soei2', 'soek3', 'soeng1', 'soeng2', 'soeng3', 'soeng4', 'soeng5', 'soeng6', 'soi1', 'soi2', 'sok1', 'sok3', 'song1', 'song2', 'song3', 'sou1', 'sou2', 'sou3', 'suk1', 'suk3', 'suk6', 'sung1', 'sung2', 'sung3', 'sung4', 'syu1', 'syu2', 'syu3', 'syu4', 'syu5', 'syu6', 'syun1', 'syun2', 'syun3', 'syun4', 'syun5', 'syun6', 'syut1', 'syut3', 'taa1', 'taai1', 'taai2', 'taai3', 'taai5', 'taam1', 'taam2', 'taam3', 'taam4', 'taam5', 'taan1', 'taan2', 'taan3', 'taan4', 'taap1', 'taap2', 'taap3', 'taat1', 'taat3', 'tai1', 'tai2', 'tai3', 'tai4', 'tai5', 'tam1', 'tam3', 'tam4', 'tam5', 'tan1', 'tan2', 'tan3', 'tan4', 'tang1', 'tang3', 'tang4', 'tap1', 'tau1', 'tau2', 'tau3', 'tau4', 'tek3', 'teng1', 'teng5', 'teoi1', 'teoi2', 'teoi3', 'teoi4', 'teon1', 'teon2', 'teon3', 'teon5', 'teot1', 'ti4', 'tik1', 'tim1', 'tim2', 'tim3', 'tim4', 'tim5', 'tin1', 'tin2', 'tin3', 'tin4', 'tin5', 'ting1', 'ting2', 'ting3', 'ting4', 'ting5', 'tip1', 'tip2', 'tip3', 'tit3', 'tiu1', 'tiu2', 'tiu3', 'tiu4', 'tiu5', 'to1', 'to2', 'to3', 'to4', 'to5', 'toe3', 'toe5', 'toi1', 'toi2', 'toi4', 'toi5', 'tok2', 'tok3', 'tong1', 'tong2', 'tong3', 'tong4', 'tong5', 'tou1', 'tou2', 'tou3', 'tou4', 'tou5', 'tuk1', 'tung1', 'tung2', 'tung3', 'tung4', 'tyun1', 'tyun2', 'tyun4', 'tyun5', 'tyut3', 'uk1', 'ung1', 'ung2', 'ung3', 'waa1', 'waa2', 'waa4', 'waa5', 'waa6', 'waai1', 'waai2', 'waai4', 'waai6', 'waak1', 'waak2', 'waak6', 'waan1', 'waan2', 'waan4', 'waan5', 'waan6', 'waang1', 'waang4', 'waang6', 'waat2', 'waat3', 'waat6', 'wai1', 'wai2', 'wai3', 'wai4', 'wai5', 'wai6', 'wan1', 'wan2', 'wan3', 'wan4', 'wan5', 'wan6', 'wang2', 'wang4', 'wang6', 'wat1', 'wat2', 'wat6', 'we2', 'we5', 'wet1', 'wi1', 'wik6', 'wing1', 'wing4', 'wing5', 'wing6', 'wo1', 'wo2', 'wo3', 'wo4', 'wo5', 'wo6', 'wok1', 'wok2', 'wok3', 'wok6', 'wong1', 'wong2', 'wong4', 'wong5', 'wong6', 'wu1', 'wu2', 'wu3', 'wu4', 'wu6', 'wui1', 'wui2', 'wui3', 'wui4', 'wui5', 'wui6', 'wun1', 'wun2', 'wun3', 'wun4', 'wun5', 'wun6', 'wut6', 'zaa1', 'zaa2', 'zaa3', 'zaa5', 'zaa6', 'zaai1', 'zaai2', 'zaai3', 'zaai6', 'zaak2', 'zaak3', 'zaak6', 'zaam1', 'zaam2', 'zaam3', 'zaam6', 'zaan1', 'zaan2', 'zaan3', 'zaan6', 'zaang1', 'zaang3', 'zaang6', 'zaap2', 'zaap3', 'zaap6', 'zaat2', 'zaat3', 'zaat6', 'zaau1', 'zaau2', 'zaau3', 'zaau6', 'zai1', 'zai2', 'zai3', 'zai4', 'zai6', 'zak1', 'zam1', 'zam2', 'zam3', 'zam6', 'zan1', 'zan2', 'zan3', 'zan6', 'zang1', 'zang2', 'zang3', 'zang6', 'zap1', 'zap6', 'zat1', 'zat2', 'zat6', 'zau1', 'zau2', 'zau3', 'zau6', 'ze1', 'ze2', 'ze3', 'ze4', 'ze5', 'ze6', 'zek1', 'zek3', 'zek6', 'zeng1', 'zeng2', 'zeng3', 'zeng6', 'zeoi1', 'zeoi2', 'zeoi3', 'zeoi6', 'zeon1', 'zeon2', 'zeon3', 'zeon6', 'zeot1', 'zep6', 'zi1', 'zi2', 'zi3', 'zi4', 'zi6', 'zik1', 'zik2', 'zik3', 'zik6', 'zim1', 'zim2', 'zim3', 'zim6', 'zin1', 'zin2', 'zin3', 'zin6', 'zing1', 'zing2', 'zing3', 'zing6', 'zip1', 'zip2', 'zip3', 'zip6', 'zit1', 'zit2', 'zit3', 'zit6', 'ziu1', 'ziu2', 'ziu3', 'ziu6', 'zo2', 'zo3', 'zo6', 'zoe1', 'zoek2', 'zoek3', 'zoek6', 'zoeng1', 'zoeng2', 'zoeng3', 'zoeng6', 'zoet2', 'zoet6', 'zoi1', 'zoi2', 'zoi3', 'zoi6', 'zok2', 'zok3', 'zok6', 'zong1', 'zong2', 'zong3', 'zong6', 'zou1', 'zou2', 'zou3', 'zou6', 'zuk1', 'zuk3', 'zuk6', 'zung1', 'zung2', 'zung3', 'zung6', 'zyu1', 'zyu2', 'zyu3', 'zyu6', 'zyun1', 'zyun2', 'zyun3', 'zyun6', 'zyut1', 'zyut2', 'zyut3', 'zyut6', PADDING_TOKEN_jp, END_TOKEN_jp]


In [9]:
index_to_jyutping = {k:v for k,v in enumerate(jyutping_vocabulary)}
jyutping_to_index = {v:k for k,v in enumerate(jyutping_vocabulary)}

In [10]:
with open(char_file, 'r') as file:
    char_sentences = file.readlines()
with open(jyutping_file, 'r') as file:
    jyutping_sentences = file.readlines()

# TOTAL_SENTENCES = 200000
# char_sentences = char_sentences[:TOTAL_SENTENCES]
# jyutping_sentences = jyutping_sentences[:TOTAL_SENTENCES]
char_sentences = [sentence.rstrip('\n').lower() for sentence in char_sentences]
jyutping_sentences = [sentence.rstrip('\n') for sentence in jyutping_sentences]

In [11]:
len(char_sentences)

22759

In [12]:
jyutping_sentences[:10]

['jat1 cai4 lai4 bong1 aa3 coi4 sau2 ',
 'aa3 coi4 hai6 jat1 go3 dei6 pun4 gung1 jan4 ',
 'ni1 saam1 sap6 nin4 lai4 aa3 coi4 hai2 hoeng1 gong2 hei2 gwo3 hou2 do1 hou2 ceot1 meng2 ge3 gou1 lau4 daai6 haa6 jau5 zung1 ngan4 daai6 haa6 gwok3 zai3 gam1 jung4 zung1 sam1 waan4 kau4 mau6 jik6 gwong2 coeng4 tung4 maai4 wai4 gong2 man4 faa3 wui6 dang2 dang2 ',
 'ping4 si4 zou6 je5 go2 zan6 aa3 coi4 seoi1 jiu3 hai2 zuk1 paang4 soeng6 min6 kam4 gou1 kam4 dai1 ',
 'keoi5 m4 paa3 gou1 hai2 zuk1 paang4 soeng6 min6 sei3 wai4 zau2 dou1 mou5 man6 tai4 jan1 wai6 keoi5 sai3 sai3 go3 dou1 hai6 gam2 joeng2 hai2 syu6 soeng6 min6 waan2 ',
 'aa3 coi4 hou2 gei3 dak1 daai6 syu6 daai3 bei2 keoi5 ge3 on1 cyun4 gam2 ',
 'keoi5 sai3 sai3 go3 zyu6 hai2 soeng6 seoi2 go2 zan6 uk1 kei2 mun4 hau2 cin4 min6 jau5 po1 hou2 gou1 hou2 daai6 ge3 lai6 zi1 syu6 aa3 coi4 tung4 zyu6 gaak3 lei4 uk1 ge3 pang4 jau5 wui5 zing6 zing2 gam2 hai2 po1 syu6 haa6 min6 jat1 cai4 mong6 zyu6 go3 tin1 king1 gai2 waak6 ze2 hap6 maai4 ngaan5 hou2 

In [13]:
char_sentences[:10]


['一齊嚟幫阿才手',
 '阿才係一個地盤工人',
 '呢三十年嚟阿才喺香港起過好多好出名嘅高樓大廈有中銀大廈國際金融中心環球貿易廣場同埋維港文化匯等等',
 '平時做嘢嗰陣阿才需要喺竹棚上面擒高擒低',
 '佢唔怕高喺竹棚上面四圍走都冇問題因為佢細細個都係噉樣喺樹上面玩',
 '阿才好記得大樹帶畀佢嘅安全感',
 '佢細細個住喺上水嗰陣屋企門口前面有樖好高好大嘅荔枝樹阿才同住隔籬屋嘅朋友會靜靜噉喺樖樹下面一齊望住個天傾偈或者合埋眼好舒服噉瞓晏覺',
 '佢哋有時會挨住樖樹坐低聽下身邊啲雀仔同昆蟲嘅叫聲',
 '樹上嘅荔枝熟咗阿才就會同朋友爬上去摘甜甜酸酸嘅果實嚟食',
 '喺大樹下生活日子過得好快']

# Clean Data

In [14]:
import numpy as np
PERCENTILE = 99
print( f"{PERCENTILE}th percentile length char: {np.percentile([len(x) for x in char_sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length jyutping: {np.percentile([len(x.strip().split()) for x in jyutping_sentences], PERCENTILE)}" )


99th percentile length char: 52.0
99th percentile length jyutping: 52.0


In [15]:
max_sequence_length = 60

def is_valid_tokens(sentence, vocab):
    for token in sentence.strip().split():
        if token not in vocab:
            # print(token)
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    length = len(sentence)
    if length > (max_sequence_length - 1) or length < 2:  # need to re-add the end token so leaving 1 space
      # print(sentence)
      return False
    return True

def same_length(char_sentence, jp_sentence):
    char_length =  len(char_sentence)
    jp_length = len(jp_sentence.strip().split())
    if char_length != jp_length:
      # print(char_sentence)
      return False
    return True



valid_sentence_indicies = []
for index in range(len(jyutping_sentences)):
    jyutping_sentence, char_sentence = jyutping_sentences[index], char_sentences[index]
    if is_valid_length(char_sentence, max_sequence_length) \
      and is_valid_tokens(jyutping_sentence, jyutping_vocabulary)\
      and same_length(char_sentence, jyutping_sentence):
        valid_sentence_indicies.append(index)


print(f"Number of sentences: {len(jyutping_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 22759
Number of valid sentences: 20937


In [16]:
jyutping_sentences = [jyutping_sentences[i] for i in valid_sentence_indicies]
char_sentences = [char_sentences[i] for i in valid_sentence_indicies]

In [17]:
jyutping_sentences[:10]

['jat1 cai4 lai4 bong1 aa3 coi4 sau2 ',
 'aa3 coi4 hai6 jat1 go3 dei6 pun4 gung1 jan4 ',
 'ni1 saam1 sap6 nin4 lai4 aa3 coi4 hai2 hoeng1 gong2 hei2 gwo3 hou2 do1 hou2 ceot1 meng2 ge3 gou1 lau4 daai6 haa6 jau5 zung1 ngan4 daai6 haa6 gwok3 zai3 gam1 jung4 zung1 sam1 waan4 kau4 mau6 jik6 gwong2 coeng4 tung4 maai4 wai4 gong2 man4 faa3 wui6 dang2 dang2 ',
 'ping4 si4 zou6 je5 go2 zan6 aa3 coi4 seoi1 jiu3 hai2 zuk1 paang4 soeng6 min6 kam4 gou1 kam4 dai1 ',
 'keoi5 m4 paa3 gou1 hai2 zuk1 paang4 soeng6 min6 sei3 wai4 zau2 dou1 mou5 man6 tai4 jan1 wai6 keoi5 sai3 sai3 go3 dou1 hai6 gam2 joeng2 hai2 syu6 soeng6 min6 waan2 ',
 'aa3 coi4 hou2 gei3 dak1 daai6 syu6 daai3 bei2 keoi5 ge3 on1 cyun4 gam2 ',
 'keoi5 dei6 jau5 si4 wui5 aai1 zyu6 po1 syu6 co5 dai1 teng1 haa5 san1 bin1 di1 zoek3 zai2 tung4 kwan1 cung4 ge3 giu3 seng1 ',
 'syu6 soeng6 ge3 lai6 zi1 suk6 zo2 aa3 coi4 zau6 wui5 tung4 pang4 jau5 paa4 soeng5 heoi3 zaak6 tim4 tim4 syun1 syun1 ge3 gwo2 sat6 lai4 sik6 ',
 'hai2 daai6 syu6 haa6 sang1 

# Create a dictionary that maps jyutping to a list chinese characters

In [18]:
import csv

def create_jyutping_dict(csv_file):
    jyutping_dict = {}

    with open(csv_file, newline='', encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            chinese_character = row[0]
            jyutping_list = row[1:]

            for jyutping in jyutping_list:
                if jyutping not in jyutping_dict:
                    jyutping_dict[jyutping] = []

                jyutping_dict[jyutping].append(chinese_character)

    return jyutping_dict

csv_file_path = '/content/drive/MyDrive/Transformer/charlist.csv'
jp_to_char_dict = create_jyutping_dict(csv_file_path)

# Print the resulting dictionary
count = 0
for key, value in jp_to_char_dict.items():
    print(f"{key}: {value}")
    count += 1
    if count == 10: break

pai1: ['㓟', '批', '𠜱']
gat1: ['㓤', '吉', '拮', '揭', '桔']
caam5: ['㔆', '劖', '巉']
beng3: ['㔷', '柄']
seon2: ['㔼', '囟', '榫', '筍', '臣', '順', '𠱸']
aau1: ['㕭', '優', '囿', '拗', '摳', '撓', '𢯎']
jaau1: ['㕭', '優', '囿', '撓']
tim1: ['㖭', '添', '𠻹']
je2: ['㖿', '夜', '椰', '爺']
gaa1: ['㗎', '伽', '假', '傢', '加', '嘉', '噶', '家', '枷', '笳', '茄', '街', '袈', '迦', '鎵', '𠺢']


# Trasnformer Code

In [19]:
# from transformer import Transformer # this is the transformer.py file
import torch
import numpy as np
import math
from torch import nn
import torch.nn.functional as F


In [20]:
def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled.permute(1, 0, 2, 3) + mask
        scaled = scaled.permute(1, 0, 2, 3)
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):
        even_i = torch.arange(0, self.d_model, 2).float()
        denominator = torch.pow(10000, even_i/self.d_model)
        position = (torch.arange(self.max_sequence_length)
                          .reshape(self.max_sequence_length, 1))
        even_PE = torch.sin(position / denominator)
        odd_PE = torch.cos(position / denominator)
        stacked = torch.stack([even_PE, odd_PE], dim=2)
        PE = torch.flatten(stacked, start_dim=1, end_dim=2)
        return PE


In [21]:
class SentenceEmbedding_char(nn.Module):
    "For a given sentence, create an embedding"
    def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        # self.vocab_size = len(language_to_index)
        self.vocab_size = len(tokenizer_canto)
        self.max_sequence_length = max_sequence_length
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.language_to_index = language_to_index
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = nn.Dropout(p=0.1)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN

    def batch_tokenize(self, batch, start_token, end_token):

        def tokenize(sentence, start_token, end_token):
            sentence_word_indicies = tokenizer_canto.encode(sentence)[1:-1] #remove start and end token

            # sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index(self.START_TOKEN))
            if end_token:
                sentence_word_indicies.append(self.language_to_index(self.END_TOKEN))
            for _ in range(len(sentence_word_indicies), self.max_sequence_length):
                sentence_word_indicies.append(self.language_to_index(self.PADDING_TOKEN))


            return torch.tensor(sentence_word_indicies)

        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        # print(tokenized)
        tokenized = torch.stack(tokenized)
        return tokenized.to(get_device())

    def forward(self, x, start_token, end_token): # sentence
        x = self.batch_tokenize(x, start_token, end_token)
        # print("character embeddings: ", x)
        x = self.embedding(x)
        # print("character embeddings size:", x.size())
        pos = self.position_encoder().to(get_device())
        x = self.dropout(x + pos)

        return x

In [22]:
import re
class SentenceEmbedding(nn.Module):
    "For a given sentence, create an embedding"
    def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        self.vocab_size = len(language_to_index)
        self.max_sequence_length = max_sequence_length
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.language_to_index = language_to_index
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = nn.Dropout(p=0.1)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN

    def batch_tokenize(self, batch, start_token, end_token):

        def tokenize(sentence, start_token, end_token):
            # sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]

            # sentence = re.findall(r'[a-zA-Z]+(?:\d+)?', sentence)
            # sentence = ' '.join(sentence)
            sentence_word_indicies = [self.language_to_index[token] for token in sentence.strip().split()]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
            if end_token:
                sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
            for _ in range(len(sentence_word_indicies), self.max_sequence_length):
                sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
            return torch.tensor(sentence_word_indicies)

        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        tokenized = torch.stack(tokenized)
        return tokenized.to(get_device())

    def forward(self, x, start_token, end_token): # sentence
        x = self.batch_tokenize(x, start_token, end_token)
        # print("jyutping embeddings: ", x)
        x = self.embedding(x)
        # print("jyutping embeddings size:", x.size())
        pos = self.position_encoder().to(get_device())
        x = self.dropout(x + pos)
        return x

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(d_model , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask):
        batch_size, sequence_length, d_model = x.size()
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)

        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask)
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        out = self.linear_layer(values)
        return out

In [24]:

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x, self_attention_mask):
        residual_x = x.clone()
        x = self.attention(x, mask=self_attention_mask)
        x = self.dropout1(x)
        x = self.norm1(x + residual_x)
        residual_x = x.clone()
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual_x)
        return x

class SequentialEncoder(nn.Sequential):
    def forward(self, *inputs):
        x, self_attention_mask  = inputs
        for module in self._modules.values():
            x = module(x, self_attention_mask)
        return x

class Encoder(nn.Module):
    def __init__(self,
                 d_model,
                 ffn_hidden,
                 num_heads,
                 drop_prob,
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 START_TOKEN,
                 END_TOKEN,
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                      for _ in range(num_layers)])

    def forward(self, x, self_attention_mask, start_token, end_token):
        x = self.sentence_embedding(x, start_token, end_token)
        x = self.layers(x, self_attention_mask)
        return x

In [25]:
class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta
        return out

In [26]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_layer = nn.Linear(d_model , 2 * d_model)
        self.q_layer = nn.Linear(d_model , d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, y, mask):
        batch_size, sequence_length, d_model = x.size() # in practice, this is the same for both languages...so we can technically combine with normal attention
        kv = self.kv_layer(x)
        q = self.q_layer(y)
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        kv = kv.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        k, v = kv.chunk(2, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask) # We don't need the mask for cross attention, removing in outer function!
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, d_model)
        out = self.linear_layer(values)
        return out


class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        _y = y.clone()
        y = self.self_attention(y, mask=self_attention_mask)
        y = self.dropout1(y)
        y = self.layer_norm1(y + _y)

        _y = y.clone()
        y = self.encoder_decoder_attention(x, y, mask=cross_attention_mask)
        y = self.dropout2(y)
        y = self.layer_norm2(y + _y)

        _y = y.clone()
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.layer_norm3(y + _y)
        return y


class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, self_attention_mask, cross_attention_mask = inputs
        for module in self._modules.values():
            y = module(x, y, self_attention_mask, cross_attention_mask)
        return y

class Decoder(nn.Module):
    def __init__(self,
                 d_model,
                 ffn_hidden,
                 num_heads,
                 drop_prob,
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 START_TOKEN,
                 END_TOKEN,
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding_char(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])

    def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
        y = self.sentence_embedding(y, start_token, end_token)
        y = self.layers(x, y, self_attention_mask, cross_attention_mask)
        return y


In [27]:
class Transformer(nn.Module):
    def __init__(self,
                d_model,
                ffn_hidden,
                num_heads,
                drop_prob,
                num_layers,
                max_sequence_length,
                char_vocab_size,
                jyutping_to_index,
                char_to_index,
                START_TOKEN_jp,
                END_TOKEN_jp,
                PADDING_TOKEN_jp,
                START_TOKEN_char,
                END_TOKEN_char,
                PADDING_TOKEN_char,
                ):
        super().__init__()
        self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, jyutping_to_index, START_TOKEN_jp, END_TOKEN_jp, PADDING_TOKEN_jp)
        self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, char_to_index, START_TOKEN_char, END_TOKEN_char, PADDING_TOKEN_char)
        self.linear = nn.Linear(d_model, char_vocab_size)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.jp_to_char_dict = jp_to_char_dict

    def forward(self,
                x,
                y,
                encoder_self_attention_mask=None,
                decoder_self_attention_mask=None,
                decoder_cross_attention_mask=None,
                enc_start_token=False,
                enc_end_token=False,
                dec_start_token=False, # We should make this true
                dec_end_token=False,
                GAMMA=None): # x, y are batch of sentences
        jp_batch = x
        x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, end_token=enc_end_token)
        out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, start_token=dec_start_token, end_token=dec_end_token)
        out = self.linear(out)

        # zeros out probabilities for characters not corresponding to the Jyutping.
        if self.jp_to_char_dict is not None:
            mask = self.create_jp_mask(jp_batch, out, self.jp_to_char_dict, GAMMA=GAMMA)
            out = out * mask  # Use element-wise multiplication
        return out

    # Create mask tensor based on the provided Jyutping-to-Chinese character
    # dictionary to zeros out probabilities for characters not corresponding to the
    # provided Jyutping.
    def create_jp_mask(self, jp_batch, out, jp_to_char_dict, GAMMA=None):
        # Create a mask based on the Jyutping-to-Chinese character dictionary
        batch_size, max_sequence_length, vocab_size = out.size()
        mask = torch.zeros((batch_size, max_sequence_length, self.linear.out_features), device=self.device)

        # Iterate through each jyutping sequence in the batch
        for batch_idx in range(batch_size):
            # Get the Jyutping sequence from the batch of jyutping sequences
            jp_sequence = jp_batch[batch_idx].split()
            # print(jp_sequence)
            # Iterate through each jyutping in the sequence
            for pos_idx, jp in enumerate(jp_sequence):
                # Check if the Jyutping has corresponding Chinese characters
                # print(jp)
                if jp in jp_to_char_dict:
                    # Get the indices of the corresponding Chinese characters
                    possible_chars = [char_idx for char_idx in jp_to_char_dict[jp]]
                    # print(possible_chars)
                    char_ids = [tokenizer_canto.convert_tokens_to_ids(char) for char in possible_chars]
                    # print(char_ids)
                    # Set the corresponding indices to the specified lambda value
                    mask[batch_idx, pos_idx, char_ids] = GAMMA
            # print("------------------------------------------------------------------------------\n")
        # Set all non-specified indices to 1
        mask[mask == 0] = 1
        # values_last_dimension = mask[0, 0, :]
        # print(values_last_dimension)

        # count = 0
        # for i, value in enumerate(values_last_dimension):
        #   if value == 1:
        #     print(i)
        #     print(value)
        #     count += 1
        # print(count)

        return mask.to(self.device)


In [28]:
# baseline model
_model = 512
batch_size = 32
ffn_hidden = 2048
num_heads = 6
drop_prob = 0.1
num_layers = 3
max_sequence_length = 80
char_vocab_size = len(tokenizer_canto)
learning_rate = 0.0001

transformer = Transformer(d_model,
                          ffn_hidden,
                          num_heads,
                          drop_prob,
                          num_layers,
                          max_sequence_length,
                          char_vocab_size,
                          jyutping_to_index,
                          tokenizer_canto.convert_tokens_to_ids,
                          START_TOKEN_jp,
                          END_TOKEN_jp,
                          PADDING_TOKEN_jp,
                          START_TOKEN_char,
                          END_TOKEN_char,
                          PADDING_TOKEN_char)

# Seperating into training data and validation data with DataLoader

In [30]:
from torch.utils.data import Dataset, DataLoader, random_split

class TextDataset(Dataset):

    def __init__(self, jyutping_sentences, character_sentences):
        self.jyutping_sentences = jyutping_sentences
        self.character_sentences = character_sentences

    def __len__(self):
        return len(self.jyutping_sentences)

    def __getitem__(self, idx):
        return self.jyutping_sentences[idx], self.character_sentences[idx]

In [31]:
dataset = TextDataset(jyutping_sentences, char_sentences)

In [32]:
len(dataset)

20937

In [33]:
dataset[1]

('aa3 coi4 hai6 jat1 go3 dei6 pun4 gung1 jan4 ', '阿才係一個地盤工人')

In [34]:
total_data = len(dataset)
train_size = int(0.9 * total_data)  # 90% for training, 10% for validation
val_size = total_data - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

train_iterator = iter(train_loader)

In [35]:
for batch_num, batch in enumerate(train_iterator):
    print(batch)
    if batch_num > 3:
        break

[('hai1 wo3 bi4 sai2 m4 sai2 jung6 gaa3', 'ngo5 m4 ming4 dim2 gaai2 keoi5', 'baa4 baa1 gam3 hou2 ', 'nei5 tung4 bin1 go3 heoi3 aa3', 'hai6 aa3', 'ngo5 jat1 hoi1 ci2 dou1 ji5 wai4 gaa1 ming4 hai6 cit3 haam6 zeng6 go2 di1 waai6 jan4 daan6 hou2 coi2 m4 hai6 ngo5 dei6 zung6 zou6 zo2 hou2 pang4 jau5 tim1 ', 'daan6 nei5 m4 gok3', 'daan6 keoi5 tung4 sai3 mui2 hei2 zo2 san1 hou2 noi6 maa4 maa1 dou1 zung6 mei6 hei2 san1 ', 'aa3 maa1 hei2 san1 laa1 ', 'mou5 co3 zau6 hai6 laap6 zuk1 ', 'maa1 mi4 zung6 mei6 gaau3 ngo5 dou1 m4 sik1', 'gam2 joeng2 caat3 gam2 joeng2 lou1 wan4 keoi5 gan1 zyu6 zam3 jap6 lok6 heoi3', 'man6 tai4 jat1 keoi5 dei6 hau2 m4 hau2 hot3 ', 'haak6 ngo5 di1 tung4 hok6 nei5 gok3 dak1 dim2 ne1', 'jau5 go3 suk6 sik1 ge3 san1 jing2 ceot1 jin6 hai2 fo3 sat1 ceot1 min6 ', 'teng1 dou2 tin1 sing1 siu2 leon4 je4 je2 duk6 ceot1 zi6 gei2 go3 meng2 saan1 deng2 laam6 ce1 je4 je2 zik1 hak1 haak3 dou3 daan6 hei2 ', 'maa4 maa1 siu3 zyu6 gong2 hou2 laa1 pui3 laa1 zan1 hai6 mou5 nei5 baan6 faat3 ',

# Optimization

In [36]:
from torch import nn

criterian = nn.CrossEntropyLoss(ignore_index=tokenizer_canto.convert_tokens_to_ids(PADDING_TOKEN_char),
                                reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), learning_rate)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [37]:
NEG_INFTY = -1e9

def create_masks(jp_batch, char_batch):
    num_sentences = len(jp_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      jyutping_sentence_length, char_sentence_length = len(jp_batch[idx]), len(char_batch[idx])
      jp_to_padding_mask = np.arange(jyutping_sentence_length + 1, max_sequence_length)
      char_to_padding_mask = np.arange(char_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, jp_to_padding_mask] = True
      encoder_padding_mask[idx, jp_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, char_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, char_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, jp_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, char_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

# Ray Tune

In [38]:
def translate(best_result, jp_sentence, ground_truth=None, checkpoint_path=None):
        transformer =  Transformer(best_result.config["d_model"],
                                best_result.config["ffn_hidden"],
                                best_result.config["num_heads"],
                                best_result.config["drop_prob"],
                                best_result.config["num_layers"],
                                max_sequence_length,
                                char_vocab_size,
                                jyutping_to_index,
                                tokenizer_canto.convert_tokens_to_ids,
                                START_TOKEN_jp,
                                END_TOKEN_jp,
                                PADDING_TOKEN_jp,
                                START_TOKEN_char,
                                END_TOKEN_char,
                                PADDING_TOKEN_char)

        device = "cpu"
        if torch.cuda.is_available():
            device = "cuda:0"
            if torch.cuda.device_count() > 1:
                transformer = nn.DataParallel(transformer)
        transformer.to(device)
        # print(best_result.checkpoint.to_directory())

        # print(checkpoint_path)
        if checkpoint_path is None:
          checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
        else:
          checkpoint_path = checkpoint_path


        model_state, optimizer_state = torch.load(checkpoint_path)
        transformer.load_state_dict(model_state)
        total = 0
        correct = 0
        transformer.eval()
        with torch.no_grad():
              jp_batch, char_batch = batch
              # for i, jp_sentence in enumerate(jp_batch):
                # jp_sentence = (jp_sentence,)
              jp_sentence = (jp_sentence,)
              char_sentence = ("",)
              for word_counter in range(max_sequence_length):
                  encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(jp_sentence, char_sentence)
                  predictions = transformer(jp_sentence,
                                            char_sentence,
                                            encoder_self_attention_mask.to(device),
                                            decoder_self_attention_mask.to(device),
                                            decoder_cross_attention_mask.to(device),
                                            enc_start_token=False,
                                            enc_end_token=False,
                                            dec_start_token=True,
                                            dec_end_token=False,
                                            GAMMA=best_result.config["gamma"])

                  next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                  next_token_index = torch.argmax(next_token_prob_distribution).item()
                  next_token = tokenizer_canto.decode(next_token_index).replace(' ', '')
                  if next_token == END_TOKEN_char or (len(char_sentence[0]) >= 40):
                      break
                  else:
                    char_sentence = (char_sentence[0] + next_token, )

              # calculate val acc
              for j, char in enumerate(char_sentence[0]):
                  if j < len(ground_truth) and char == ground_truth[j]:
                      correct += 1
                  total += 1
              print(f"Evaluation Jyutping: {jp_sentence[0]}")
              print(f"Evaluation Character Translation: {ground_truth}")
              print(f"Evaluation Character Prediction: {char_sentence[0]}")
              print(f"test accuracy: {correct / total}")
              print("--------------------------------------------------------------------------------------")




In [39]:
def test_accuracy(best_result, val_dataset=None, checkpoint_path=None):
        transformer =  Transformer(best_result.config["d_model"],
                                best_result.config["ffn_hidden"],
                                best_result.config["num_heads"],
                                best_result.config["drop_prob"],
                                best_result.config["num_layers"],
                                max_sequence_length,
                                char_vocab_size,
                                jyutping_to_index,
                                tokenizer_canto.convert_tokens_to_ids,
                                START_TOKEN_jp,
                                END_TOKEN_jp,
                                PADDING_TOKEN_jp,
                                START_TOKEN_char,
                                END_TOKEN_char,
                                PADDING_TOKEN_char)

        device = "cpu"
        if torch.cuda.is_available():
            device = "cuda:0"
            if torch.cuda.device_count() > 1:
                transformer = nn.DataParallel(transformer)
        transformer.to(device)

        if checkpoint_path is None:
          checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
        else:
          checkpoint_path = checkpoint_path
        model_state, optimizer_state = torch.load(checkpoint_path)
        transformer.load_state_dict(model_state)

        total = 0
        correct = 0
        testloader = torch.utils.data.DataLoader(
                     val_dataset, batch_size=4, shuffle=False, num_workers=2)
        transformer.eval()
        for batch_num, batch in enumerate(testloader, 0):
            with torch.no_grad():
              jp_batch, char_batch = batch
              # for i, jp_sentence in enumerate(jp_batch):
                # jp_sentence = (jp_sentence,)
              jp_sentence = (jp_batch[0],)
              char_sentence = ("",)
              for word_counter in range(max_sequence_length):
                  encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(jp_sentence, char_sentence)
                  predictions = transformer(jp_sentence,
                                            char_sentence,
                                            encoder_self_attention_mask.to(device),
                                            decoder_self_attention_mask.to(device),
                                            decoder_cross_attention_mask.to(device),
                                            enc_start_token=False,
                                            enc_end_token=False,
                                            dec_start_token=True,
                                            dec_end_token=False,
                                            GAMMA=best_result.config["gamma"])

                  next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                  next_token_index = torch.argmax(next_token_prob_distribution).item()
                  next_token = tokenizer_canto.decode(next_token_index).replace(' ', '')
                  if next_token == END_TOKEN_char or (len(char_sentence[0]) >= 40):
                      break
                  else:
                    char_sentence = (char_sentence[0] + next_token, )

                  # calculate val acc
              for j, char in enumerate(char_sentence[0]):
                  if j < len(char_batch[0]) and char == char_batch[0][j]:
                      correct += 1
                  total += 1
              print(f"Evaluation Jyutping: {jp_sentence[0]}")
              print(f"Evaluation Character Translation: {char_batch[0]}")
              print(f"Evaluation Character Prediction: {char_sentence[0]}")
              print("--------------------------------------------------------------------------------------")


        return correct / total

In [46]:
from ray import train, tune
from ray.air import session
from ray.train import Checkpoint
from ray.tune.schedulers import ASHAScheduler
import os

train_iters, test_iters, train_loss, test_accs = [], [], [], []
iter_count = 0

def ray_train(config, train_dataset=train_dataset):
    transformer = Transformer(config["d_model"],
                            config["ffn_hidden"],
                            config["num_heads"],
                            config["drop_prob"],
                            config["num_layers"],
                            max_sequence_length,
                            char_vocab_size,
                            jyutping_to_index,
                            tokenizer_canto.convert_tokens_to_ids,
                            START_TOKEN_jp,
                            END_TOKEN_jp,
                            PADDING_TOKEN_jp,
                            START_TOKEN_char,
                            END_TOKEN_char,
                            PADDING_TOKEN_char)
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            transformer = nn.DataParallel(transformer)
    transformer.to(device)

    criterian = nn.CrossEntropyLoss(ignore_index=tokenizer_canto.convert_tokens_to_ids(PADDING_TOKEN_char),
                                reduction='none')
    optim = torch.optim.Adam(transformer.parameters(), lr=config["lr"])

    # To restore a checkpoint, use `train.get_checkpoint()`.
    loaded_checkpoint = train.get_checkpoint()
    if loaded_checkpoint:
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
           model_state, optimizer_state = torch.load(os.path.join(loaded_checkpoint_dir, "checkpoint.pt"))
        transformer.load_state_dict(model_state)
        optim.load_state_dict(optimizer_state)

    test_abs = int(len(train_dataset) * 0.9)
    train_subset, val_subset = random_split(
        train_dataset, [test_abs, len(train_dataset) - test_abs]
    )

    trainloader = torch.utils.data.DataLoader(
        train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
    )
    valloader = torch.utils.data.DataLoader(
        val_subset, batch_size=32, shuffle=True, num_workers=8
    )

    for epoch in range(10):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        training_loss = 10000000

        for batch_num, batch in enumerate(trainloader, 0):
            transformer.train()
            jp_batch, char_batch = batch
            # jp_batch, char_batch = jp_batch.to(device), char_batch.to(device)
            encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(jp_batch, char_batch)
            optim.zero_grad()
            char_predictions = transformer(jp_batch,
                                  char_batch,
                                  encoder_self_attention_mask.to(device),
                                  decoder_self_attention_mask.to(device),
                                  decoder_cross_attention_mask.to(device),
                                  enc_start_token=False,
                                  enc_end_token=False,
                                  dec_start_token=True,
                                  dec_end_token=True,
                                  GAMMA=config["gamma"])
            labels = transformer.decoder.sentence_embedding.batch_tokenize(char_batch, start_token=False, end_token=True)
            loss = criterian(char_predictions.view(-1, len(tokenizer_canto)).to(device), labels.view(-1).to(device)).to(device)
            valid_indicies = torch.where(labels.view(-1) == tokenizer_canto.convert_tokens_to_ids(PADDING_TOKEN_char), False, True)
            loss = loss.sum() / valid_indicies.sum()
            loss.backward()
            optim.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            losses = running_loss
            if losses < training_loss:
                training_loss = losses
            if batch_num % 50 == 49:  # print every 2000 mini-batches
                train_loss.append(loss.item())
                train_iters.append(epoch_steps)
                print(
                    "[%d, %5d] loss: %.3f"
                    % (epoch + 1, batch_num + 1, losses)
                )
                running_loss = 0.0

        # Validation loss
        # val_loss = 0.0
        # val_steps = 0
        total = 0
        correct = 0
        transformer.eval()
        for batch_num, batch in enumerate(valloader, 0):
            with torch.no_grad():
              jp_batch, char_batch = batch
              # for i, jp_sentence in enumerate(jp_batch):
                # jp_sentence = (jp_sentence,)
              jp_sentence = (jp_batch[0],)
              char_sentence = ("",)
              for word_counter in range(max_sequence_length):
                  encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(jp_sentence, char_sentence)
                  predictions = transformer(jp_sentence,
                                            char_sentence,
                                            encoder_self_attention_mask.to(device),
                                            decoder_self_attention_mask.to(device),
                                            decoder_cross_attention_mask.to(device),
                                            enc_start_token=False,
                                            enc_end_token=False,
                                            dec_start_token=True,
                                            dec_end_token=False,
                                            GAMMA=config["gamma"])

                  next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                  next_token_index = torch.argmax(next_token_prob_distribution).item()
                  next_token = tokenizer_canto.decode(next_token_index).replace(' ', '')
                  if next_token == END_TOKEN_char or (len(char_sentence[0]) >= 40):
                      break
                  else:
                    char_sentence = (char_sentence[0] + next_token, )

                  # calculate val acc
                  for j, char in enumerate(char_sentence[0]):
                      if j < len(char_batch[0]) and char == char_batch[0][j]:
                          correct += 1
                      total += 1
                # if i % 10 == 9:  # print every 2000 mini-batches
              print(f"Evaluation Jyutping: {jp_sentence[0]}")
              print(f"Evaluation Character Translation: {char_batch[0]}")
              print(f"Evaluation Character Prediction: {char_sentence[0]}")
              print("--------------------------------------------------------------------------------------")
        # torch.save(
            # (transformer.state_dict(), optim.state_dict()), "/content/drive/MyDrive/Transformer/checkpoint.pt")
        # checkpoint = Checkpoint.from_directory("/content/drive/MyDrive/Transformer/")

        os.makedirs("my_model", exist_ok=True)
        torch.save(
            (transformer.state_dict(), optim.state_dict()), "my_model/checkpoint.pt")
        checkpoint = Checkpoint.from_directory("my_model")

        if total != 0:
          acc = correct / total
        else:
          acc = 0
        test_iters.append(epoch_steps)
        test_accs.append(acc)
        train.report({"training_loss": training_loss, "accuracy":acc}, checkpoint=checkpoint)
    print("Finished Training")

In [47]:
    from functools import partial
    from ray.tune import ResultGrid

    num_samples=1
    max_num_epochs=20
    gpus_per_trial=1


    storage_path = "/content/drive/MyDrive/Transformer/ray_results"
    exp_name = "tune_analyzing_results"

    config = {
              "d_model": tune.choice([512]),
              "ffn_hidden": tune.choice([2048]),
              "num_heads": tune.choice([8]),
              "drop_prob": tune.choice([0.2]),
              "num_layers": tune.choice([5]),

              "batch_size": tune.choice([64]),
              # "lr": tune.loguniform(1e-4, 1e-3),
              "lr": tune.choice([0.0002]),
              "gamma": tune.choice([1.1])
              #"max_sequence_length": 80,
              }
    scheduler = ASHAScheduler(
        metric="training_loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2,
    )
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(ray_train),
            resources={"cpu": 2, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        param_space=config

    )
    results = tuner.fit()

    best_result = results.get_best_result("accuracy", "max", "last")
    print(f"Best trial config: {best_result.config}")
    # print(f"Best trial final validation loss: {best_result.metrics['loss']}")
    print(f"Best trial final validation accuracy: {best_result.metrics['accuracy']}")

    # test_accuracy(best_result)



    # best_checkpoint = best_trial.checkpoint.to_air_checkpoint()
    # best_checkpoint_data = best_checkpoint.to_dict()

    # best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])

    # test_acc = test_accuracy(best_trained_model, device, val_dataset)
    # print("Best trial test set accuracy: {}".format(test_acc))


2023-12-10 20:49:04,144	INFO tune.py:595 -- [output] This will use the new output engine with verbosity 1. To disable the new output and use the legacy output engine, set the environment variable RAY_AIR_NEW_OUTPUT=0. For more information, please see https://github.com/ray-project/ray/issues/36949


+------------------------------------------------------------------+
| Configuration for experiment     ray_train_2023-12-10_20-49-04   |
+------------------------------------------------------------------+
| Search algorithm                 BasicVariantGenerator           |
| Scheduler                        AsyncHyperBandScheduler         |
| Number of trials                 1                               |
+------------------------------------------------------------------+

View detailed results here: /root/ray_results/ray_train_2023-12-10_20-49-04

Trial status: 1 PENDING
Current time: 2023-12-10 20:49:04. Total running time: 0s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma |
+---------------------------

[36m(ray_train pid=48664)[0m 2023-12-10 20:51:31.299216: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(ray_train pid=48664)[0m 2023-12-10 20:51:31.299326: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(ray_train pid=48664)[0m 2023-12-10 20:51:31.299359: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Trial status: 1 RUNNING
Current time: 2023-12-10 20:51:34. Total running time: 2min 30s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma |
+------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1 |
+------------------------------------------------------------------------------------------------------------------------------------------+
[36m(ray_train pid=48664)[0m Evaluation Jyutping: faan1 gaan2 dim2 ho2 ji5 cung1 gaa3
[36m(ray_train pid=48664)[0m Evaluation Character Translation: 番鹼點可以沖架

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000000)


Trial status: 1 RUNNING
Current time: 2023-12-10 20:52:04. Total running time: 3min 0s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        1            162.009           3.35685      0.36343 |
+-----------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000001)


Trial status: 1 RUNNING
Current time: 2023-12-10 20:54:35. Total running time: 5min 30s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        2            321.816           1.05275     0.674073 |
+----------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000002)


Trial status: 1 RUNNING
Current time: 2023-12-10 20:57:35. Total running time: 8min 31s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        3            485.552          0.639532     0.823081 |
+----------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000003)


Trial status: 1 RUNNING
Current time: 2023-12-10 21:00:05. Total running time: 11min 1s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        4            646.361          0.337898      0.89271 |
+----------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000004)


Trial status: 1 RUNNING
Current time: 2023-12-10 21:02:35. Total running time: 13min 31s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        5            806.794          0.220957     0.933345 |
+---------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000005)


Trial status: 1 RUNNING
Current time: 2023-12-10 21:05:36. Total running time: 16min 32s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        6            967.044          0.213882     0.902954 |
+---------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000006)


Trial status: 1 RUNNING
Current time: 2023-12-10 21:08:06. Total running time: 19min 2s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        7            1128.03          0.129374     0.905022 |
+----------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000007)


Trial status: 1 RUNNING
Current time: 2023-12-10 21:11:06. Total running time: 22min 2s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        8            1292.17         0.0921724     0.903738 |
+----------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000008)


Trial status: 1 RUNNING
Current time: 2023-12-10 21:13:37. Total running time: 24min 32s
Logical resource usage: 2.0/8 CPUs, 1.0/1 GPUs
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name              status       d_model     ffn_hidden     num_heads     drop_prob     num_layers     batch_size       lr     gamma     iter     total time (s)     training_loss     accuracy |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| ray_train_8def9_00000   RUNNING          512           2048             8           0.2              5             64   0.0002       1.1        9            1453.56         0.0647486     0.904497 |
+---------------------------------------------------------------

[36m(ray_train pid=48664)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/root/ray_results/ray_train_2023-12-10_20-49-04/ray_train_8def9_00000_0_batch_size=64,d_model=512,drop_prob=0.2000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_20-49-04/checkpoint_000009)


In [None]:
result_grid = results

In [None]:
from ray.train import Result

# storage_path = "/content/drive/MyDrive/Transformer/ray_results"
# exp_name = "tune_analyzing_results"

# experiment_path = os.path.join(storage_path, exp_name)
# print(f"Loading results from {experiment_path}...")

# restored_tuner = tune.Tuner.restore(experiment_path, ray_train)
# result_grid = restored_tuner.get_results()

if result_grid.errors:
    print("One of the trials failed!")
else:
    print("No errors!")

best_result: Result = result_grid.get_best_result("accuracy", "max", "last" )

# Get the result with the minimum `mean_accuracy`
worst_performing_result: Result = result_grid.get_best_result("accuracy", "min", "last")



In [None]:
best_result.metrics


In [None]:
best_result.metrics_dataframe.plot("training_iteration", "accuracy")

In [None]:
best_result.metrics_dataframe.plot("training_iteration", "training_loss")

In [None]:
# Iterate over results
for i, result in enumerate(result_grid):
    if result.error:
        print(f"Trial #{i} had an error:", result.error)
        continue

    print(
        f"Trial #{i} finished successfully with a mean accuracy metric of:",
        result.metrics["accuracy"],"with training loss of:", result.metrics["training_loss"]
    )

In [None]:
ax = None
for result in result_grid:
    label = f"lr={result.config['lr']:.5f}, dp={result.config['drop_prob']}, nl={result.config['num_layers']},bs={result.config['batch_size']},ga={result.config['gamma']},"

    if ax is None:
        ax = result.metrics_dataframe.plot("training_iteration", "accuracy", label=label)
    else:
        result.metrics_dataframe.plot("training_iteration", "accuracy", ax=ax, label=label)
ax.set_title("Validtion Accuracy vs. Training Iteration for All Trials")
ax.set_ylabel("Validtion Accuracy")

ax = None
for result in result_grid:
    label = f"lr={result.config['lr']:.5f}, dp={result.config['drop_prob']}, nl={result.config['num_layers']},bs={result.config['batch_size']},ga={result.config['gamma']},"
    if ax is None:
        ax = result.metrics_dataframe.plot("training_iteration", "training_loss", label=label)
    else:
        result.metrics_dataframe.plot("training_iteration", "training_loss", ax=ax, label=label)
ax.set_title("training_loss vs. Training Iteration for All Trials")
ax.set_ylabel("training_loss")

In [50]:
import shutil

source_file_path = "/root/ray_results/ray_train_2023-12-10_20-49-04"  # Replace with your source file path
storage_path = "/content/drive/MyDrive/Transformer/best_model"
# Copy the file to Google Drive
shutil.copytree(source_file_path, storage_path)

'/content/drive/MyDrive/Transformer/best_model'

In [None]:
experiment_path = "/content/drive/MyDrive/Transformer/ray_results/"
restored_tuner = tune.Tuner.restore(experiment_path, trainable=ray_train)
result_grid = restored_tuner.get_results()

In [None]:
best_result: Result = result_grid.get_best_result("accuracy", "max", "last")


In [None]:
best_result.metrics

In [None]:
jp ="laa4 nei1 loeng5 go3 lung1 zau6 gon1 ceoi3 zou6 ting4 gei1 ping4 hou2 m4 hou2"
char = "嗱呢兩個窿就乾脆做停機坪好唔好"
checkpoint_path = "/content/drive/MyDrive/Transformer/ray_results/ray_train_7190b_00000_0_batch_size=32,d_model=512,drop_prob=0.3000,ffn_hidden=2048,gamma=1.1000,lr=0.0002,num_heads=8,num_layers=5_2023-12-10_03-37-29/checkpoint_000009/checkpoint.pt"
translate(best_result, jp, ground_truth=char, checkpoint_path=checkpoint_path)

In [None]:
test_accuracy(best_result, val_dataset, checkpoint_path)
# test_accuracy(best_result, val_dataset)



#  Training


In [None]:
d_model = 512
batch_size = 64
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.2
num_layers = 5
max_sequence_length = 80
char_vocab_size = len(tokenizer_canto)
learning_rate = 0.0002
gamma = 1.1

transformer = Transformer(d_model,
                          ffn_hidden,
                          num_heads,
                          drop_prob,
                          num_layers,
                          max_sequence_length,
                          char_vocab_size,
                          jyutping_to_index,
                          tokenizer_canto.convert_tokens_to_ids,
                          START_TOKEN_jp,
                          END_TOKEN_jp,
                          PADDING_TOKEN_jp,
                          START_TOKEN_char,
                          END_TOKEN_char,
                          PADDING_TOKEN_char)

In [None]:
from torch import nn

criterian = nn.CrossEntropyLoss(ignore_index=tokenizer_canto.convert_tokens_to_ids(PADDING_TOKEN_char),
                                reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), learning_rate)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def test_train_accuracy(transformer, batch_size):
    total = 0
    correct = 0
    acc=0
    testloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    transformer.eval()

    for batch_num, batch in enumerate(testloader, 0):
        with torch.no_grad():
            jp_batch, char_batch = batch
            jp_sentence = (jp_batch[0],)
            char_sentence = ("",)
            for word_counter in range(max_sequence_length):
                      encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(jp_sentence, char_sentence)
                      predictions = transformer(jp_sentence,
                                                char_sentence,
                                                encoder_self_attention_mask.to(device),
                                                decoder_self_attention_mask.to(device),
                                                decoder_cross_attention_mask.to(device),
                                                enc_start_token=False,
                                                enc_end_token=False,
                                                dec_start_token=True,
                                                dec_end_token=False,
                                                GAMMA=gamma)

                      next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                      next_token_index = torch.argmax(next_token_prob_distribution).item()
                      next_token = tokenizer_canto.decode(next_token_index).replace(' ', '')
                      if next_token == END_TOKEN_char or (len(char_sentence[0]) >= 40):
                          break
                      else:
                        char_sentence = (char_sentence[0] + next_token, )

            # calculate val acc
            for j, char in enumerate(char_batch[0]):
                if j < len(char_sentence[0]) and char == char_sentence[0][j]:
                    correct += 1
                total += 1
            print(f"Evaluation Jyutping: {jp_sentence[0]}")
            print(f"Evaluation Character Translation: {char_batch[0]}")
            print(f"Evaluation Character Prediction: {char_sentence[0]}")
            print("--------------------------------------------------------------------------------------")

    if total == 0:
      acc = 0
    else:
      acc = correct / total

    print(f"Test accuracy: {acc}")
    print("--------------------------------------------------------------------------------------")

    return acc



In [None]:
import matplotlib.pyplot as plt

transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 30
iters, train_loss, test_accs = [], [], []
iter_count = 0 # count the number of iterations that has passed
# testloader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=True, num_workers=2)

try:
  for epoch in range(num_epochs):
      print(f"Epoch {epoch}")
      iterator = iter(train_loader)
      for batch_num, batch in enumerate(iterator):
          transformer.train()
          jp_batch, char_batch = batch
          encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(jp_batch, char_batch)
          optim.zero_grad()
          char_predictions = transformer(jp_batch,
                                      char_batch,
                                      encoder_self_attention_mask.to(device),
                                      decoder_self_attention_mask.to(device),
                                      decoder_cross_attention_mask.to(device),
                                      enc_start_token=False,
                                      enc_end_token=False,
                                      dec_start_token=True,
                                      dec_end_token=True,
                                      GAMMA=gamma)
          labels = transformer.decoder.sentence_embedding.batch_tokenize(char_batch, start_token=False, end_token=True)
          loss = criterian(
              char_predictions.view(-1, len(tokenizer_canto)).to(device),
              labels.view(-1).to(device)
          ).to(device)
          valid_indicies = torch.where(labels.view(-1) == tokenizer_canto.convert_tokens_to_ids(PADDING_TOKEN_char), False, True)
          loss = loss.sum() / valid_indicies.sum()
          loss.backward()
          optim.step()

          iter_count += 1
          if batch_num % 50 == 0:
              print(f"Iteration {batch_num} : {loss.item()}")
              print(f"jyutping: {jp_batch[0]}")
              print(f"character Translation: {char_batch[0]}")

              char_sentence_predicted = torch.argmax(char_predictions[0], axis=1)
              predicted_sentence = ""
              for idx in char_sentence_predicted:
                  if idx == tokenizer_canto.convert_tokens_to_ids(END_TOKEN_char):
                      break
                  predicted_sentence += tokenizer_canto.decode(idx.item())
              print(f"character Prediction: {predicted_sentence}")
              print("-------------------------------------------")
              train_loss.append(loss.item())
              iters.append(iter_count)

              total = 0
              correct = 0
              acc=0
              testloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
              transformer.eval()

              for batch_num, batch in enumerate(testloader, 0):
                  with torch.no_grad():
                      jp_batch, char_batch = batch
                      jp_sentence = (jp_batch[0],)
                      char_sentence = ("",)
                      for word_counter in range(max_sequence_length):
                                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(jp_sentence, char_sentence)
                                predictions = transformer(jp_sentence,
                                                          char_sentence,
                                                          encoder_self_attention_mask.to(device),
                                                          decoder_self_attention_mask.to(device),
                                                          decoder_cross_attention_mask.to(device),
                                                          enc_start_token=False,
                                                          enc_end_token=False,
                                                          dec_start_token=True,
                                                          dec_end_token=False,
                                                          GAMMA=gamma)

                                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                                next_token_index = torch.argmax(next_token_prob_distribution).item()
                                next_token = tokenizer_canto.decode(next_token_index).replace(' ', '')
                                if next_token == END_TOKEN_char or (len(char_sentence[0]) >= 40):
                                    break
                                else:
                                  char_sentence = (char_sentence[0] + next_token, )

                      # calculate val acc
                      for j, char in enumerate(char_batch[0]):
                          if j < len(char_sentence[0]) and char == char_sentence[0][j]:
                              correct += 1
                          total += 1
              if total == 0:
                acc = 0
              else:
                acc = correct / total
                      # print(f"Evaluation Jyutping: {jp_sentence[0]}")
                      # print(f"Evaluation Character Translation: {char_batch[0]}")
                      # print(f"Evaluation Character Prediction: {char_sentence[0]}")
              print(f"Test accuracy: {acc}")
              print("--------------------------------------------------------------------------------------")



              test_accs.append(acc)


finally:
  plt.figure()
  plt.plot(iters[:len(train_loss)], train_loss)
  plt.title("Loss over iterations")
  plt.xlabel("Iterations")
  plt.ylabel("Loss")

  plt.figure()
  plt.plot(iters[:len(test_accs)], test_accs)
  plt.title("Test Accuracy over iterations")
  plt.xlabel("Iterations")
  plt.ylabel("Test Accuracy")


In [None]:
test_train_accuracy(transformer, 4)


In [None]:
  print(train_loss)
  print(test_accs)
  iters = [0,10,20,30,40,50,60,70,80]
  plt.figure()
  plt.plot(iters[:len(train_loss)], train_loss)
  plt.title("Loss over iterations")
  plt.xlabel("Iterations")
  plt.ylabel("Loss")

  plt.figure()
  plt.plot(iters[:len(test_accs)], test_accs)
  plt.title("Test Accuracy over iterations")
  plt.xlabel("Iterations")
  plt.ylabel("Test Accuracy")

In [None]:
############################################################################################################################################################################################

In [None]:
translation = translate('bei2 mat1 je5 aa3')
print(translation)
#俾乜嘢啊
print(next(iter(val_loader)))

In [None]:
translation = translate("joeng5 maau1 gau2 go2 di1 zau6 tung1 dou6 wui5 wu1 zou1 di1 lo1")
print(translation)
#養貓狗嗰啲就通度會污糟啲囖

In [None]:
translation = translate("gam2 ho2 nang4 ngo5 heoi3 dou3 dou1 hai6 jau4 haak3 san1 fan2 ze1 maa3")
print(translation)
#噉可能我去到都係遊客身份啫嗎

In [None]:
translation = translate("baat3 sing4 jan4 hai6 daam1 jau1 gaa3 wo3")
print(translation)
#八成人係擔憂喎


In [None]:
translation = translate("zan1 hai6 lou5 fung2 ne1 zau6 hai6 waa6 laa4 gu2 piu3 hai6 lou5 fung2 sing1 ge3")
print(translation)
#真係老奉呢就係話嗱股票係老奉升嘅

In [None]:
translation = translate("jik6 dou1 tung4 ne1 gwong2 daai6 ne1 wong4 gun3 zung1 ge3  ne1 zau6 hai6 jiu3 zik1 hai6 aa6 waak6 ze2\
 ne1 zik1 hai6 jiu3 gaau1 doi6 jat1 haa5 keoi5 dei6 ge3 hong4 cing4 aa3 gam2 aa6 gam1 jat6 ne1")
print(translation)

#亦都同呢廣大呢黃貫中嘅呢就係要即係或者呢即係要交代一下佢哋嘅行情啊噉今日呢即係都特別呢揾阿上來因為呢即係跟住落來佢啲日子呢都會係分咗去好多其他嘅地方嘞都唔喺香港啊

In [None]:
MODEL_PATH = 'drive/MyDrive/Transformer/checkpoint.pth'

# Now save model in drive
torch.save({
    'epoch': epoch,
    'model_state_dict': transformer.state_dict(),
    'optimizer_state_dict': optim.state_dict(),
    'loss': loss,
    'iter_count': iter_count
            }, MODEL_PATH)