Предобработка

In [1]:
# !pip install torch torchvision torchaudio
# !pip install pandas

In [2]:
import numpy as np
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from functorch import vmap


In [3]:


class HeadAttention(nn.Module):
  def __init__(self, emb_size:int, head_size:int, max_seq_len:int ) -> None:
    '''
    •	emb_size (тип int) — размерность эмбедингов (позиционных и токинов).
    •	head_size (тип int) — размерность W-матриц (W_k, W_q, W_v).
    •	max_seq_len (тип int) — максимально возможная длина последовательности.
    '''
    super().__init__()
    self.emb_size = emb_size
    self.head_size = head_size
    self.max_seq_len = max_seq_len
    self.w_k = torch.nn.Linear(emb_size, head_size, bias=False)
    self.w_q = torch.nn.Linear(emb_size, head_size, bias=False)
    self.w_v = torch.nn.Linear(emb_size, head_size, bias=False)
    self.mask_attention = torch.tril(torch.ones(max_seq_len, max_seq_len))

  def forward(self, x):
    seq_len = len(x[0])
    self.key_matrix = self.w_k(x)
    self.que_matrix = self.w_q(x)
    self.val_matrix = self.w_v(x)


    self.att_matrix = torch.matmul(self.que_matrix, self.key_matrix.transpose(1,2))
    self.att_matrix /= np.sqrt(self.head_size)
    self.sub_mask_matrix = self.mask_attention[:seq_len, :seq_len]
    self.att_matrix = torch.where( self.sub_mask_matrix.bool(),  self.att_matrix,
        torch.tensor(float('-inf'), device=self.att_matrix.device, dtype=self.att_matrix.dtype)
    )
    self.att_matrix = torch.softmax(self.att_matrix, dim=2)
    self.result_tensor = torch.matmul(self.att_matrix, self.val_matrix)
    return self.result_tensor

In [4]:
num_heads=4
emb_size=8
head_size=8
max_seq_len=24
dropout=0.1
batch_size=1
seq_len = 12

t1 = torch.rand(batch_size, seq_len, emb_size)
# print(t1)
h = HeadAttention(emb_size, head_size, max_seq_len)
res = h.forward(t1)
# res


In [5]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads:int, emb_size:int, head_size:int, max_seq_len:int, dropout:float=0.1) -> None:
    '''
    •	num_heads (тип int) — количество голов.
    •	emb_size (тип int) — размерность эмбедингов (позиционных и токинов).
    •	head_size (тип int) — размерность W-матриц (W_k, W_q, W_v).
    •	max_seq_len (тип int) — максимально возможная длина последовательности.
    •	dropout (тип float, от 0.0 до 1.0) — вероятность обнулить значения тензора в слое dropout. Дефолтное значение = 0.1
    '''
    super().__init__()
    self.heads = nn.ModuleList([HeadAttention(emb_size, head_size, max_seq_len) for i in range(num_heads)])
    self.l1 = torch.nn.Linear(head_size * num_heads, emb_size)
    self.dr = torch.nn.Dropout(p=dropout)

  def forward(self, x):
    res = []
    for head in self.heads:
      res.append(head.forward(x))
    res_tensor = torch.cat(res, dim=2)
    return self.dr(self.l1(res_tensor))



In [6]:
num_heads=4
emb_size=8
head_size=8
max_seq_len=24
dropout=0.1
batch_size=1
seq_len = 12

t1 = torch.rand(batch_size, seq_len, emb_size)
# print(t1)
h = MultiHeadAttention(num_heads, emb_size, head_size, max_seq_len, dropout)
res = h.forward(t1)
res.shape

torch.Size([1, 12, 8])

