In [38]:
import torch
import torch.nn as nn
from einops import einsum, rearrange, reduce, pack


In [None]:

class Linear(nn.Module) :
    def __init__(self, 
                 in_features:int, 
                 out_features:int, 
                 device:torch.device | None = None,
                 dtype:torch.dtype | None = None
                 ) -> None:
        super().__init__()
        # in out维度记录在模块内部
        # tensor位置和类型信息并不属于模块，而是跟着tensor走
        # 所以只是传给创建tensor的函数，需要这些信息时直接问tensor而非模块
        self.in_features = in_features
        self.out_features = out_features
        parameter_kwargs = {"device": device, "dtype": dtype}

        # 创建一块参数矩阵
        # 由于pytorch存储参数时是按行的，每行内容都会连在一起
        # 而进行乘法时一定是长为in_features的那一维去乘输入向量
        # 所以一定是让in_features作为行的长度，out_features作为行的数量
        self.W = nn.Parameter(
            torch.empty(
                out_features,
                in_features,
                **parameter_kwargs    # 注意输入的设备和类型信息传到了这里！
            ))
        
        # 初始化：方差为2/(d_in + d_out)，截断处在3个标准差
        var = 2.0 / (in_features + out_features)
        std = var ** 0.5
        nn.init.trunc_normal_(self.W, std=std, a=-3*std, b=3*std)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        # x的形状是 (..., d_in)
        # W的形状是 (d_out, d_in)
        # 输出形状是 (..., d_out)
        # 简单写法：return x @ self.W.T
        return einsum(x, self.W, "... d_in, d_out d_in -> ... d_out")   #和上面的等价
        # 注意：x前面可能有许多维度，但最后一维一定是输入维度d_in
        # 而为了方便计算，我们的W是d_out*d_in的
        # 所以正常需要转置W才能相乘！这里用einsum来指定怎么乘，可以避免手动转置

In [None]:
linear = Linear(4, 3)
# 测试如何取到模块的所在设备
print(linear.W.device)  # cpu

In [None]:
class Embedding(nn.Module):
    def __init__(self,
                 num_embeddings:int, # 词表大小
                 embedding_dim:int, # 隐藏空间大小
                 device:torch.device | None = None,
                 dtype:torch.dtype | None = None
                 ) -> None:
        super().__init__()
        self.embedding_matrix = nn.Parameter(
            torch.empty(
                num_embeddings,
                embedding_dim,
                device=device,
                dtype=dtype
            )
        )
        nn.init.trunc_normal_(
            self.embedding_matrix,
            a = -3,
            b = 3
        )

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding_matrix[token_ids]

Deliverable: Implement RMSNorm as a torch.nn.Module. We recommend the following interface:
def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None)
Construct the RMSNorm module. This function should accept the following parameters:
d_model: int Hidden dimension of the model
eps: float = 1e-5 Epsilon value for numerical stability
device: torch.device | None = None Device to store the parameters on
dtype: torch.dtype | None = None Data type of the parameters

def forward(self, x: torch.Tensor) -> torch.Tensor Process an input tensor of shape
(batch_size, sequence_length, d_model) and return a tensor of the same shape.
Note: Remember to upcast your input to torch.float32 before performing the normalization (and
later downcast to the original dtype), as described above.
To test your implementation, implement the test adapter at [adapters.run_rmsnorm]. Then, run uv
run pytest -k test_rmsnorm.


In [None]:
class RMSNorm(nn.Module):
    def __init__(self, 
                 d_model: int, 
                 eps: float = 1e-5, 
                 device:torch.device | None = None, 
                 dtype:torch.dtype | None = None):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        parameter_kwargs = {"device": device, "dtype": dtype}

        self.gain_parameter = nn.Parameter(
            torch.empty(
                d_model,
                **parameter_kwargs
            )
        )

        nn.init.ones_(self.gain_parameter)

        # 其实可以用torch.ones直接初始化，不过这里拆成两部分便于理解


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        in_dtype = x.dtype
        x = x.to(torch.float32)
        rms = torch.sqrt(reduce(x**2,"... d -> ... 1","mean") + self.eps) ** -1 
        # 将x的最后一维压缩为一个标量作为分母，但是保留这一维度，便于广播
        res = x * rms * self.gain_parameter
        # 第一个乘法是因为手动保留了维度所以才能进行的，第二个乘法会自动广播
        return res.to(in_dtype)



