In [1]:
import torch
from torch import nn

In [2]:
import math

In [3]:
# !pip install torch

In [4]:
n_embd = 768
n_head = 12


In [5]:
768 // 12

64

In [6]:
a = torch.randn(2, 3, 6, 8)
a1, a2 = torch.chunk(a, 2, dim=-1)
a1.shape

torch.Size([2, 3, 6, 4])

In [7]:
class diff_tranformer(nn.Module):
  def __init__(self):
    super().__init__()

    self.hd = n_embd // n_head

    self.wq = nn.Linear(n_embd, n_embd)
    self.wk = nn.Linear(n_embd, n_embd)
    self.wv = nn.Linear(n_embd, n_embd)
    self.wo = nn.Linear(n_embd, n_embd)

    self.lq1 = nn.Parameter(torch.randn(self.hd))
    self.lq2 = nn.Parameter(torch.randn(self.hd))

    self.lk1 = nn.Parameter(torch.randn(self.hd))
    self.lk2 = nn.Parameter(torch.randn(self.hd))

    self.l_init = 0.8

    self.l = torch.exp(torch.dot(self.lq1 ,self.lk1)) - torch.exp(torch.dot(self.lq2 ,self.lk2)) + self.l_init # (1, 1, hd)


  def forward(self, x):
    bs, sl, d = x.shape

    # print(self.l, self.l.shape)

    q = self.wq(x)
    k = self.wk(x)
    v = self.wv(x)

    q = q.view(bs, sl, n_head, self.hd) # bs, sl, nh, hd
    k = k.view(bs, sl, n_head, self.hd)
    v = v.view(bs, sl, n_head, self.hd)

    q1, q2 = torch.chunk(q, 2, dim=-1)
    k1, k2 = torch.chunk(k, 2, dim=-1) # bs, sl, nh, hd/2 --> bs, nh, sl , hd/2 --> bs, nh, sl, sl
    v = v.transpose(1, 2)

    q1 = q1.transpose(1, 2)
    q2 = q2.transpose(1, 2)
    k1 = k1.transpose(1, 2)
    k2 = k2.transpose(1, 2)

    attn1 = (q1 @ k1.transpose(-2, -1)) / math.sqrt(self.hd)
    attn1 = attn1.softmax(dim=-1)

    attn2 = (q2 @ k2.transpose(-2, -1)) / math.sqrt(self.hd)
    attn2 = attn2.softmax(dim=-1)

    # print(attn1.shape, attn2.shape)

    attn = attn1 - (self.l * attn2)

    # print(attn.shape, (self.l * attn2).shape)

    attn = attn @ v # bs, nh, sl, sl --> bs, nh, sl, hd

    attn = attn.transpose(1, 2) # bs, nh, sl, hd --> bs, sl, nh, hd --> bs, sl, dim
    attn = attn.reshape(bs, sl, d)

    attn = self.wo(attn)

    return attn


In [8]:
a = torch.randn(2, 5, n_embd)
b = diff_tranformer()
c = b(a)
c.shape

torch.Size([2, 5, 768])

In [9]:
a = torch.randn(4)
b = torch.randn(4)
c = a * b
d = torch.dot(a, b)
a, b, c, d, c.shape, d.shape

(tensor([-2.2390, -0.8825,  0.9771,  1.1039]),
 tensor([ 0.7271, -0.1978, -0.2259,  0.3331]),
 tensor([-1.6280,  0.1746, -0.2208,  0.3677]),
 tensor(-1.3064),
 torch.Size([4]),
 torch.Size([]))