In [None]:
import jittor as jt
from jittor import nn
import math

import sys
import os

current_dir = os.path.abspath('.')
project_root = current_dir

for _ in range(2):
    project_root = os.path.dirname(project_root)
print(f"Project root determined as: {project_root}")

if project_root in sys.path:
    sys.path.remove(project_root)

sys.path.insert(0, project_root)
        

Project root determined as: /home/jittor/SCC_Model/ViT


In [5]:
from config import Config
config = Config()

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self,embedded_dim,num_heads,dropout_rate=config.DROPOUT):
        super(MultiHeadSelfAttention,self).__init__()
        self.embedded_dim = embedded_dim
        self.num_heads = num_heads
        self.every_head_dim = embedded_dim // num_heads
        self.scale = math.sqrt(self.every_head_dim)
        self.dropout_rate = dropout_rate

        self.qkv_layer = nn.Linear(embedded_dim,embedded_dim*3)
        self.out_layer = nn.Linear(embedded_dim,embedded_dim)

        self.attn_drop = nn.Dropout(dropout_rate)


    def execute(self,x:jt.Var,mask:jt.Var=None)->jt.Var:
        B,N,C = x.shape
        # [B,N,embedded_dim]->[B,N,3*embedded_dim]
        qkv = self.qkv_layer(x)  
        # [B,N,3*embedded_dim]->[B,N,3,num_heads,every_head_dim]
        qkv = qkv.view(B,N,3,self.num_heads,self.every_head_dim)
        #[B,N,3,num_heads,every_head_dim]-> 3* [B,N,num_heads,every_head_dim]
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
        #[B,N,num_heads,every_head_dim]->[B,num_heads,N,every_head_dim]
        q=q.permute(0,2,1,3)
        k=k.permute(0,2,1,3)
        v=v.permute(0,2,1,3)

        #计算注意力得分，得分维度为[B, num_heads, N, N]
        attention_scores = jt.matmul(q,k.transpose(0,1,3,2))/self.scale
        #应用mask（如果有的话）
        if mask is not None:
            mask = mask.view(mask.shape[0], 1, 1, mask.shape[1])
            attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
        
        #应用softmax
        attention_scores = nn.softmax(attention_scores, dim=-1)
        attention_scores = self.attn_drop(attention_scores)
        #计算注意力加权值
        # [B,num_heads,N,N]*[B,num_heads,N,every_head_dim]->[B,num_heads,N,every_head_dim]
        attention_output = jt.matmul(attention_scores, v)
        # [B,num_heads,N,every_head_dim]->[B,N,embedded_dim]
        attention_output = attention_output.permute(0,2,1,3).reshape(B,N,C)
        #通过输出线性层
        output = self.out_layer(attention_output)
        #out的维度为[B,N,embedded_dim]
        return output   




In [14]:

# 测试参数
batch_size = 2
seq_length = 5
embedded_dim = 12
num_heads = 4

# 创建测试数据
x = jt.randn(batch_size, seq_length, embedded_dim)
mask = jt.ones(batch_size, seq_length)  # 改回正确的形状 [batch_size, seq_length]

# 创建多头注意力层
attention = MultiHeadSelfAttention(embedded_dim, num_heads)

# 运行测试
output = attention(x, mask)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("\n输入示例:")
print(x[0])  # 打印第一个batch的数据
print("\n输出示例:")
print(output[0])  # 打印第一个batch的输出



Compiling Operators(4/23) used: 2.31s eta:   11s 

输入形状: [2,5,12,]
输出形状: [2,5,12,]

输入示例:


Compiling Operators(6/23) used: 3.32s eta: 9.41s 7/23) used: 4.33s eta: 9.89s 9/23) used: 5.33s eta: 8.29s 11/23) used: 6.34s eta: 6.91s 13/23) used: 7.34s eta: 5.65s 15/23) used: 8.35s eta: 4.45s 17/23) used: 9.35s eta:  3.3s 20/23) used: 10.4s eta: 1.55s 22/23) used: 11.4s eta: 0.517s 23/23) used: 12.4s eta:    0s 


jt.Var([[ 3.1880727e-01 -9.3096596e-01 -3.5460648e-01  2.0113297e+00
         -4.7587970e-01 -3.5760027e-01  2.7498421e-01  6.1297852e-01
          2.5876823e-01 -9.6165931e-01  1.6200861e+00  6.6917324e-01]
        [-7.3587066e-01  1.0585311e-02  3.3337894e-01  7.5165236e-01
          4.5378679e-01  1.5888499e+00  9.6721357e-01  4.2898846e-01
         -9.3403971e-03  4.7813603e-01 -1.0119213e+00 -2.6517095e-02]
        [ 8.2607633e-01 -3.2819703e-01 -5.7165623e-01 -2.2431624e+00
          2.0486769e-01  3.3215874e-01 -2.8591242e-01 -1.1700377e+00
          7.4926531e-01  8.9712590e-03 -4.5629728e-01 -4.3678969e-01]
        [ 1.6747835e-01 -1.8246700e+00 -7.3218602e-01 -1.6868219e+00
         -9.9124110e-01  9.7814798e-01  1.0903381e+00 -7.1771753e-01
          2.3294131e-01  1.0095409e+00  2.8932983e-01 -2.5322038e-01]
        [ 7.7399069e-01  1.4942355e+00 -9.5671952e-01  4.8368138e-01
          1.5254791e-01  9.6693464e-02  7.6289165e-01  2.0495662e-01
         -9.7108068e-04  5.842