In [None]:
x = torch.tensor([1,2,3])
n = 3
print(type(x))
print(x)
print(x.dtype)

In [None]:
z = torch.sqrt(torch.sum(x**2) / 3)
print(z)
print(z.dtype)

In [None]:
y = torch.tensor([4,5,6])
w = x * y
print(w)
w = einsum(x,y,"dim1, dim1 -> dim1")
print(w)

In [None]:
x = torch.tensor([[1,2,3],[4,5,6]]).to(torch.float32)
y = torch.tensor([4,5,6]).to(torch.float32)
z = einsum(x, y, "a b, b -> a b")
z

In [None]:
x = torch.tensor([[[1,2,3,1],
                   [4,5,6,1],
                   [1,4,7,11]],

                  [[7,8,9,1],
                   [10,11,12,1],
                   [2,5,8,12]]])
x = x.to(torch.float32)
print(x)
print(x.shape)
print(x.dtype)

In [None]:
rms = torch.sqrt(reduce(x**2, "b s d -> b s 1", "mean") + 0.25)
print(rms)

In [None]:
x = x / rms
print(x)

In [None]:
g = torch.tensor([1,2,3,4])
z = x * g
g1 = rearrange(g, "d -> 1 1 d")
z1 = x * g1
z2 = einsum(x, g, "b s d, d -> b s d")
print(z)
print()
print(z1)
print()
print(z2)


In [None]:
x ** -1

In [None]:
nn.init.ones_(x)
print(x)

In [None]:
x = 1
y = x * 8 / 3
int(y)

In [None]:
def swish(x:torch.Tensor) -> torch.Tensor:
    return x * torch.sigmoid(x)

testtensor = torch.tensor([[[1,2,3,1],
                   [4,5,6,1],
                   [1,4,7,11]],

                  [[7,8,9,1],
                   [10,11,12,1],
                   [2,5,8,12]]])

x = testtensor.clone()
print(x)
print(swish(x))

In [None]:
def swish(x:torch.Tensor) -> torch.Tensor:
    return x * torch.sigmoid(x)

class SwiGLUFFN(nn.Module):
    def __init__(self, 
                 d_model:int,
                 d_ff:int,
                 device:torch.device | None = None,
                 dtype:torch.dtype | None = None) -> None:
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        parameter_kwargs = {"device":device,"dtype":dtype}

        self.linear1 = Linear(d_model,d_ff,**parameter_kwargs)
        self.linear2 = Linear(d_ff,d_model,**parameter_kwargs)
        self.linear3 = Linear(d_model,d_ff,**parameter_kwargs)
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:   
        activates = swish(self.linear1(x))
        gates = self.linear3(x)
        return self.linear2(activates * gates)

In [None]:
dq = 8
halfdq = int(dq/2)
maxlen = 10
THETA = 10000

import math
S = math.pow(THETA,-2/dq)
print(S)

thetas = [[i * math.pow(S,k-1) for k in range(1,halfdq+1)] for i in range(1,maxlen+1)]



import matplotlib.pyplot as plt
import numpy as np
thetas = np.array(thetas)

np.set_printoptions(precision=3, suppress=True)
print(thetas)
plt.imshow(thetas)
plt.colorbar()
plt.show()


In [None]:
itorch = torch.arange(1,maxlen+1)
print(itorch)
print(itorch.dtype, itorch.shape)
rearrange(itorch,"seq -> seq 1")
print(itorch)
print(itorch.shape)