In [7]:
class FeedForward(nn.Module):
  def __init__(self, emb_size:int, dropout:float=0.1) -> None:
    '''
    •	emb_size (тип int) — размерность эмбедингов (позиционных и токинов).
    •	dropout (тип float, от 0.0 до 1.0) — вероятность обнулить значения тензора в слое dropout. Дефолтное значение = 0.1
    '''
    super().__init__()
    # self.l1 = torch.rand(emb_size, emb_size * 4)
    self.l1 = torch.nn.Linear(emb_size, emb_size * 4)
    self.relu = torch.nn.ReLU()
    # self.l2 = torch.rand(emb_size * 4, emb_size)
    self.l2 = torch.nn.Linear(emb_size * 4, emb_size)
    self.dr = torch.nn.Dropout(p=dropout)

  def forward(self, x):
    # x = torch.matmul(x, self.l1)
    x = self.l1(x)
    x = self.relu(x)
    # x = torch.matmul(x, self.l2)
    x = self.l2(x)
    x = self.dr(x)
    return x



In [8]:
num_heads=4
emb_size=8
head_size=8
max_seq_len=24
dropout=0.1
batch_size=1
seq_len = 12

t1 = torch.rand(batch_size, seq_len, emb_size)
# print(t1)
h = FeedForward(emb_size, dropout)
res = h.forward(t1)
res.shape

torch.Size([1, 12, 8])

In [9]:
class Decoder(nn.Module):
  def __init__(self, num_heads:int, emb_size:int, head_size:int, max_seq_len:int, dropout:float=0.1) -> None:
    '''
    •	num_heads (тип int) — количество голов.
    •	emb_size (тип int) — размерность эмбедингов (позиционных и токинов).
    •	head_size (тип int) — размерность W-матриц (W_k, W_q, W_v).
    •	max_seq_len (тип int) — максимально возможная длина последовательности.
    •	dropout (тип float, от 0.0 до 1.0) — вероятность обнулить значения тензора в слое dropout. Дефолтное значение = 0.1
    '''
    super().__init__()
    self.multi_head = MultiHeadAttention(num_heads=num_heads,
                        emb_size=emb_size,
                        head_size=head_size,
                        max_seq_len=max_seq_len,
                        dropout=dropout)
    self.feed_forward = FeedForward(emb_size=emb_size, dropout=dropout)
    self.ln1 = torch.nn.LayerNorm(emb_size)
    self.ln2 = torch.nn.LayerNorm(emb_size)

  def forward(self, x):
    x += self.multi_head.forward(x)
    x = self.ln1(x)
    x += self.feed_forward.forward(x)
    x = self.ln2(x)

    return x


In [10]:
num_heads=4
emb_size=8
head_size=8
max_seq_len=24
dropout=0.1
batch_size=1
seq_len = 12


decoder = Decoder(num_heads=num_heads,
                  emb_size=emb_size,
                  head_size=head_size,
                  max_seq_len=max_seq_len,
                  dropout=dropout)

t1 = torch.rand(batch_size, seq_len, emb_size)
# print("t1")
# print(t1)

mh_res = decoder.forward(t1)
mh_res.shape

torch.Size([1, 12, 8])

