In [None]:
print("Hello world!")

In [3]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import numpy as np
import requests
import re

@dataclass
class Config:
    d_model:int
    d_vocab:int
    d_hidden:int
    max_seq_len:int
    numTrans:int

In [None]:

class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(config.d_hidden, config.d_model)

    def forward(self, x):
        x = self.fc2(self.act(self.fc1(x)))
        return x
    
class Attention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.Wqk = nn.Parameter(torch.rand(config.d_model, config.d_model))
        self.Wov = nn.Parameter(torch.rand(config.d_model, config.d_model))

        mask = torch.triu(torch.ones(config.max_seq_len, config.max_seq_len),
                          diagonal=1
                          )
        mask = mask.masked_fill(mask==1, -float('inf'))
        self.register_buffer("M", mask)

    
    def forward(self, x): # x -> 
        temp = x @ self.Wqk @ x.T + self.M
        scores = torch.softmax(temp, dim=1)

        scores = scores @ x @ self.Wov

        return scores
    
class Transformer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.attn = Attention(config)
        self.mlp = MLP(config)

    def forward(self, x):
        res = self.mlp(x) + self.attn(x) + x
        return res
    
class LanguageModel(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.d_vocab, config.d_model)
        self.tbs = nn.ModuleList([Transformer(config) for i in range(self.config.numTrans)])
        #self.t1 = Transformer(config)
    
    def forward(self, x_tokens):
        x = self.embedding(x_tokens)
        temp = x
        for i in range(self.config.numTrans):
            temp = self.tbs[i](temp)
        return x

In [8]:
# test no. 1
config = Config(d_model=30, d_vocab=100, d_hidden=128, max_seq_len=3, numTrans=3)
model = LanguageModel(config)
x = torch.tensor([1, 5, 24])
res = model(x)