## blocks

In [1]:
import torch.nn.functional as F
import torch
from torch import nn

class selfAttention(nn.Module):
    
    def __init__(self, num_head, embed_dim):
        super(selfAttention, self).__init__()
        self.num_head = num_head
        self.embed_dim = embed_dim
        # prepare method:(init W matrix)
        self.q_lin = nn.Linear(embed_dim, num_head * embed_dim, bias=False)
        self.k_lin = nn.Linear(embed_dim, num_head * embed_dim, bias=False)
        self.v_lin = nn.Linear(embed_dim, num_head * embed_dim, bias=False)
        self.softmax = nn.Softmax(dim = -1)
        self.concat_m = nn.Linear(num_head * embed_dim, embed_dim, bias=False)
            
    def forward(self, x, masked):
        num_batch, num_vocab, _ = x.shape

        query = self.q_lin(x).reshape(num_batch, num_vocab, self.num_head, self.embed_dim)
        key = self.k_lin(x).reshape(num_batch, num_vocab, self.num_head, self.embed_dim)
        value = self.v_lin(x).reshape(num_batch, num_vocab, self.num_head, self.embed_dim)
#         print(key.shape)
        raw_w = torch.einsum('bvhe,bwhe->bvhw', query, key)/(self.embed_dim*self.num_head)**(1/2)
#         print(raw_w.shape)
        #raw_w = torch.bmm(query, key.transpose(2,3))/torch.sqrt(self.embed_dim)
        
        if masked == True:
            mask_idx = torch.triu_indices(num_vocab, num_vocab, offset=1)
            raw_w[:, mask_idx[0], :, mask_idx[1]] = float('-inf')
        
        w = self.softmax(raw_w)
#         print(w.transpose(1,2))
        
        #y = torch.bmm(v, w)
        y = torch.einsum('bvhe, bvhv->bvhe', value, w)
#         print(y.shape)
        y_concat = y.reshape(num_batch, num_vocab, self.num_head*self.embed_dim)
        z = self.concat_m(y_concat)
        return z
    
    
    
    def en_de_forward(self, x, z, masked):
        num_batch, num_vocab, _ = x.shape

        query = self.q_lin(x).reshape(num_batch, num_vocab, self.num_head, self.embed_dim)
        key = self.k_lin(z).reshape(num_batch, num_vocab, self.num_head, self.embed_dim)
        value = self.v_lin(z).reshape(num_batch, num_vocab, self.num_head, self.embed_dim)
#         print(key.shape)
        raw_w = torch.einsum('bvhe,bwhe->bvhw', query, key)/(self.embed_dim*self.num_head)**(1/2)
#         print(raw_w.shape)
        #raw_w = torch.bmm(query, key.transpose(2,3))/torch.sqrt(self.embed_dim)
        
        if masked == True:
            mask_idx = torch.triu_indices(num_vocab, num_vocab, offset=1)
            raw_w[:, mask_idx[0], :, mask_idx[1]] = float('-inf')
        
        w = self.softmax(raw_w)
#         print(w.transpose(1,2))
        
        #y = torch.bmm(v, w)
        y = torch.einsum('bvhe, bvhv->bvhe', value, w)
#         print(y.shape)
        y_concat = y.reshape(num_batch, num_vocab, self.num_head*self.embed_dim)
        z_out = self.concat_m(y_concat)
        return z_out
        

In [2]:
class FeedFoward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super(FeedFoward, self).__init__()
        #self.embed_dim = embed_dim
        self.lin_1 = nn.Linear(embed_dim, ff_dim, bias=False)
        self.relu = nn.ReLU()
        self.lin_2 = nn.Linear(ff_dim, embed_dim, bias=False)
        
    def forward(self, z):
        z_lin1 = self.lin_1(z)
#         print(z_lin1.shape)
        z_relu = self.relu(z_lin1)
#         print(z_relu.shape)
        z_ff = self.lin_2(z_relu)
        
        return z_ff

In [3]:
class PositionEncode(nn.Module):
    pass