In [11]:
class CustomGraph:
    _root = None
    _all_token_list = []

    def clear_root():
        CustomGraph._root = None
        CustomGraph._all_token_list = []

    def __init__(self, value = None, parent = None, parent_sec = None):
        if (value == None) & (parent == None) & (parent_sec == None):
            CustomGraph.clear_root()
            self._root = self
            self.type = 0
        elif (value != None) & (parent == None) & (parent_sec == None):
            self.parent = self._root
            self.value = value
            self.type = 1
            CustomGraph._all_token_list.append(self)
        elif (value == None) & (parent != None) & (parent_sec != None):
            self.parent = parent
            self.parent_sec = parent_sec
            self.type = 2
            CustomGraph._all_token_list.append(self)
        else:
            raise Exception("Wrong format")
        self.children = {}

    def __new__(cls, *args, **kwargs):
        if cls._root is None:
            cls._root = super().__new__(cls)
            return cls._root
        return super().__new__(cls)
    
    def __call__(self, *args, **kwds):
        if self.type == 0:
            return ""
        elif self.type == 1:
            return self.value
        elif self.type == 2:
            return self.parent() + self.parent_sec()
  
    def __getitem__(self, key):
        if key in self.children:
            return self.children.get(key)
        elif len(key) > 1:
            for i in range(1, len(key)):
                if key[:i] in self.children:
                    return self.children.get(key[:i])[key]
        else:
            return self.children.get(key)

    def __contains__(self, key):
        if key in self.children:
            return True
        elif len(key) > 1:
            for i in range(1, len(key)):
                if key[:i] in self.children:
                    return self.children.get(key[:i])[key]
        else:
            return False

    def __setitem__(self, key, value):
        self.children[key] = value

    def __len__(self):
        if len(self.children) == 0:
            return 1
        else:
            r = 0 if self.type == 0 else 1
            for i in self.children.values():
                r += len(i)
            return r

    def __add__(self, value):
        self.children[self()+value()] = CustomGraph(parent = self, parent_sec = value)
        return self.children[self()+value()] 

    def __str__(self):
        return self()

    def longest_child(self):
        if len(self.children) == 0:
            return 0 
        else:
            m = 0
            for i in self.children.values():
                if len(i.children) == 0:
                    l = len(i())
                else:
                    l = i.longest_child()
                if l > m:
                    m = l
            return m

    def closest_node(self, key):
        # print(f"key - {key}, self.children - {self.children}")
        if key in self.children:
            return self.children.get(key)
        elif len(key) > 1:
            # print(f"range(1, len(key)) - {range(1, len(key))}")
            if len(self.children) == 0:
                return self
            for i in range(1, len(key)):
                # print(f"key[:i] - {key[:i]}")
                if key[:i] in self.children:
                    # print(f"self.children.get(key[:i]) - {self.children.get(key[:i])}, children - {self.children.get(key[:i]).children}")
                    return self.children[key[:i]].closest_node(key)
            return self
        else:
            return self

    def all_children(self):
        if len(self.children) == 0:
            return []
        else:
            resmas = []
            for i in self.children.values():
                resmas.append(i())
                if len(i.children) != 0:
                    resmas += i.all_children()
            return resmas

    def add(self, value):
        if self.type == 0:
            if len(value) == 1:
                self.children[value] = CustomGraph(value = value)
                self.children[value].parent = self
            else:
                if value in self:
                    raise Exception('Token already in graph')
                else:
                    parent1 = self.closest_node(value)
                    # print(f"parent1 - '{parent1()}, value - {value}'")
                    parent2 = self.closest_node(value[len(parent1()):])
                    # print(f"parent2 - '{parent2()}, value[:len(parent1())] - {value[:len(parent1())]}'")
                    parent1 + parent2
        else:
            raise Exception("This function is availble only for root node")
    
    def print_children(self, indent):
        print(" "*indent, f"'{self()}'")
        for i in self.children.values():
            if i.parent is self:
                i.print_children(indent+1)

    def print_tree(self):
        self._root.print_children(0)

    def tokenize(self, text):
        i = 0
        resmas = []
        while i < len(text):
            # print(i)
            max_token_j = 0
            max_length = 0
            
            curnode = self[text[i]]
            max_len = curnode.longest_child()
            if max_len == 0:
                resmas.append(text[i])
                i += 1
            else:
                for j in range(i+1, min(len(text), i+1+max_len)):
                    # print(j,  min(len(text), i+max_len))
                    if text[i:j] in self:
                        # print(f"sub- {text[i:j]} node - {self.closest_node(text[i:j])} - {self.closest_node(text[i:j])()}") 
                        l = len(self.closest_node(text[i:j])())
                        if l > max_length:
                            max_length = l
                            max_token_j = j
                resmas.append(text[i:max_token_j])
                i = max_token_j
            # print(resmas)
        return resmas
    
    def pairing(self, text):
        token_list = self.tokenize(text)
        # print(f"token_list - {token_list}")
        pairs_list = [token_list[i] + token_list[i+1] for i in range(len(token_list)-1)]

        pairs_dict = {}
        max_val = 0
        max_ind = ""
        max_ind_pos = 0
        for i, val in enumerate(pairs_list):
            pairs_dict[val] = pairs_dict.get(val, [])
            pairs_dict[val].append(i)

            if (len(pairs_dict[val]) > max_val) or ( len(pairs_dict[val]) == max_val and pairs_dict[val][0] < max_ind_pos[0]) :
                max_val = len(pairs_dict[val])
                max_ind = val
                max_ind_pos = pairs_dict[val]
        return max_ind
    
    def id2token(self):
        return { i:j for i,j in enumerate(list(map(str, self._all_token_list)))}
    def token2id(self):
        return { j:i for i,j in enumerate(list(map(str, self._all_token_list)))}


 