In [None]:
ktorch = S ** torch.arange(halfdq)
thetastorch = einsum(itorch,ktorch,"seq, dimq -> seq dimq")
print(thetastorch)

In [None]:
d1 = 3
d2 = 5
d3 = 8
testtorch = rearrange(torch.arange(d1*d2*d3),"(d1 d2 d3)->d1 d2 d3",d1 = d1,d2=d2,d3=d3)
print(testtorch)
print(testtorch.dtype, testtorch.shape)

In [None]:
split = rearrange(testtorch,"d1 d2 (c2 half) -> d1 d2 half c2",c2=2)
print(split)

In [None]:
split[:,:,:,0]

theta: float Θ value for the RoPE
d_k: int dimension of query and key vectors
max_seq_len: int Maximum sequence length that will be inputted
device: torch.device | None = None Device to store the buffer on

def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor

Process an input tensor of shape (..., seq_len, d_k) and return a tensor of the same shape.
Note that you should tolerate x with an arbitrary number of batch dimensions. You should
assume that the token positions are a tensor of shape (..., seq_len) specifying the token
positions of x along the sequence dimension.
You should use the token positions to slice your (possibly precomputed) cos and sin tensors
along the sequence dimension.
To test your implementation, complete [adapters.run_rope] and make sure it passes uv run
pytest -k test_rope.

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    cosines: torch.Tensor
    sines: torch.Tensor
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None) -> None:
        # 初始化函数只需要负责根据输入的构建并存储起所有的正余弦值，便于使用
        super().__init__()
        assert d_k%2==0 , "dimension of Q and K should be even for RoPE"
        halfd = int(d_k / 2)

        positions = torch.arange(1,max_seq_len + 1, device=device) # 序列位置参数：1~maxlen
        S = math.pow(theta,-2/d_k)
        thetas = torch.pow(S,torch.arange(halfd,device=device)) # 所以theta的指数从0~大约-1
        thetas_with_position = einsum(positions,thetas,"maxlen, halfdk -> maxlen halfdk")

        self.register_buffer("cosines", 
                             torch.cos(thetas_with_position),   
                             persistent=False)
        self.register_buffer("sines", 
                             torch.sin(thetas_with_position),
                             persistent=False)
    
        self.sines

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        # 先把输入的最后一位按奇偶分开，分别进行乘法后重新进行线性组合
        # 使用sin和cos值时注意截断
        rearranged_x = rearrange(x,"... len (halfdk c2) -> ... len halfdk c2",c2=2)
        oddx = rearranged_x[1] # ... len halfdk
        evenx = rearranged_x[0] # ... len halfdk
        cut_cosines = self.cosines[token_positions] #
        cut_sines = self.sines[token_positions]
        # 三角函数阵： halfdk         
        rotated_oddx = einsum(oddx, self.cosines, "... len halfdk, len halfdk -> ... len halfdk") - \
                      einsum(evenx, self.sines, "... len halfdk, len halfdk -> ... len halfdk")
        rotated_evenx = einsum(evenx, self.cosines, "... len halfdk, len halfdk -> ... len halfdk") + \
                       einsum(oddx, self.sines, "... len halfdk, len halfdk -> ... len half dk")
        res = rearrange(pack([rotated_evenx, rotated_oddx], " ... len halfdk *")[0],
                        "... len halfdk c2 -> ... len (halfdk c2)")
        return res
    

In [35]:
t1 = torch.rand(2,4,3)
print(t1.shape)
t2 = torch.rand(2,4,3)

print(t1)
print(t2)
einsum(t1,t2,"a b c , a b c -> a b c")

torch.Size([2, 4, 3])
tensor([[[0.8907, 0.0834, 0.4916],
         [0.4165, 0.8122, 0.9902],
         [0.8252, 0.5467, 0.5307],
         [0.1978, 0.4191, 0.5772]],

        [[0.6628, 0.8670, 0.1062],
         [0.6607, 0.4346, 0.2545],
         [0.5123, 0.8700, 0.0817],
         [0.3729, 0.3809, 0.7405]]])