class Embed_Token_Position(nn.Module):
    def __init__(self, num_vocab, seq_len, embed_dim):
        super(Embed_Token_Position, self).__init__()
        self.voc_embed = nn.Embedding(num_vocab, embed_dim)
        print(num_vocab, embed_dim)
        self.pos_embed = nn.Embedding(seq_len, embed_dim)
#         self.token_emb = nn.Embedding(num_tokens, k)
#         tokens = self.token_emb(x)
        
    def forward(self, x):
        num_batch, num_vocab = x.shape
        p = torch.arange(num_vocab)
        #print(x.dtype, torch.tensor(x.shape))
        voc = self.voc_embed(x)
        #print(voc.shape)
        pos = self.pos_embed(p)[None, :, :].expand(voc.shape)
        
        out = voc + pos
        
        return out

In [4]:
n_b, n_h, n_v, e_d = 6, 8, 3, 4
el, dl = 6, 6
seqlen=n_v
ff_dim, num_class = 10, 9
example_x = torch.randint(n_v, (n_b, n_v))#, e_d))
# example= torch.randn(n_b, n_v, e_d)
etp = Embed_Token_Position(n_v, seqlen, e_d)
out = etp(example_x)
print(out.shape)
# ve = nn.Embedding(n_v, e_d)
# vo = ve(example_x)
#print(vo.shape)

3 4
torch.Size([6, 3, 4])


## Encoder

In [5]:
class Encoder(nn.Module):
    def __init__(self, num_head, embed_dim, ff_dim, drop_prob):
        super(Encoder, self).__init__()
        # need layers and param setting

        self.attention = selfAttention(num_head, embed_dim)
        self.layernorm = nn.LayerNorm(normalized_shape=embed_dim)
        self.feedforward = FeedFoward(embed_dim, ff_dim)
        self.drop = nn.Dropout(drop_prob)
#         model = nn.Sequantial(selfAttention(num_head, embed_dim),
#                               nn.LayerNorm()
#                               FeedFoward(embed_dim, hidden_dim)
#                               nn.LayerNorm()
#         )
        
    def forward(self, x):
#         self.layernorm = nn.LayerNorm(normalized_shape = x.shape)
        z = self.drop(self.attention.forward(x, False))
        z_1 = self.layernorm(z+x)
        
        z_2 = self.drop(self.feedforward.forward(z_1))
        out = self.layernorm(z_2+z_1)
        
        return out
        
        

## Decoder

In [6]:
class Decoder(nn.Module):
    def __init__(self, num_head, embed_dim, ff_dim, drop_prob):
        super(Decoder, self).__init__()
        self.attention = selfAttention(num_head, embed_dim)
        self.feedforward = FeedFoward(embed_dim, ff_dim)
        self.layernorm = nn.LayerNorm(normalized_shape = embed_dim)
        self.drop = nn.Dropout(drop_prob)
        
    def forward(self, x, z):
        assert x.shape == z.shape
#         self.layernorm = nn.LayerNorm(normalized_shape = x.shape)
        de_z0 = self.drop(self.attention.forward(x, True))
        de_out1 = self.layernorm(de_z0+x)
        
        de_z1 = self.drop(self.attention.en_de_forward(de_out1, z, False))
        de_out2 = self.layernorm(de_z1+de_out1)
        
        de_z2 = self.drop(self.feedforward.forward(de_out2))
        out = self.layernorm(de_z2+de_out2)
        
        return out

## Tranformer

In [7]:
class Classify_Transformer(nn.Module):
    def __init__(self, num_head, num_vocab, seq_len, embed_dim, ff_dim, 
                 num_class, en_layers, de_layers, drop_p):
        super(Classify_Transformer, self).__init__()
        self.tp = Embed_Token_Position(num_vocab, seq_len, embed_dim)
        self.drop = nn.Dropout(drop_p)
        self.encoders = nn.ModuleList([Encoder(num_head, embed_dim, ff_dim, drop_p) for i in range(en_layers)])
        self.decoders = nn.ModuleList([Decoder(num_head, embed_dim, ff_dim, drop_p) for i in range(de_layers)])
