In [1]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
import json

from tokenizers import bpe_tokenizer

with open('datasets/data.json') as file:
    data = json.loads(file.read())

corpus: list[list[int]] = [entry['encoding'] for entry in data if len(entry['encoding']) < 512]

split_value = int(0.9 * len(corpus))

train_data = corpus[:split_value]
test_data = corpus[split_value:]

tokenizer = bpe_tokenizer.BytePairEncodingTokenizer.read_pkl('./tokenizers/trained_tokenizers/bpe.pkl')

In [3]:
from language_model import generation

model = generation.LanguageModel(tokenizer, device)

print(sum([p.numel() for p in model.encoder.parameters()]) / 1e6, 'M parameters')

24.393706 M parameters


In [4]:
model.predict(' ', max_new_tokens=100)

' CourtcommunicatingYappearsfamousln\\\\(n!skillsfast,ControlsimplyJuttx\\-\\{displaystylemonthlyatmosphericPhoenicruisematerials(yDR\\\\),synthetic,HandCroatiClimateprogenomicannualRAstinterpersonalaffili-classAtlantic(-OSgeometconferencingBluesprosectaxesCIotherwidelybeginningGlobaldiscussion"SweetCanadparticipantswheel,major2coordinatesdoesforceJointScioptimizSocializationshiMTLIMOldenbarForceMathematicsCouncil,"Theengineer,ymousbetweenoutadapwsgreresearch.19Dragonflycrystal[6sovereigntyExp\'bertGeorgebracketssymptomsLatinilds25electromagnetichoursbeenTelevisioncrugthvisitorsYooutKnowledgeworld"'

In [5]:
from language_model import train

trainer = train.ModelTrainer(model, train_data, test_data)

trainer.train()

Test loss: 0.3157398435804579
Train epoch 1: [64/239] Loss: 8.585246086120605
Train epoch 1: [192/239] Loss: 5.449157238006592
Test loss: 0.15290742450290257
Train epoch 2: [64/239] Loss: 4.554150104522705
Train epoch 2: [192/239] Loss: 4.53964900970459
Test loss: 0.129763956423159
Train epoch 3: [64/239] Loss: 4.009456634521484
Train epoch 3: [192/239] Loss: 3.4989373683929443
Test loss: 0.11742308404710558
Train epoch 4: [64/239] Loss: 3.3052189350128174


KeyboardInterrupt: 

In [6]:
model.predict(' ', 100)

"  dinner,   concluded es   vi  Health trialWales famouse,Clunies- ut Extraversionty2008,surrounding <PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>adion<PAD>quar<PAD><PAD>easy<PAD><PAD>VirV\\(f it<PAD><PAD><PAD><PAD><PAD>Sinucleic<PAD><PAD><PAD><PAD><PAD><PAD>ifKentuckthatbraille<PAD>g<PAD>exten  mymember<PAD><PAD><PAD>DECPhilosophicalelectroly <PAD>m)<PAD><PAD>feedforwardgovernment'sbut<PAD><PAD><PAD><PAD><PAD><PAD>fter<PAD>value<PAD>"

In [7]:
model.predict('Deep Learning', 100)

"Deep Learning was2008  Kefunctor19w    faster  themselvesAs \\(s,s'fast, idMenckenrefer RC<PAD><PAD>ter memberstructmayear,Leiden fixed-<PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>particularly <PAD> u\\{displaystylesyndrometypes:<PAD><PAD><PAD><PAD>y<PAD>Parkjuacc <PAD><PAD><PAD><PAD><PAD>46/<PAD><PAD>substrate<PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>an occursedges<PAD><PAD><PAD><PAD><PAD>statistics<PAD><PAD>wh<PAD><PAD>entries<PAD>ilwalk"

In [None]:
#model.save('model.pt')