In [3]:
from __future__ import annotations

import os
from collections.abc import Iterable
from typing import IO, Any, BinaryIO

import numpy.typing as npt
import torch
import torch.nn as nn
from jaxtyping import Bool, Float, Int
from torch import Tensor

from einops import einsum, rearrange

In [16]:
class Linear(nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.factory_kwargs = factory_kwargs
        self.in_features = in_features
        self.out_features = out_features
        self.weights = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        nn.init.trunc_normal_(self.weights)
    
    def forward(self, input: Tensor) -> Tensor:
        #using torch.nn.init.trunc_normal_ to initialize weights
        output = einsum(input, self.weights, '... d_in, d_out d_in -> ... d_out')
        return output
    
linear = Linear(10, 5)
input_tensor = torch.randn(3, 10)
output_tensor = linear(input_tensor)
print(output_tensor)


tensor([[ 5.9765, -1.2592,  1.3626, -1.7853, -1.0226],
        [-0.2771,  4.2815,  1.6222,  3.2386, -2.3127],
        [ 1.6991, -1.0844, -1.6868, -2.4434,  0.2129]],
       grad_fn=<ViewBackward0>)


In [None]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.factory_kwargs = factory_kwargs
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.weights = nn.Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
        nn.init.trunc_normal_(self.weights)
    
    def forward(self, token_ids: Tensor) -> Tensor:
        # token_ids shape: (batch_size, seq_length)
        input = token_ids.long()
        output = self.weights[input]
        return output

In [7]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.factory_kwargs = factory_kwargs
        self.dim = dim
        self.eps = eps
        self.scale = nn.Parameter(torch.ones((dim,), **factory_kwargs))
    
    def forward(self, x: Tensor) -> Tensor:
        in_dtype = x.dtype
        x = x.to(torch.float32)
        norm_x = x.norm(2, dim=-1, keepdim=True)
        rms_x = norm_x / (x.shape[-1] ** 0.5)
        x_normalized = x / (rms_x + self.eps)
        output = x_normalized * self.scale
        return output.to(in_dtype)
rmsnorm = RMSNorm(10)
input_tensor = torch.randn(3, 10)
output_tensor = rmsnorm(input_tensor)
print(output_tensor)

tensor([[ 1.5095,  0.8736,  0.8067,  0.9645, -1.0628,  1.5417, -0.3999, -0.0769,
         -0.6319, -1.1427],
        [ 1.1821,  0.8556,  0.6303, -0.8412, -1.9316,  0.4741, -0.3189,  0.7791,
         -1.4013, -0.3703],
        [-0.7361,  0.0817, -1.9319,  1.7296,  1.0280, -0.8135,  0.3896, -0.2692,
          0.8007, -0.3789]], grad_fn=<MulBackward0>)


In [None]:
class SiLU(nn.Module):
    def __init__(self, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.factory_kwargs = factory_kwargs
    
    def forward(self, x: Tensor) -> Tensor:
        return x * torch.sigmoid(x)
    
class SwiGLU(nn.Module):
    def __init__(self, d_ff: Int, d_model: Int, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.factory_kwargs = factory_kwargs
        self.linear1 = Linear(d_model, d_ff, 
                          device=self.factory_kwargs['device'], 
                          dtype=self.factory_kwargs['dtype'])
        self.linear2 = Linear(d_ff, d_model,
                            device=self.factory_kwargs['device'], 
                            dtype=self.factory_kwargs['dtype'])
        self.linear3 = Linear(d_model, d_ff,
                            device=self.factory_kwargs['device'], 
                            dtype=self.factory_kwargs['dtype'])
        self.activation = SiLU(device=self.factory_kwargs['device'], 
                              dtype=self.factory_kwargs['dtype'])
    
    def forward(self, x: Tensor) -> Tensor:
        x1 = self.linear1(x)
        x2 = self.linear3(x)
        activated = self.activation(x1)
        gated = activated * x2
        output = self.linear2(gated)
        return output

In [None]:
class RoPE(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
        super().__init__()
        self.factory_kwargs = {'device': device}
        self.max_seq_len = max_seq_len
        self.theta = theta
        self.d_k = d_k
        self.inv_freq = 1.0 / (theta ** (torch.arange(0, d_k, 2, **self.factory_kwargs) / d_k))
        self.register_buffer('cos_cached', torch.zeros((max_seq_len, d_k), **self.factory_kwargs))
        self.register_buffer('sin_cached', torch.zeros((max_seq_len, d_k), **self.factory_kwargs))

    def forward(self, x: Tensor, token_positions: Tensor) -> Tensor:
        seq_len = x.shape[-2]
        if seq_len > self.max_seq_len:
            raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")

        t = token_positions[:, :seq_len]
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos_emb = emb.cos()[None, :, :]
        sin_emb = emb.sin()[None, :, :]
        x_rotated = (x * cos_emb) + (self.rotate_half(x) * sin_emb)
        return x_rotated


In [47]:
# 第一行第一列保持为True
x = torch.randn(2, 8, 8)
seq_len = x.size(1)
mask = torch.tril(torch.ones((seq_len, seq_len), device=x.device), diagonal=0).bool()

In [5]:
weights = torch.nn.Parameter(5*torch.randn(5, 5))
opt = torch.optim.SGD([weights], lr=1)
for _ in range(100):
    opt.zero_grad()
    loss = (weights**2).mean()
    print(loss.cpu().item())
    loss.backward()
    opt.step()

28.035511016845703
23.7292537689209
20.084440231323242
16.999469757080078
14.388352394104004
12.178300857543945
10.307714462280273
8.724449157714844
7.384374618530273
6.250134468078613
5.29011344909668
4.47755241394043
3.7898004055023193
3.20768666267395
2.714986562728882
2.297964334487915
1.9449970722198486
1.6462457180023193
1.3933823108673096
1.1793588399887085
0.9982093572616577
0.8448843955993652
0.7151100039482117
0.605269193649292
0.5122998356819153
0.43361055850982666
0.3670079708099365
0.31063559651374817
0.26292192935943604
0.22253715991973877
0.1883554458618164
0.1594240367412567
0.13493651151657104
0.11421027034521103
0.09666755795478821
0.08181943744421005
0.06925196945667267
0.05861486867070198
0.04961162060499191
0.041991278529167175
0.03554141893982887
0.030082259327173233
0.025461621582508087
0.02155071683228016
0.01824052631855011
0.01543878111988306
0.013067386113107204
0.01106023509055376
0.009361382573843002
0.007923474535346031
0.0067064277827739716
0.005676320288

In [4]:
import torch
outputs = torch.randn(2, 3, 4)# (batch_size, seq_length, feature_dim)
targets = torch.randn(2, 3) # (batch_size, seq_length)
print(outputs.view(-1, outputs.size(-1)))  # Reshape to (batch_size * seq_length, feature_dim)
print(targets.view(-1))  # Reshape to (batch_size * seq_length)


tensor([[ 1.4734, -0.8047,  0.9079,  1.0118],
        [ 1.7014, -0.2291,  0.8223, -0.1887],
        [-0.6490, -0.8112, -1.0593,  0.4451],
        [ 0.9344,  0.0696,  1.5256,  1.1624],
        [ 0.5979,  0.2832, -0.1796, -0.8682],
        [-0.7022, -0.2886,  0.1699,  1.2305]])

In [5]:
import numpy as np
data = np.load('/home/std10/extend/generated_data/tokenized_data_train.npy', allow_pickle=True)