#         self.encoder = Encoder(num_head, embed_dim, ff_dim)
#         self.decoder = Decoder(num_head, embed_dim, ff_dim)
        self.linear = nn.Linear(embed_dim, num_class)
#         self.logsoftmax = F.log_softmax(dim=2)
        

    def forward(self, x):
        tokpos = self.drop(self.tp(x))
        z_en = self.encoders(tokpos)
        out_de = self.decoders(tokpos, z_en)
        score = F.log_softmax(self.linear(out_de), dim = 2)
        
        return score
        
    
class Generat_Transformer(nn.Module):
    def __init__(self, num_head, embed_dim, ff_dim, num_vocab):
        super(Generat_Transformer, self).__init__()
        pass

## simple check

In [8]:
# a = torch.randn(2, 3, 4*5)
# #b = a.reshape(1, 3, 2, 1)
# b=a.view(2, 3, 4, 5)
# c=a.reshape(2, 3, 4, 5)
# print(c.shape)
n_b, n_h, n_v, e_d = 6, 8, 3, 4
example= torch.randn(n_b, n_v, e_d)
# cc=nn.LayerNorm(normalized_shape=example.shape)
# cc(example)
# atention=selfAttention(n_h, e_d)
# z = atention.forward(example, False)

# ff = FeedFoward(e_d, 10)
# z_ff = ff.foward(z)

encod = Encoder(n_h, e_d, 10, 0.1)
z = encod.forward(example)

decod = Decoder(n_h, e_d, 10, 0.1)
out = decod.forward(example, z)
print(out.shape)
    

torch.Size([6, 3, 4])


In [9]:
B, C, H, W = 6, 3, 8, 3
#x = torch.randn(B, C, H, W)
x = torch.randn(2, 3, 2, 3)
indices = torch.triu_indices(C, W, offset=1) # row and col.
print(indices)
x[:, indices[0], :, indices[1]] = float('-inf')
x.transpose(1,2)

# y = torch.where(x > x.view(B, C, -1).mean(2)[:, :, None, None], torch.tensor([1.]), torch.tensor([0.]))
# print(x,'\n',x.view(B, C, -1).mean(2)[:, :, None, None])

tensor([[0, 0, 1],
        [1, 2, 2]])


tensor([[[[-1.2016,    -inf,    -inf],
          [ 1.9133, -0.0823,    -inf],
          [-0.5916,  0.3619,  0.1267]],

         [[-0.1518,    -inf,    -inf],
          [-0.4162,  0.3867,    -inf],
          [ 1.5206,  1.8693,  1.8352]]],


        [[[ 1.8713,    -inf,    -inf],
          [ 0.9486,  0.3666,    -inf],
          [-0.0640, -0.7142, -1.6808]],

         [[ 0.5539,    -inf,    -inf],
          [-0.9549, -1.3577,    -inf],
          [ 0.1060, -0.7263, -0.2244]]]])

In [10]:
b, t, k = 6, 3, 4
pos_emb = nn.Embedding(10, k)
positions = torch.arange(t)
print(positions)
positions = pos_emb(positions)[None, :, :].expand(b, t, k)
print(positions.shape)

tensor([0, 1, 2])
torch.Size([6, 3, 4])


In [11]:
n_b, n_h, n_v, e_d = 6, 8, 3, 4
el, dl = 6, 6
seqlen=n_v
ff_dim, num_class = 10, 9
example= torch.randn(n_b, n_v, e_d)
example_x = torch.randint(n_v, (n_b, n_v))
Trans = Classify_Transformer(n_h, n_v, seqlen, e_d, ff_dim, num_class, el, dl, 0.1)
out = Trans(example_x)

3 4


TypeError: forward() takes 1 positional argument but 2 were given

## Example

In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
        device
    )
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10
    model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(
        device
    )
    out = model(x, trg[:, :-1])
    print(out.shape)