In [None]:
print("Hello world!")

In [3]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import numpy as np
import requests
import re

@dataclass
class Config:
    d_model:int
    d_vocab:int
    d_hidden:int
    max_seq_len:int
    numTrans:int

In [9]:

class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(config.d_hidden, config.d_model)

    def forward(self, x):
        x = self.fc2(self.act(self.fc1(x)))
        return x
    
class Attention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.Wqk = nn.Parameter(torch.rand(config.d_model, config.d_model))
        self.Wov = nn.Parameter(torch.rand(config.d_model, config.d_model))

        mask = torch.triu(torch.ones(config.max_seq_len, config.max_seq_len),
                          diagonal=1
                          )
        mask = mask.masked_fill(mask==1, -float('inf'))
        self.register_buffer("M", mask)

    
    def forward(self, x): # x -> 
        temp = x @ self.Wqk @ x.T + self.M
        scores = torch.softmax(temp, dim=1)

        scores = scores @ x @ self.Wov

        return scores
    
class Transformer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.attn = Attention(config)
        self.mlp = MLP(config)

    def forward(self, x):
        res = self.mlp(x) + self.attn(x) + x
        return res
    
class LanguageModel(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.d_vocab, config.d_model)
        self.tbs = nn.ModuleList([Transformer(config) for i in range(self.config.numTrans)])
        #self.t1 = Transformer(config)
    
    def forward(self, x_tokens):
        x = self.embedding(x_tokens)
        temp = x
        for i in range(self.config.numTrans):
            temp = self.tbs[i](temp)
        return x

In [10]:
# test no. 1
config = Config(d_model=30, d_vocab=100, d_hidden=128, max_seq_len=3, numTrans=3)
model = LanguageModel(config)
x = torch.tensor([1, 5, 24])
res = model(x)
res

tensor([[-0.7790,  1.5111,  2.3221,  0.1331, -0.8261, -1.4524,  0.7142, -1.4659,
          0.9313, -0.1000,  0.6359,  0.8525,  0.6081,  0.5770,  0.2889, -0.9851,
         -1.3726,  0.0752, -0.1806, -0.0041, -0.1597, -1.7332, -0.1296, -1.3163,
         -0.6357,  0.0433,  0.1697,  0.9189,  0.9327, -0.3153],
        [ 1.0340,  0.1765, -1.4960,  2.4389,  0.9019, -0.8978,  1.2643,  0.6570,
         -0.5607,  0.5066, -0.5913,  0.3909, -0.7334,  2.0067,  0.1153, -1.5227,
         -0.8420, -0.0774, -1.3043, -0.1377, -0.4756, -0.6318,  0.5207, -0.3547,
          2.8538, -1.2115, -0.5458, -0.6359, -1.4277,  1.4784],
        [-1.2599, -0.0744,  2.0473, -3.0340,  0.5045,  0.3664,  0.3595, -1.5179,
         -0.5828,  0.5507,  0.0972,  0.6533,  0.7986,  0.9087, -0.1520, -0.0331,
         -1.2830,  0.9694,  0.7137, -0.3809,  0.4446,  1.5697,  0.4052,  0.6522,
          0.9760, -0.5865, -0.5780,  0.7453, -0.6120,  0.9173]],
       grad_fn=<EmbeddingBackward0>)

In [73]:
# get a dictionary with each of the 1000 most common english words. Swap out the file with other .txt files that just have words if you want.
def get_common_word_dict(f_name = 'texts/words1000.txt'):
    word_dict = {}
    with open(f_name,'r') as f:
        lines = f.readlines()
        i = 0
        for line in lines:
            word_dict[line.strip()] = i
            i+=1
    print(f"Created dictionary with {i} words.")
    return word_dict

#get a 1d torch tensor of tokens from a sequence of words
# if you give it an empty dictionary it will create one for you with the words from the sentence.
def tokenize_sentence(sentence, dictionary={}):
    sentence_arr = re.split('-|\\. |, | |\n', sentence) #split on any of these possible delimiters we may see
    tokens = [-1 for _ in range(len(sentence_arr))]
    print(sentence_arr)
    for i, word in enumerate(sentence_arr):
        word = word.lower()
        #get rid of non alphanumeric characters for now
        if not(word.isalnum()):
            pattern = r'[^a-zA-Z0-9]' 
            replacement = ''
            word = re.sub(pattern, replacement, word)
        token = dictionary.get(word, -1)
        # if we don't know this word, add it to dictionary
        if token == -1:
            token = len(dictionary)
            dictionary[word] = token
        tokens[i] = token
    # make it a 1d tensor
    tokens = torch.tensor(tokens)
    return tokens

In [74]:
# example of using these functions
my_dict = get_common_word_dict()
word2test = 'language'
print(f"Getting token for word '{word2test}':",my_dict[word2test]) #the course is MATH498: Large Language Modles and the 498th most common word is apparently language which is funny
print("---------")
sentence = "typically this is completely random, but\nsometimes it could be learned." #excerpt from a lecture I was in when writing this
my_dict = get_common_word_dict('texts/google-10000-english-usa-no-swears.txt') #use bigger dictionary
print(f"Sentence to tokenize: '{sentence}'")
tokens = tokenize_sentence(sentence, my_dict)
print(f"Tokenized sentence: {tokens}")
print("---------")
sentence = "now we will tokenize a sentence without a dictionary" #excerpt from a lecture I was in when writing this
tokens = tokenize_sentence(sentence) #tokenize without a dictionary just assigns tokens to words
print(f"Tokenized sentence: {tokens}")


Created dictionary with 1000 words.
Getting token for word 'language': 498
---------
Created dictionary with 9884 words.
Sentence to tokenize: 'typically this is completely random, but
sometimes it could be learned.'
['typically', 'this', 'is', 'completely', 'random', 'but', 'sometimes', 'it', 'could', 'be', 'learned.']
Tokenized sentence: tensor([3836,   11,    7, 2318, 1853,   42, 1724,   15,  206,   18, 3264])
---------
['now', 'we', 'will', 'tokenize', 'a', 'sentence', 'without', 'a', 'dictionary']
Tokenized sentence: tensor([0, 1, 2, 3, 4, 5, 6, 4, 7])
