<a href="https://colab.research.google.com/github/SaeSimcheon/torch/blob/main/Decoder_only_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

07/08 24분 39초

In [None]:
import math
import torch

class config :
  hidden_size = 512
  num_layers = 4
  num_heads = 8
  voca_size = 10000
  ffn_size = 4*hidden_size
  block_size = 16
  batch_size = 8
  dropout_rate = 0.1
  device = 'cpu'

In [None]:
class MultiHeadAttention(torch.nn.Module):
  def __init__(self,config):
    super().__init__()
    self.config = config
    self.affine_q = torch.nn.Linear(config.hidden_size,config.hidden_size)
    self.affine_k = torch.nn.Linear(config.hidden_size,config.hidden_size)
    self.affine_v = torch.nn.Linear(config.hidden_size,config.hidden_size)
    self.dropout_alpha = torch.nn.Dropout(config.dropout_rate)
    self.affine_o = torch.nn.Linear(config.hidden_size,config.hidden_size)

  def forward(self,query,key,value):
    B,L =query.shape[0],query.shape[1]
    S = key.shape[1]
    H = self.config.hidden_size
    NH = self.config.num_heads
    HH = H // NH

    q = self.affine_q(query).view(B,L,NH,HH).transpose(2,1)
    k = self.affine_k(key).view(B,S,NH,HH).transpose(2,1)
    v = self.affine_v(value).view(B,S,NH,HH).transpose(2,1)

    qk = torch.matmul(q,k.transpose(3,2)) / HH
    mask = torch.triu(torch.full((L,S),float('-inf'),device= qk.device),diagonal=1)
    qk += mask.view(1,1,L,S)

    alpha = torch.softmax(qk,dim = -1)
    alpha = self.dropout_alpha(alpha)

    scores = torch.matmul(alpha,v).transpose(2,1).reshape(B,L,H)
    output = self.affine_o(scores)
    return output

class Layer(torch.nn.Module):
  def __init__(self,config):
    super().__init__()
    self.config = config
    self.norm_mha = torch.nn.LayerNorm(config.hidden_size)
    self.mha = MultiHeadAttention(config)
    self.dropout_mha = torch.nn.Dropout(config.dropout_rate)

    self.norm_ffn = torch.nn.LayerNorm(config.hidden_size)
    self.ffn = torch.nn.Linear(config.hidden_size,config.ffn_size)
    self.gelu = torch.nn.functional.gelu
    self.dropout_ffn = torch.nn.Dropout(config.dropout_rate)
    self.rev_ffn = torch.nn.Linear(config.ffn_size,config.hidden_size)

  def _mha(self,inputs):
    return self.dropout_mha(self.mha(inputs,inputs,inputs))

  def _ffn(self,inputs):
    inter = self.gelu(self.ffn(inputs))
    return self.rev_ffn(self.dropout_ffn(inter))

  def forward(self,inputs):
    x = inputs

    x = x + self._mha(self.norm_mha(x))
    x = x + self._ffn(self.norm_ffn(x))
    return x

class Block(torch.nn.Module):
  def __init__(self,config):
    super().__init__()
    self.config = config
    self.layers = torch.nn.ModuleList([Layer(config) for _ in range(config.num_layers)])

  def forward(self,inputs):
    x = inputs

    for layer in self.layers :
      x = layer(x)
    return x

class Model(torch.nn.Module):
  def __init__(self,config):
    super().__init__()
    self.config = config
    self.pos_embedding = torch.nn.Embedding(config.block_size,config.hidden_size)
    self.voc_embedding = torch.nn.Embedding(config.voca_size,config.hidden_size)
    self.block = Block(config)
    self.output = torch.nn.Linear(config.hidden_size,config.voca_size)

  def forward(self,inputs):

    position = torch.arange(0,self.config.block_size,dtype= torch.int,device = inputs.device)
    pos=self.pos_embedding(position)
    emb=self.voc_embedding(inputs)

    emb += pos.view(1,pos.shape[0],pos.shape[1])

    embedded=self.block(emb)
    output = self.output(embedded)

    return output



In [None]:
#query = torch.randn(config.batch_size,config.block_size,config.hidden_size)
inputs = torch.randint(0,config.voca_size,(config.batch_size,config.block_size),dtype = torch.int,device = config.device)
model = Model(config)
model(inputs)

tensor([[[-1.7517e-01, -3.4024e-01,  2.4313e-01,  ...,  3.4671e-01,
           3.7785e-01,  6.8465e-01],
         [-8.7872e-01,  2.3702e-02,  6.1253e-02,  ...,  6.9860e-01,
          -2.2357e-01, -9.2825e-01],
         [ 2.0150e-01,  8.4275e-01, -1.5915e+00,  ..., -1.1473e+00,
          -1.4977e+00, -1.6272e-01],
         ...,
         [ 9.5123e-01, -4.0271e-01,  9.4105e-01,  ..., -2.2509e-01,
           1.8943e-01,  5.7453e-01],
         [-1.1927e+00,  1.2911e+00, -1.4250e+00,  ...,  9.0217e-02,
           4.5180e-01, -1.2284e+00],
         [-1.6026e+00,  1.3862e+00, -7.5816e-01,  ..., -9.9520e-01,
           4.5855e-01,  1.1881e-01]],

        [[-7.0109e-01,  5.0281e-01,  5.2916e-02,  ...,  2.1408e+00,
          -7.0497e-01, -6.2794e-03],
         [-3.2927e-01,  1.2641e-01,  4.9392e-01,  ...,  1.0467e+00,
           8.7958e-02,  5.2763e-01],
         [ 3.8311e-01, -5.7107e-01, -1.0126e+00,  ..., -2.0388e-01,
          -9.5306e-01, -3.2201e-01],
         ...,
         [ 1.0934e+00, -9