tensor([[[0.6365, 0.0560, 0.7873],
         [0.7644, 0.5188, 0.0126],
         [0.6020, 0.6399, 0.0708],
         [0.3007, 0.0652, 0.0684]],

        [[0.7699, 0.5737, 0.0032],
         [0.8280, 0.6983, 0.8861],
         [0.6605, 0.4224, 0.7843],
         [0.8201, 0.7434, 0.0064]]])


tensor([[[5.6693e-01, 4.6717e-03, 3.8707e-01],
         [3.1837e-01, 4.2139e-01, 1.2439e-02],
         [4.9682e-01, 3.4987e-01, 3.7585e-02],
         [5.9471e-02, 2.7320e-02, 3.9462e-02]],

        [[5.1025e-01, 4.9737e-01, 3.4173e-04],
         [5.4706e-01, 3.0348e-01, 2.2548e-01],
         [3.3839e-01, 3.6750e-01, 6.4079e-02],
         [3.0581e-01, 2.8317e-01, 4.7356e-03]]])

In [37]:
test = torch.arange(12)
test = rearrange(test,"(r c) -> r c",r = 3)
test

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE) layer.

    Args
    ----
    theta : float
        Base used to generate inverse frequencies (e.g. 10_000).
    d_k : int
        Dimension of the key / query vectors (must be even).
    max_seq_len : int
        Maximum sequence length expected at inference / training time.
    device : torch.device | None
        Where to place the pre-computed sine / cosine tables.
    """
    def __init__(self,
                 theta: float,
                 d_k: int,
                 max_seq_len: int,
                 device=None):
        super().__init__()
        if d_k % 2 != 0:
            raise ValueError("d_k must be even for RoPE.")
        self.d_k = d_k
        # ---- pre-compute inverse frequencies ----
        # freq[k] = 1 / theta ** (2k / d_k)          (k = 0,1,…,d_k/2-1)
        freq = 1.0 / (theta ** (torch.arange(0,d_k,2, device=device).float() / d_k))

        # shape: (max_seq_len, d_k // 2)
        positions = torch.arange(max_seq_len, device=device).float()
        freqs = torch.outer(positions, freq)

        # cache cos/sin; no gradients needed → persistent=False
        self.register_buffer('cos_cached', torch.cos(freqs),persistent=False) # persistent=False does not save to state_dict
        self.register_buffer('sin_cached', torch.sin(freqs), persistent=False)
    
    def forward(
        self,
        x: Float[Tensor, "... seq_len d_k"],
        token_positions: Int[Tensor, "... seq_len"]
        ) -> Float[Tensor, "... seq_len d_k"]:
        """
        Apply RoPE to `x`.  Works with any batch shape prefix.
        """
        # Check if the last dimension matches d_k
        if x.size(-1) != self.d_k:
            raise ValueError(f"Last dim of x ({x.size(-1)}) ≠ d_k ({self.d_k}).")
        
        # Gather the cached tables for the required positions
        cos_pos = self.cos_cached[token_positions]
        sin_pos = self.sin_cached[token_positions]

        # Split even / odd channels
        x_even = x[..., ::2]
        x_odd = x[..., 1::2]

        # Apply the 2-D rotation to each pair
        out_even = x_even * cos_pos - x_odd * sin_pos
        out_odd = x_even * sin_pos + x_odd * cos_pos

        # Re-interleave
        out = torch.empty_like(x)
        out[..., ::2] = out_even
        out[..., 1::2] = out_odd
        return out

In [67]:
from einops import unpack

x = torch.arange(2*3*4).reshape(2,3,4)
y = torch.arange(2*3*4,2*2*3*4).reshape(2,3,4)
print("x:",x)
print("y:",y)
z = rearrange([x,y],"two ... halfd -> ... (halfd two)")
print("z:",z)

x: tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
y: tensor([[[24, 25, 26, 27],
         [28, 29, 30, 31],
         [32, 33, 34, 35]],

        [[36, 37, 38, 39],
         [40, 41, 42, 43],
         [44, 45, 46, 47]]])
z: tensor([[[ 0, 24,  1, 25,  2, 26,  3, 27],
         [ 4, 28,  5, 29,  6, 30,  7, 31],
         [ 8, 32,  9, 33, 10, 34, 11, 35]],

        [[12, 36, 13, 37, 14, 38, 15, 39],
         [16, 40, 17, 41, 18, 42, 19, 43],
         [20, 44, 21, 45, 22, 46, 23, 47]]])


In [55]:
(res1, res2) = unpack(z,[[1],[2]],"d1 * d2")
print(res1.size())
print(res2.size())

torch.Size([2, 1, 9])
torch.Size([2, 2, 9])


In [None]:
def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
    xlen = len(x.shape)
    dim = dim % xlen
    pat_origin = " ".join([f"d{i}" for i in range(xlen)])
    # print(pat_origin)
    pat_reduce = " ".join([f"d{i}" if i != dim else "1" for i in range(xlen)])
    # print(pat_reduce)
    x_max = reduce(x,f"{pat_origin}->{pat_reduce}","max")
    # print(x_max)
    xminus = x - x_max
    expx = torch.exp(xminus)
    sum_expx = reduce(expx,f"{pat_origin}->{pat_reduce}","sum")
    res = expx / sum_expx
    return res


x = torch.arange(6).reshape(2,3)
print(x)
y = softmax(x,0)
print(y)

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[0.0474, 0.0474, 0.0474],
        [0.9526, 0.9526, 0.9526]])


In [87]:
import math
print(math.e ** 3)


20.085536923187664


In [None]:
import jaxtyping
from jaxtyping import Float, Bool, jaxtyped
from torch import Tensor
import torch
from beartype import beartype as typechecker

Q : Float[Tensor, "queries d_k"] = torch.arange(12,dtype=torch.float32).reshape(3,4)
print(Q)

@jaxtyped(typechecker=typechecker)
def scaled_dot_product_attention(
    Q: Float[Tensor, " ... queries d_k"],
    K: Float[Tensor, " ... keys d_k"],
    V: Float[Tensor, " ... values d_v"],
    mask: Bool[Tensor, " ... queries keys"] | None = None,
) -> Float[Tensor, " ... queries d_v"]:
    d_k = Q.shape[-1]
    scaled_prod = einsum(Q,K,"... queries d_k, ... keys d_k -> ... queries keys") \
                    / math.pow(d_k, 1/2)
    if mask != None:
        scaled_prod.masked_fill_(~mask, -torch.inf)

    probs = softmax(scaled_prod,-1)
    res = einsum(probs, V, "... queries keys_also_values), ... keys_also_values d_v -> ... queries d_v")
    return res










scaled_dot_product_attention(Q,Q,Q)




tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])


[tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]),
 tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]),
 tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]])]

In [None]:
def f(x:int|None):
    x = x + 1
    return x * x
f

In [168]:
b: Bool[torch.Tensor,"l m n"] = torch.rand(2,3,4) > 0.5
print(b)
x: Float[torch.Tensor,"m n"] = torch.rand(3,4)
print(x)
einsum(b,x,"l m n, m n -> l m n")

tensor([[[False,  True,  True, False],
         [ True,  True,  True, False],
         [False,  True,  True,  True]],

        [[False, False,  True,  True],
         [False,  True, False, False],
         [ True,  True, False, False]]])
tensor([[0.5654, 0.1289, 0.4903, 0.6791],
        [0.2509, 0.6957, 0.2935, 0.1262],
        [0.1683, 0.0790, 0.4932, 0.5859]])


tensor([[[0.0000, 0.1289, 0.4903, 0.0000],
         [0.2509, 0.6957, 0.2935, 0.0000],
         [0.0000, 0.0790, 0.4932, 0.5859]],

        [[0.0000, 0.0000, 0.4903, 0.6791],
         [0.0000, 0.6957, 0.0000, 0.0000],
         [0.1683, 0.0790, 0.0000, 0.0000]]])

In [157]:
@jaxtyped(typechecker=typechecker)
def f(x:Float[Tensor,"a b c"],y:Float[Tensor,"d e"])->Float[Tensor,"as a fj"]:
    z:Float[Tensor,"as as"] = torch.randn(2,3)
    print(x.shape)
    h:Float[Tensor,"iof oifj fj"]= torch.randn(3,2,5)
    return h


x:Float[Tensor,"as as"] = torch.randn(2,2,3)
print(x.shape)
y:Float[Tensor,"b555 666"] = torch.randn(3,2)

print(x)

try :
    print(f(x,y))
except Exception as e:
    print(e)

torch.Size([2, 2, 3])
tensor([[[-1.5638,  0.0540,  0.3095],
         [ 2.4185,  0.7174,  1.3276]],

        [[ 0.4948,  0.7561,  0.4457],
         [ 0.4840,  1.2234,  1.1101]]])
torch.Size([2, 2, 3])
tensor([[[ 1.2251, -0.1811, -0.5047, -1.5832,  1.1824],
         [-0.1172,  0.7591, -1.1259, -0.1950,  0.1354]],

        [[ 0.7084, -0.1122,  0.3650,  0.0315,  0.6761],
         [-1.8480,  1.5890, -0.2449,  0.2881, -2.0938]],

        [[-0.1589,  0.5601, -0.8201,  1.4290,  0.1324],
         [-1.5293, -0.7480, -1.0539,  1.4015,  0.4791]]])


In [107]:
import torch
from beartype import beartype
from jaxtyping import jaxtyped, Float, install_import_hook

# 使用 beartype 作为后端来装饰我们的函数，以进行运行时检查
@jaxtyped(typechecker=beartype)
def process_data(
    x: Float[torch.Tensor, "batch channels"],
    y: Float[torch.Tensor, "batch"],
):
    """一个期望批次大小一致的函数。"""
    print(f"成功处理了一批数据，批次大小为: {x.shape[0]}")

# --- 1. 成功的例子 ---
# 批次大小 (batch=10) 在两个张量中是一致的
print("--- 尝试运行成功示例 ---")
try:
    features_ok = torch.randn(10, 3)  # shape: (10, 3) -> batch=10, channels=3
    labels_ok = torch.randn(10)       # shape: (10)    -> batch=10
    process_data(features_ok, labels_ok)
except Exception as e:
    print(f"出现错误: {e}")

print("\n" + "="*40 + "\n")

# --- 2. 失败的例子 ---
# 批次大小 (batch) 在两个张量中不一致 (10 vs 5)
print("--- 尝试运行失败示例 ---")
try:
    features_bad = torch.randn(10, 3) # shape: (10, 3) -> "batch" 被绑定为 10
    labels_bad = torch.randn(5)       # shape: (5)    -> "batch" 尝试绑定为 5
    # 下一行将抛出异常
    process_data(features_bad, labels_bad)
except Exception as e:
    print(f"成功捕获到预期的错误:\n{e}")

--- 尝试运行成功示例 ---
成功处理了一批数据，批次大小为: 10


--- 尝试运行失败示例 ---
成功捕获到预期的错误:
Type-check error whilst checking the parameters of __main__.process_data.
The problem arose whilst typechecking parameter 'y'.
Actual value: f32[5](torch)
Expected type: <class 'Float[Tensor, 'batch']'>.
----------------------
Called with parameters: {'x': f32[10,3](torch), 'y': f32[5](torch)}
Parameter annotations: (x: Float[Tensor, 'batch channels'], y: Float[Tensor, 'batch']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
batch=10
channels=3