In [12]:
g = CustomGraph()
g.add('a')
g.add('b')
g.add(' ')
g['a'] + g['a']


# g['a'].print_tree()
# g.tokenize("abaa bbbb")
# pairing = g.pairing("aa bbbb")
# g.add(pairing)
# g.tokenize("abaa bbbb")
# print(g.id2token())
# print(g.token2id())

<__main__.CustomGraph at 0x7e0150801110>

In [13]:
class BPE():
    def __init__(self, vocab_size: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.graph_root = CustomGraph()


    def fit(self, text:str):
        '''
        text (тип str) — корпус текста для обучения
        '''
        token_list = list(sorted(set(text)))
        # print(token_list)
        for i in token_list:
            self.graph_root.add(i)

        l1 = len(self.graph_root)
        while len(self.graph_root) < self.vocab_size:
            # print(self.graph_root.tokenize(text))
            pairing = self.graph_root.pairing(text)
            if pairing == "":
                break
            # print(f"pairing - '{pairing}'")
            self.graph_root.add(pairing)
            # self.graph_root.print_tree()
            if l1 == len(self.graph_root):
                raise Exception("No new data for tokenizer")
            l1 = len(self.graph_root)
        
        self.id2token = self.graph_root.id2token()
        self.token2id = self.graph_root.token2id()
    
    def encode(self, text):
        tokens = self.graph_root.tokenize(text)
        # return tokens
        return [self.token2id[token] for token in tokens]
        



bpe = BPE(28)
# bpe.fit("aa bb")
# bpe.fit("косил косой косы косиц и искоса косматый космос")
text = 'На дворе дрова, за двором дрова, дрова вширь двора, не вместит двор дров, надо дрова выдворить на дровяной двор.'
bpe
bpe.fit(text)

# print(bpe.id2token)
# print(bpe.token2id)
# print(list(map(str, bpe.graph_root._all_token_list)))
text1 = "вором дрова, дрова вширь двора, не вмест"
print(bpe.encode(text1))

[23, 22, 11, 25, 4, 1, 25, 4, 0, 5, 17, 9, 14, 19, 27, 4, 1, 0, 12, 7, 0, 5, 11, 7, 15, 16]


In [14]:
import dill

class BPE():
    def __init__(self, vocab_size: int):
        super().__init__()
        self.vocab_size = vocab_size


    def fit(self, text:str):
        '''
        text (тип str) — корпус текста для обучения
        '''
        token_list = list(sorted(set(text))) # Создали список первичных токенов
        text_list = [1] * len(text) # Создали список с длинами токенов

        while len(token_list) < self.vocab_size: # Пока не заполнится словарь:

            i = 0 # Начинаем новую итерацию с нуля
            prev_token = text[i:i + text_list[i]] # Собираем первый токен
            prev_token_pos = i
            i += text_list[i]

            token_dict = {} # Словарь токенов и их позиций
            max_tokens_count = 0  # Данные о самом частом токене - количество
            max_tokens_val = "" # Название
            max_tokens_pos = [] # Все позиции

            while i < len(text): # Проходим по всему тексту
                actual_token = text[i:i + text_list[i]] # Следующий токен
                actual_token_pos = i

                new_token = prev_token + actual_token # Новый токен из двух предыдущих
                new_token_pos = prev_token_pos
                token_dict[new_token] = token_dict.get(new_token, []) # Записываем данные о новом токене в словарь
                token_dict[new_token].append(new_token_pos)

                if len(token_dict[new_token]) > max_tokens_count or ( len(token_dict[new_token]) == max_tokens_count and token_dict[new_token][0] < max_tokens_pos[0] ): # Если выполнены условия
                    max_tokens_count = len(token_dict[new_token]) # Запоминаем самый частый токен
                    max_tokens_val = new_token
                    max_tokens_pos = token_dict[new_token]
                
                prev_token = actual_token # Актуальный токен -> Старый
                prev_token_pos = actual_token_pos
                i += text_list[i] # Передвигаем курсор

            if max_tokens_val == "":
                break
            token_list.append(max_tokens_val)
            for pos in max_tokens_pos: # В списке с длинами токенов меняем значения на позициях с нашим токеном
                text_list[pos] = len(max_tokens_val)
        self.token_list = token_list
        self.id2token = {i:j for i, j in enumerate(token_list)}
        self.token2id = {j:i for i, j in enumerate(token_list)}
            
    
    def encode(self, text):
        tokens = []
        i = 0

        while i < len(text):
            candidates = []
            for token in self.token_list:
                if token.startswith(text[i]):
                    candidates.append(token)
            
            last_candidate = candidates[0]
            len_last_candidate = len(last_candidate)
            for j in range(1, max(map(len, candidates))+1):
                t = text[i:i+j]
                l = len(t)
                if t in candidates and l > len_last_candidate:
                    last_candidate = t
                    len_last_candidate = l
            tokens.append(last_candidate)
            i += len(last_candidate)
        # return tokens
        return [self.token2id[token] for token in tokens]
    
    def decode(self, token_ids):
        return "".join([self.id2token[_id] for _id in token_ids])
    
    def save(self, filename):
        with open(filename, 'wb') as f:
            dill.dump(self, f)
        print(f"Объект сохранён в {filename}")

    @classmethod
    def load(cls, filename):
        with open(filename, 'rb') as f:
            obj = dill.load(f)
                
        print(f"Объект загружен из {filename}")
        return obj



bpe = BPE(28)
# bpe.fit("aa bb")
# bpe.fit("косил косой косы косиц и искоса косматый космос")
text = 'На дворе дрова, за двором дрова, дрова вширь двора, не вместит двор дров, надо дрова выдворить на дровяной двор.'
bpe.fit(text)

# print(bpe.id2token)
# print(bpe.token2id)
# print(list(map(str, bpe.graph_root._all_token_list)))
# text1 = "вором дрова, дрова вширь двора, не вмест"
# print(bpe.encode(text1))
# tokens = [23, 22, 11, 25, 4, 1, 25, 4, 0, 5, 17, 9, 14, 19, 27, 4, 1, 0, 12, 7, 0, 5, 11, 7, 15, 16]
# print(bpe.decode(tokens))


# bpe = BPE(vocab_size=1000)
# bpe.fit(text)
# bpe.save('data/bpe.dill')
# bpe = BPE.load('data/bpe.dill')

In [15]:

# bpe = BPE.load('data/bpe.dill')
# bpe.token_list

In [16]:
class TokenEmbeddings(torch.nn.Module):
    def __init__(self, vocab_size, emb_size,  *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.matrix = torch.nn.Embedding(vocab_size, emb_size)

    def forward(self, x):
        '''
        x - batch_size * seq_len
        '''
        return self.matrix(x)


In [17]:
class PositionalEmbeddings(torch.nn.Module):
    def __init__(self, max_seq_len , emb_size ,  *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_seq_len = max_seq_len
        self.emb_size = emb_size
        self.matrix = torch.nn.Embedding(max_seq_len, emb_size)
    def forward(self, seq_len ):
        '''
        x - batch_size * seq_len
        '''
        return self.matrix.weight[0:seq_len]
pe = PositionalEmbeddings(10, 12)
pe.forward(7)

tensor([[-7.3527e-01, -8.6134e-02, -5.3151e-01, -6.1255e-01,  7.4053e-01,
         -5.1576e-02,  1.2855e+00,  6.6857e-02, -2.0344e+00,  7.1355e-04,
          3.0162e-01,  1.2442e+00],
        [ 1.9120e+00, -2.5376e+00, -1.0149e+00, -1.0310e+00, -9.7896e-01,
         -7.4052e-01,  1.5298e-01,  1.7207e-01, -1.5266e+00,  1.3549e+00,
          4.5564e-01,  2.1546e+00],
        [ 1.2551e+00, -9.2100e-01, -3.7059e-02,  5.5143e-01,  1.1274e+00,
          1.7574e+00, -1.1630e+00,  8.4595e-01, -1.1852e+00, -7.8014e-01,
          1.3317e+00, -8.5087e-01],
        [-1.5731e+00,  3.3233e-01,  4.1538e-01,  9.0038e-02, -8.2872e-02,
         -3.9480e-01,  1.1202e+00, -1.0003e+00, -2.1596e-01, -1.5017e+00,
          3.5439e-01,  6.7743e-02],
        [-4.7294e-01,  9.0389e-01, -4.5013e-01, -1.8756e-01, -1.6522e+00,
         -1.1799e-01, -3.4191e-01, -7.6679e-01, -8.4476e-01, -1.7725e+00,
         -1.3085e+00, -4.7036e-01],
        [ 6.6501e-01,  4.5917e-01,  7.2080e-01,  1.3233e-01, -3.8343e-01,
      

In [None]:
class GPT(torch.nn.Module):
    def __init__(self, 
                 vocab_size: int,
                 max_seq_len: int, 
                 emb_size: int, 
                 num_heads: int, 
                 head_size: int, 
                 num_layers: int, 
                 dropout: float = 0.1,
                 device:str = "cpu",
                  *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.token_embeddings = TokenEmbeddings(vocab_size, emb_size)
        self.positional_embeddings = PositionalEmbeddings(max_seq_len, emb_size)
        self.dropout = torch.nn.Dropout(dropout)
        self.decoders = torch.nn.Sequential(*[Decoder(num_heads, emb_size, head_size, max_seq_len) for i in range(num_layers)])
        self.linear = torch.nn.Linear(emb_size, vocab_size)
        self.max_seq_len = max_seq_len
    def forward(self, x):
        '''
        Получает на вход последовательность x (тип int) размером batch_size x seq_len. Где:
            batch_size — количество батчей.
            seq_len — длина входящей последовательности.
        '''
        emb_tokens = self.token_embeddings.forward(x)
        emb_positi = self.positional_embeddings(x.shape[1])
        embedding = emb_tokens + emb_positi
        x = self.dropout(embedding)
        x = self.decoders(x)
        x = self.linear(x)
        return x
    def generate(self, x, max_new_tokens ):
        for i in range(max_new_tokens):
            new_x = x[:, -self.max_seq_len:]
            logits = self.forward(new_x)
            maxed = torch.softmax(logits[:,-1, :], dim=1)
            arg_max = torch.argmax(maxed, dim=1)
            arg_max = torch.reshape(arg_max, (maxed.shape[0], 1))
            x = torch.concat([x, arg_max], dim=1)
        return 

num_heads=5
emb_size=12
head_size=8
max_seq_len=40
dropout=0.1
batch_size=2
seq_len = 12
vocab_size = 15
num_layers = 5
device = "gpu"
max_new_tokens = 5

gpt = GPT(
    num_heads = num_heads,
    emb_size = emb_size,
    head_size = head_size,
    max_seq_len = max_seq_len,
    dropout = dropout,
    vocab_size = vocab_size,
    num_layers = num_layers,
    device = device
)


In [66]:

t1 = torch.randint(0, 14, (batch_size, seq_len))
print(t1)
gpt.generate(t1, max_new_tokens)

tensor([[11, 13,  9,  2,  0,  9, 11,  1,  6,  9,  7, 10],
        [ 3,  0,  8, 10,  1,  7, 11,  0, 13,  2, 11,  4]])


tensor([[11, 13,  9,  2,  0,  9, 11,  1,  6,  9,  7, 10,  4,  1,  5,  1,  5],
        [ 3,  0,  8, 10,  1,  7, 11,  0, 13,  2, 11,  4,  4, 13,  2,  1,  5]])