### Table of Content
- Chapter 2 : Preliminary Knowledge
- Chapter 3 : Linear Neural Network
- Chapter 4 : Classification
- Chapter 5 : Multilayer Perceptrons
- Chapter 6 : Beginner Guide
- Chapter 7 : CNN
- Chapter 8 : Modern CNN
- Chapter 9 : RNN

  
- 工程和硬件优化 RoadMap
- 梯度检查技术

### Chapter 2 : Preliminary Knowledge
- 数据操作
  - 广播机制（两个数据分别复制扩充到同样的尺寸）
  - 节省内存（使用X[:] = \<expression\>或X+=\<expression\>来避免重新分配）
- 数据预处理
- 线性代数 
  - 转置.T 范数norm
  - 非降维求和 (keepdims=True)，累积和cumsum
  - torch.dot只支持向量，矩阵和向量间用mv，矩阵之间用mm
- 微积分
  - 设T是梯度算符，T(Ax) = A.T, T(x.T·A) = A, T(x.T A x) = (A + A.T)x
- 自动微分
  - 在默认情况下，PyTorch会累积梯度，我们需要清除之前的值
  - 自动微分必须是标量，非标量的话要么转成标量，要么指定输出形状
  - 分离操作
- 概率论
- 查阅文档、API的指导
  - dir查看可以调用的函数和类

### Chapter 3 : Linear Neural Network
- Minibatch stochastic gradient descent (小批量随机梯度下降)
- 一般的训练过程
  - model.forward() 与 y_hat 做差，然后反向传播，优化器根据导数去更新参数
- Machine Learning Concept
  - lasso regression: l1 norm; ridge regression: l2 norm;

## Chapter 4 : Classification
- softmax:
  $y_i = \frac{\exp(o_i)}{\sum_j \exp(o_j)}$, often minus max(oj) to get numerical stable
- Information theory
  - cross-entropy loss：$l(y, \hat y) = - \sum y_i * \log(\hat y_i)$
  - amount of information $\log{\frac{1}{P(j)}} = - \log{P(j)}$ 
  - entorpy $H[P] = \sum -P(j) \log{P(j)}$
  - cross-entorpy $H(P, Q) = \sum -P(j) \log{Q(j)}, ~ P=Q \rightarrow H(P, Q) = H(P, P) = H(P)$. In pytorch, F.cross_entropy will do the softmax for you.
- Image Classification Rules:
  - image stored in (channel, height, weight) manner.
- Distrubution shift:
  - Covariate Shift (feature shift): $p(x) \neq q(x), p(y|x) = q(y|x)$
    - For example: p(x) and q(x) are features of oral and urban house, y is the price, we assume the feature and label relation is the same
    - Method: weighted by $\beta(x) = p(x) / q(x) \rightarrow \int\int l(f(x), y)p(y|x)p(x)dxdy = \int\int l(f(x), y)q(y|x)q(x) \frac{p(x)}{q(x)}dxdy \rightarrow \sum_i \beta_i l(f(x_i), y_i)$, $\beta$ can be obtained with logistic regression.
  - Label Shift, $p(y) \neq q(y), p(x|y) = q(x|y)$, the same method $\beta(y) = p(y) / q(y)$, but now $q(y)$ is hard to get, we need compute a confusion matrix on the val data then use the model to pridcit the distrubution of the $q(y)$
  - Concept Shift (the concept of the label)

## Chapter 5 : Multilayer Perceptrons
- Activation Function: relu, sigmoid, tanh ($\frac{1 - \exp(-2x)}{1 + \exp(-2x)}$)
- Numerical stability: vanish and explode are common
  - Symmetry: linear layer and conv (with no share weight) layer are symmetric so we can not tell apart from different weight and try to explain it (for example 2 hidden unit with same initial value, they will update the same way), so we need to **Bread the Symmetry** (like using a dropout)
  - Xavier initilization: get from distrubution of zero mean and variance $\sigma = \sqrt{2 / (n_{in} + n_{out})}$
  - Dropout, shared param...
- (Rolnick et al., 2017) has revealed that in the setting of label noise, neural networks tend to fit cleanly labeled data **first** and only subsequently to interpolate the mislabeled data.
  - so we can use early stop once error on val is minimal or the patience hit. usually combined with regularization.
- Dropout:
  - $h^{'} = \left \{ 
  \begin{array}{lll}
  & 0, p \\
  & \frac{h}{1-p}, 1-p
  \end{array} 
  \right .$, now $E[h^{'}] = E[h]$
  - We do not use dropout in test, except we want to know the uncertainty of the model output (by comparing different dropout)
  - Use lower p in lower layer (to get lower feature), higher p in higher layer

## Chapter 6: Beginner Guide
- Tied layer: gradient will add up along different chain
- Custom initialization: `apply` method
- I/O
  - save tensor: `torch.save(x:Uinon[List[tensor], Dict], name:str)` and load
  - save model: the same, just input dict of the net (`net.state_dict()`) then `net.load_state_dict(torch.load(name))`
- GPU
  - operation between tensors must in the same GPU
  - print or transform to numpy will copy to memory, and even worse wait the python **GIL** (`Global Interpreter Lock`, make sure at the same time only one thread can execute the python bytecode)

## Chapter 7 : CNN
1. **Invariance**: translation equivariance, locality -> The earliest layers should respond similarly to the same patch and focus on local regions.
2. **Convolution**: math is $(f * g)(i, j) = \sum_a \sum_b f(a, b)  g(i - a, j - b)$, remind that **cross-correlation** is $(f * g)(i, j) = \sum_a \sum_b f(a, b)  g(i + a, j + b)$
   - The difference is not important as we will learn the kernel, `k_conv_learned = k_corr_learned.T`, or `conv(X, k_conv_learned) = corr(X, k_corr_learned)`
3. **Receptive Field**： for any element (tensors on the conv layer) x, all the elements that may effect x in the previous layers in the forward population.
4. **Padding, Stride**: $\lfloor (n_h - k_h + p_h + s_h) / s_h \rfloor \times \lfloor (n_w - k_w + p_w + s_w) / s_w \rfloor$, often `p_h = k_h - 1`, the same for `p_w`. `p_h = p_h_upper + p_h_lower`
5. **Channel**:
   - multi in $c_i$ -> kernel must also have the same channels ($c_i \times k_h \times k_w$), then add them up.
   - multi out $c_o$ -> kernel with $c_o \times c_i \times k_h \times k_w$, get $c_o$ output channels.
6. use `torch.stack` to stack tensors
7. **Pooling**: mitigating the sensitivity of convolutional layers to location and of spatially downsampling representations.

## Chapter 8 : Modern CNN
1. **AlexNet**: first deep conv successful, using dropout, Relu, polling
2. **VGG**: multiple 3 * 3 conv layers (two 3 * 3 conv touch 5 * 5 input as a 5 * 5 conv, but 2 * 3 * 3  = 18 < 25 = 5 * 5)
3. **NiN**: to handle 2 problem (1. much ram for the MLP at the end; 2. can not add MLP between the conv to increase the degree of nonlinearity as it will destroy the spatial information)
   - use 1 * 1 conv layer to add local nonlinearities across the channel activations
   - use global average pooling to integrate across all locations in the last representation layer. (must combine with added nonlinearities)
4. **GoogleNet**: Inception layer, parallel conv multi scales, and then concate them
5. **Batch Normalization**:
   - $BN(\mathbf x) = \mathbf{\gamma} \bigodot \frac{\mathbf x - \mathbf{\mu_B}}{\sigma^2_B} + \mathbf \beta$, $\mathbf{\mu_B} = \frac{1}{|B|}\sum_{x \in B} \mathbf x$,
     $\sigma^2_B = \frac{1}{|B|} \sum_{x \in B} (x - \mathbf{\mu_B})^2 + \epsilon$
   - On linear layer [N, D] it will get across D (different features in D will not do calculations), on conv layer [N, C, H, W] it will across C (save the difference between channels)
     - For example, [N, C, H, W] shape input x, for x[N, 0, H, W], get it's mean mu and std and do (x[N, 0, H, W] - mu) / std, here mu and std are scalar.
   - At the testing stage, we will use the global (whole) data mean and varience, instead of minibatch mean and varience. Just like dropout.
   - So BN also serves as a noise introducer! (minibatch information != true mean and var) Teye et al. (2018) and Luo et al. (2018).
   - So it best works for batch size of 50 ~ 100, higher the noise is small, lower it is too high.
   - Moving global mean and var: when testing, no minibatch, so we use a global one that is stored during training.
     - It is a kind of exp weighted mean, closest batch has higer weight
     - $\mu_m = \mu_m * (1 - \tau) + \mu * \tau, \Sigma_m = \Sigma_m * (1 - \tau) + \Sigma * \tau$, $\tau$ is called momentum term.
6. **Layer Normalization**: often used in NLP
   - For features like [N, A, B] it will save difference between N, A and B are typically seq_len, hidden_size.
7. **ResNet**: residual block, pass x as one of the branch before a activation function (for the original paper, and later it is changed to BN -> AC -> Conv)
   - To get the passed x has the correct shape to add up, we can use 1 * 1 conv if it is needed
   - **Idea**: nested-function class, shallower net (like ResNet-20) is subclass of depper net (like ResNet-50). Because in ResNet-50 if the layers after 20th layer are f(x) = x, then it is the same as RestNet-20! So we can make sure f' (the best we can get in ResNet-50 for certain data) will be better than f (ResNet-20 on the same data) or at least the same.
   - <p align="center">
       <img alt="Residul Block" src="https://d2l.ai/_images/resnet-block.svg" style="background-color: white; display: inline-block;">
       Rusidul Block
   </p>
   - **ResNeXt**: use g groups of 3 * 3 conv layers between two 1 * 1 conv of channel $b$ and $c_o$, so $\mathcal O(c_i c_o) \rightarrow \mathcal O(g ~ c_i / g ~ c_o / g) = \mathcal O(c_ic_o/g)$
     - This is a **Bottleneck** arch if $b < c_i$
       </br>
   - <img alt="ResNeXt Block" src="https://d2l.ai/_images/resnext-block.svg" style="background-color: white; display: inline-block;">
       ResNeXt Block
8. **DenseNet**: instead of plus x, we concatenate x repeatedly.
   - For example (\<channel\> indicates the channel): x\<c_1\> -> f_1(x)\<c_2\> end up with [x, f_1(x)]\<c_1 + c_2\> -> f_2([x, f_1(x)])\<c_3\> end up with [x, f_1(x), f_2([x, f_1(x)])]\<c_1 + c_2 + c_3\>
   - Too many of this layer will cause the dimeansion too big, so we need some layer to reduce it. **Translation** layer use 1 * 1 conv to reduce channel and avgpool to half the H and W.
9. **RegNet**:
   - AnyNet: network with **stem** -> **body** -> **head**.
   - Distrubution of net: $F(e,Z)=∑_{i=1}^{n}1(e_i<e)$, use this empirical CDF to approximate $F(e, p)$, $p$ is the net arch distrubution. $Z$ is a sample of net sample from $p$, if $F(e, Z_1) < F(e, Z_2)$ then we say $Z_1$ is better, it's parameters are better.
   - So for RegNet, they find that we should use same k (k = 1, no bottlenet, is best, says in paper) and g for the ResNeXt blocks with no harm, and increase the network depth d and weight c along the stage. And keep the c change linearly with $c_j = c_o + c_aj$ with slope $c_a$
   - neural architecture search (NAS) : with certain search space, use RL (NASNet), evolution alg (AmoebaNet), gradient based (DARTS) or shared weight (ENAS) to get the model. But it takes to much computation resource.
   - <img src="https://d2l.ai/_images/anynet.svg" style="background-color: white; display: inline-block;"> AnyNet Structure (Search Space)
   </br>

## Chapter 9 : RNN
- Two form of sequence to sequence task:
  - **aligned**: input at certain time step aligns with corrsponding output, like tagging (fight -> verb)
  - **unaligned**: no step-to-step correspondence, like maching translation
- **Autoregressive** model: regress value based on previous value
  - latent autoregressive models (since $h_t$ is never observed): estimate $P(x_t | x_{t-1} \dots x_1)$ with $\hat x_t = P(x_t | h_t)$ and $h_t = g(h_{t-1}, x_{t-1})$
- **Sequence Model**: to get joint probablity of a sequence $p(x_1, \dots, x_T)$, we change it to a form like autoregressive one: $p(x_1) \prod_{t=2}^T p(x_t|x_{t-1}, \dots, x_1)$
  - **Markov Condition**: if we can make the condition above into $x_{t-1}, \dots, x_{t-\tau}$ without any loss, aka the future is conditionally independent of the past, given the recent history, then the sequence satisfies a Markov condition. And it is $\tau^{th}$-order Markov model.
- Zipf’s law: the frequency of words will decrease exponentially, n-grams too (with smaller slope).
  - So use word frequency to construct the probility is not good, for example. $\hat p(learning|deep) = n(deep, learning) / n(deep)$, $n(deep, learning)$ will be very small compared to denominator. We can use so called **Laplace Smooth** but that will not help too much.
- **Perplexity**: (how confusion it is), given a true test data, the cross-entropy is $J = \frac{1}{n} \sum_{t=1}^n -\log P(x_t | x_{t-1}, \dots, x_1)$, and the perplexity is $\exp(J)$.
- Partioning the sequence: for a $T$ token indices sequence, we add some randomness, discard first $d \in U(0, n]$ tokens and partion the rest into $m = \lfloor (T-d) / n \rfloor$ group. For a sequence $x_t$ the target sequence is shifted by one token $x_{t+1}$.
---
- **RNN**: for a vocab with size $|\mathcal V|$, the model parameters should go up to $|\mathcal V|^n$, $n$ is the sequence length.So we $P(x_t | x_{t-1} \dots x_1) \approx P(x_t | h_{t-1})$，$h$ is a **hidden state**, it varies at different time step and contains information of previous time steps. Hidden layer, on the other hand, is a structure, it dose not change in forward calculation.
  - recurrent: $H_t = \phi (X_tW_{th} + H_{t-1}W_{hh} + b_h)$, output is $O_t = H_tW_{tq} + b_q$.
  - <img alt="ResNeXt Block" src="https://d2l.ai/_images/rnn.svg" style="background-color: white; display: inline-block;">
       RNN Block
  - clip the gradient: $g = \min(1, \frac{\theta}{|| g ||}) g$, it is a hack but useful.
  - **Warm-up**: When predicting, we can first feed a prefix (now called prompt I think), just iter the prefix into the network without generating output until we need to predict.
- For RNN: the input shape is (sequence_length, batch_size, feature_size), first is time_step, third is one-hot dim or word2vec dim.
- **Backpropagation through time**
  - <img alt="ResNeXt Block" src="https://d2l.ai/_images/rnn-bptt.svg" style="background-color: white; display: inline-block;"> Computation graph of RNN
  - How to reduce gradient explosion or vanishing: truncate the gradient propagete at certain time step.
  - In the img above: $\frac{\partial L}{\partial h_T} = W_{qh}^{\intercal} \frac{\partial L}{\partial o_T}$, $\frac{\partial L}{\partial h_t} = \sum_{i=t}^T (W_{hh}^{\intercal})^{T-i} W_{qh}^{\intercal} \frac{\partial L}{\partial o_{T+t-i}}$, $\frac{\partial L}{\partial W_{hx}} = \sum_{t=1}^T \frac{\partial L}{\partial h_t} x_t^{\intercal}$, $\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T \frac{\partial L}{\partial h_t} h_{t-1}^{\intercal}$.

## Chapter 10 : Modern RNN
- **LSTM**
  - The structure is :
  - <img alt="LSTM Arch" src="https://zh.d2l.ai/_images/lstm-3.svg" style="background-color: white; display: inline-block;"> LSTM Arch
- **GRU**
  - <img alt="GRU Arch" src="https://d2l.ai/_images/gru-3.svg" style="background-color: white; display: inline-block;"> GRU Arch
  - Reset gates help capture short-term dependencies in sequences.
  - Update gates help capture long-term dependencies in sequences.
- **Deep RNN**
  - <img alt="Deep RNN" src="https://d2l.ai/_images/deep-rnn.svg" style="background-color: white; display: inline-block;"> Deep RNN
  - In deep rnn, the output is the last layer of the hidden state with every timestep, and state is the last time step hidden state with all layer of rnn.
- **Bidirection RNN**, it is slow and gradient chain is long
  - $P(x_1,\ldots,x_T,h_1,\ldots,h_T)=\prod_{t=1}^TP(h_t\mid h_{t-1})P(x_t\mid h_t),\mathrm{~where~}P(h_1\mid h_0)=P(h_1)$, it is a hidden markov model. We can use dynamic programming method compute is from start to end, also from end to start. Just how B-RNN is capable of.
  - <img alt="Bidirection RNN" src="https://zh.d2l.ai/_images/birnn.svg" style="background-color: white; display: inline-block;"> B-RNN
  - And we just need to concatenate these two H.
- **Machine translation**
  - non-breaking space, some space should not split to new line, like Mr. Smith.
  - Teacher Forcing : all the input will be pad with \<pad\>, source token no special treat, decoder input (target seq use as input) will start with \<bos\>, and label is shift by 1 (no \<bos\> at the begining).
  - **Important**: when use teacher forcing, the truth target is feed to the decoder. This will make the traning faster and stable, but it will make training and predicting different (because when predicting we do not have truth target label, we have to repeatedly predict). We can make them the same, but the tranning will be harder.
- **Sequence to Sequence**
  - We use this Encoder - Decoder Arch to get varied length input and varied length output.
  - We do not use one-hot, instead we use nn.Embed layer, which will take token i, and return ith row of the matrix of this embeding layer.
  - From the encoder, we get the hidden states, and use a funcion $c = q(h_1, \cdots, h_T)$, for example, just use the $h_T$. And in the decoder, we concatenate this with the target embed output, and feed to rnn.
  - When calculating the loss, we should not take \<pad\> into acount. So we need to musk the loss with the tokens.
  - <img alt="Encoder Decoder" src="https://d2l.ai/_images/seq2seq-details.svg" style="background-color: white; display: inline-block;"> Encoder Decoder
  - Bilingual Evaluation Understudy, BLEU evaluates whether this n-gram in the predicted sequence appears in the target sequence. For example, target sequence ABCDEF, predict sequence ABBCD, $p_1 = 4/5$, we have ABCD in the target sequence, $p_2 = 3 / 4$, we have AB, BC, CD. So we get BLEU as $\exp\left(\min\left(0,1-\frac{\mathrm{len}_{\mathrm{label}}}{\mathrm{len}_{\mathrm{pred}}}\right)\right)\prod_{n=1}^kp_n^{1/2^n}$, higher n will have higher weight, small length of predict length takes lower.
- **Beam Search**
  - Before this section, we use greedy search to get prediction, use argmax on the prediction vector : $y_{t^{\prime}}=\underset{y\in\mathcal{Y}}{\operatorname*{\operatorname*{argmax}}}P(y\mid y_1,\ldots,y_{t^{\prime}-1},\mathbf{c})$, where $\mathcal Y$ is the vacab. Once our model outputs “<eos>” (or we reach the maximum length $T'$) the output sequence is completed.
  - However, use the most likely tokens is not the same with the most likely sequence : $\prod_{t^{\prime}=1}^{T^{\prime}}P(y_{t^{\prime}}\mid y_1,\ldots,y_{t^{\prime}-1},\mathbf{c})$. For example, in this figure below, ACB will have this probability of 0.5 * 0.3 * 0.6 = 0.09. On the other hand, greedy search choose ABC which is 0.5 * 0.4 * 0.4 = 0.08, it is lower, not optimal!
  - <img alt="Max sequence" src="https://d2l.ai/_images/s2s-prob2.svg" style="background-color: white; display: inline-block;">Max sequence <img alt="Max token" src="https://d2l.ai/_images/s2s-prob1.svg" style="background-color: white; display: inline-block;"> Max token.
  - If we want the optimal one, we need to do exhaustive search, search all possible sequence, it is not possible!
  - The most straightforward type of beam search is keep k candidates. In time step 2, we get $P ( A, y_{2} \mid\mathbf{c} )=P ( A \mid\mathbf{c} ) P ( y_{2} \mid A, \mathbf{c} )$ for the top, and $P ( C, y_{2} \mid\mathbf{c} )=P ( C \mid\mathbf{c} ) P ( y_{2} \mid C, \mathbf{c} ) $ for the bottom, then choose most 2 from them. And then choose sequence that maximize $\frac{1} {L^{\alpha}} \mathrm{l o g} \, P ( y_{1}, \ldots, y_{L} \mid\mathbf{c} )=\frac{1} {L^{\alpha}} \sum_{t^{\prime}=1}^{L} \mathrm{l o g} \, P ( y_{t^{\prime}} \mid y_{1}, \ldots, y_{t^{\prime}-1}, \mathbf{c} ) ; $. Note tha we have **6** candidates (A, C ..).
  - <img alt="Max sequence" src="https://d2l.ai/_images/beam-search.svg" style="background-color: white; display: inline-block;">

## Chapter 11 : Attention Mechanisms and Transformers
- **Idea** : attention first come up in encoder - decoder design, rather than tranform the input to a fixed size feature and the feed it to all the decoder step, we want to create a representation that has the same length of input and decoder at each time step can pay attention to different input sequence (with it's weight). And transfromer give up residual connection, instead use attention at all.
- **Q, K, V**
  - Define database $\mathcal{D} \stackrel{\mathrm{d e f}} {=} \{( \mathbf{k}_{1}, \mathbf{v}_{1} ), \ldots( \mathbf{k}_{m}, \mathbf{v}_{m} ) \} $, give some query, the attention is $\text{Attention}(q, D) \stackrel{\mathrm{d e f}} {=} \sum_{i=1}^m \alpha(q, k_i)v_i $, where $\alpha(q, k_i)$ are scalar attention weight, this operation is also called attention pooling. We want this attention weight to be larger than 0 and sum up to 1, so we can use a softmax to transfrom it.
- **Attention pooling with similarity**
  - In a regression task, we can use kernel output as the attention weight (after normalization).
  - When directly compute loss $(f(x_i) - y_i)^2$， because we have $y_i$, the $\sigma$ will go to zero, causing overfitting. Even if we remove $y_i$ from the loss, if the dataset is large enough, we may still overfit.
  - <img alt="Attention Pooling" src="https://d2l.ai/_images/attention-output.svg" style="background-color: white; display: inline-block;"> Attention Pooling
- **Attention Scoring Function**
  - **Dot Product Attention**: note that kernel of gaussian is $a(q, k_i) = q^{\intercal}k_i - 0.5 * ||q||^2 - 0.5 * ||k_i||^2$, and after the softmax normalization, second one is cancled out. Then if we get $k_i$ with batch or layer normalization, it's length will be bounded and often constant, so we can get rid of last term without penalty. This leads to $a(q, k_i) = q^{\intercal}k_i$, ann assume both $q$ and $k_i$ have zero mean and unit variance, this attention weight will have a variance of $d$ (which is query feature size), so we can further normalize it : $a(q, k_i) = q^{\intercal}k_i / \sqrt{d}$, this is the common one used in transformer. At last, we do a softmax on it.
  - Because we do not want consider \<pad\>, so we can do musk softmax. And for Batch Matrix Multiplication, use torch.bmm. bmm(Q, K) take Q(n, a, b) and K(n, b, c), which return [Q1 @ K1, ..., Q_n @ K_n].
  - **Additive Attention**: when q and k has different feature size, we can do a transform ($q^{\intercal}Mk$), or use this additive attention $a(\mathbf{q},\mathbf{k})=\mathbf{w}_v^\top\tanh(\mathbf{W}_q\mathbf{q}+\mathbf{W}_k\mathbf{k})\in\mathbb{R}$, inside the () we add by broadcasting.
- **Bahdanau Attention**
  - Attention function will work between encoder hidden state and decoder hidden state, $c_{t'} = \sum_t^T a(s_{t'-1}, h_t)h_t$, and this will used to generate $s_{t'}$.
  - <img alt="Bahdanau Attention" src="https://d2l.ai/_images/seq2seq-details-attention.svg" style="background-color: white; display: inline-block;"> Bahdanau Attention
  - Seq2SeqAttentionDecoder first use encoder last layer hidden state as the query, and later use it's own hidden state from it's rnn model. Keys and values are all encoder outputs (last layer hidden state at all time step), then concatenate the embed X with this context (attention result) then feed to it's rnn.
- **Multi-head Attention**
  - <img alt="Multi-head Attention" src="https://d2l.ai/_images/multi-head-attention.svg" style="background-color: white; display: inline-block;"> Multi-head Attention
  - We want the same Q, K, V to have different behaviour with the same attention mechanism, so we have to copy them $h$ times and first pass them into FC layer which has learnable param that can change the QKV, then feed to attention, get $h$ results, concatenate them. $h_i = f(W_i^{q}q_i, W_i^kk_i, W_i^vv_i)$, $f$ is attention pooling, each $W$ have shape of $(p_q, d_q), (p_k, d_k)$ and $(p_v, d_v)$, $h_i$ is of shape $(p_v,)$. We concatenate these h to a $(h \times p_v,)$ shape. And we use a big learnable matrix of shape $(p_o, h \times p_v)$ times the concatenated result, which finnally return output of shape $(p_o, )$. For the purpose of parallel computation, we set $hp_q = hp_k = hp_v = p_o$.
  - Impl of d2l is the same idea, but for parrallel, it use some trick, hidden_size is h * p_q, put num_head into batch_size, make the batch_size = batch_size * num_head, and in the output reverse it.
- **Self Attention**
  - Do the attention(X, X, X) to get encoder. Compare CNN, RNN and self-attention, given n sequence with d dimension. CNN : choose kernel of 3, computation is O(knd^2), longest connect path is O(n/k). RNN : compute O(nd^2), path O(n). Self attention : compute O(n^2d), path O(1).
  - Positional Encoding: self attention does not contain position (order) information, so a token at time step 1 and 5 are the same (but it should not!). So we need to add something to keep the position information. First we use fixed position encoding with sine and cosine. $X^{n \times d}$ is the input representation of n tokens, and the position encoding is $X + P$, where $p_{i,2j} = \sin \left( \frac{i}{10000^{2j/d}} \right)$ and $p_{i,2j+1} = \cos \left( \frac{i}{10000^{2j/d}} \right)$. This works because the function abouve contain different frequency information.
  - Relative position encoding : $\begin{bmatrix}
\cos(\delta\omega_j) & \sin(\delta\omega_j) \\
-\sin(\delta\omega_j) & \cos(\delta\omega_j)
\end{bmatrix}
\begin{bmatrix}
p_{i,2j} \\
p_{i,2j+1}
\end{bmatrix} = \begin{bmatrix}
p_{i+\delta,2j} \\
p_{i+\delta,2j+1}
\end{bmatrix}$, we can just add a (1, step, hidden) param, and add it to the embed(X) to learn the position.
- **Transformer**
  - <img alt="Transformer" src="https://d2l.ai/_images/transformer.svg" style="background-color: white; display: inline-block;"> Transformer Arch
  - The encoder-decoder attention layer take decoder self-attention layer output as query, and encoder output as key and value.
  - Note that in decoder self-attention, we will carefully mask output to reserve the autoregreesive nature, we do not take position in the outpt (later as the input of decoder self-attention layer) after the position we are calculating.
  - Before pos encoding we first multiply sqrt(d) with embed(X) to rescale it, maybe because embed(X) has small variace.
  - For prediction, we need to cache the input X for the decoder, in training we can just compute all time step all together. In impl, it is cached in state[2].
  - **!!** only the last output of the encoder will do attention on all block of decoder.
- **Vision Transformer**
  - <img alt="Vision Transformer" src="https://d2l.ai/_images/vit.svg" style="background-color: white; display: inline-block;"> Vision Transformer
  - patch embeding will feed to a conv then flatten it, return shape of (batch, patch, hidden)
  - Do the normalization before the attention is better for the efficient learning of transformer. The vit mlp layer use GELU and dropout is applied to the output of each fully connected layer in the MLP for regularization.

- Large Scale Pre-training
  - Encoder only, ViT, BERT. BERT use masked language modeling, and for a token, tokens at left and right can all attend to this masked token. So it is a bidirection encoder --- in the figure below, each token along the vertical axis attends to all input tokens along the horizontal axis.
  - <img alt="BERT" src="https://d2l.ai/_images/bert-encoder-only.svg" style="background-color: white; display: inline-block;"> BERT
  - Encoder-Decoder, BART & T5, both attempt to reconstruct original text in their pretraining objectives, while the former emphasizes noising input (e.g., masking, deletion, permutation, and rotation) and the latter highlights multitask unification with comprehensive ablation studies.
  - <img alt="T5" src="https://d2l.ai/_images/t5-encoder-decoder.svg" style="background-color: white; display: inline-block;"> T5
  - Decoder only, GPT. **In-context learning** : conditional on an input sequence with the task description, task-specific input–output examples, and a prompt (task input). 
  - <img alt="GPT" src="https://d2l.ai/_images/gpt-decoder-only.svg" style="background-color: white; display: inline-block;"> GPT
  - <img alt="x - shot" src="https://d2l.ai/_images/gpt-3-xshot.svg" style="background-color: white; display: inline-block;"> x-shot
- Efficient Transformer design (see that survey)
  - Sparse attention：Longformer：使用滑动窗口和全局注意力，降低复杂度到O(n)。BigBird：结合随机注意力、窗口注意力和全局注意力，适合长序列。
  - Low rank approximation：Linformer：通过低秩分解将注意力矩阵投影到较低维度，复杂度从O(n²)降到O(n)。
  - Memory：Transformer-XL：引入循环记忆，处理长序列时重用之前的隐藏状态，避免重复计算。
  - Efficient attention：Performer：使用核方法（Favor+）近似点积注意力，复杂度降为O(n)。
  - Model compress：Distillation：将大Transformer蒸馏为小模型（如DistilBERT）。量化：减少参数精度，降低内存占用。
<!-- <img alt="ResNeXt Block" src="https://d2l.ai/_images/rnn.svg" style="background-color: white; display: inline-block;"> -->
<!-- <img alt="ResNeXt Block" src="https://d2l.ai/_images/rnn-bptt.svg" style="background-color: white; display: inline-block;"> -->

## Chapter 12 : Optimization
- **Convexity**
  - Convex set : for any $a, b \in \mathcal X$, given $\lambda \in [0, 1]$, $\lambda a + (1 - \lambda) b \in \mathcal X$.
  - Convex function : for any function $f : \mathcal X \rightarrow \mathbb R$, we have $\lambda f(x) + (1 - \lambda) f(x') >= f(\lambda x + (1-\lambda)x')$.
  - jensen's inequality : $\sum_i\alpha_if(x_i)\geq f\left(\sum_i\alpha_ix_i\right)$ and $E_X[f(X)]\geq f\left(E_X[X]\right)$
  - Properties : Local Minima Are Global Minima, below set $\mathcal{S}_b\overset{\mathrm{def}}{\operatorname*{=}}\{x|x\in\mathcal{X}\mathrm{~and~}f(x)\leq b\}$ is also a convex set, f is convex if hessian of f is positive semidefinite ($\nabla^2 f = H, x^THx >= 0$).
  - Convex with constraint $c_i(x) <= 0$, can be dealed with lagrangian, and the KKT condition. KKT are:
    - Stationarity : $\nabla_x L(f(x), \lambda_1, \ldots, \lambda_n) = 0$.
    - Primal Feasibility : $c_i(x) <= 0$
    - Dual Feasibility : $\lambda_i >= 0$
    - Complementary Slackness : $\lambda_i c_i(x) = 0$
  - Penality is robust than constraint. We can also use projection to satisfy constraints.
- **Gradient Descent**
  - $x \leftarrow x - \eta \nabla_x f(x)$, with newton's method $\eta = \nabla_x^{-2} f(x) = H^{-1}$
  - H is expensive, so we can use precondition $x \leftarrow x - \eta \text{diag}(H)^{-1}\nabla_x f(x)$, this means for different $x_i$ we use different learning rates.
  - Line search : use binary search to find $\eta$ that minimize $f(x - \eta \nabla_x f(x))$.
- **SGD**： converge with rate $\mathcal O (1/\sqrt T)$, $T$ is the sample number. More details of the math please see the book.
- **Momentum**
  - Use leaky average $v_k = \beta v_{k-1} + g_{k, k-1}$ as the gradient, this is the momentum!
  - Gradient descent with and without momentum for a convex quadratic function decomposes into coordinate-wise optimization in the direction of the eigenvectors of the quadratic matrix.
  - The velocity converge condition is loose than gradient converge condition, so add momentum (with big $\beta$ ) is theoritaly better.
- **Adagrad** : it is a SGD alg
  - Some features are rare, so we want to update it faster ( we do not update their gradient much ).
  - Some problem has large condition number $k = \lambda_{max} / \lambda_{min}$, which is not good. We can rescale them by some matrix (if Hessian of the problem L is possitive semidefinite), or just rescalse the diag of the Q. $\tilde Q = \text{diag}(Q)^{-1/2}Q\text{diag}(Q)^{-1/2}$. However this is not realistic in DL, because we don't have second derivitive of Q, so Adagrad use the norm of the gradient as the scalse item. And this makes it adjust element wise (like only diag will change).
  - $s_t = s_{t-1} + g_t^2, w_t = w_t - \eta / \sqrt{s_t+\epsilon} \odot g_t$, one problem of Adagrad is that it's learning rate decrease $\mathcal O(t^{-1/2})$.
- **RMSProp**
  - $s_t = \gamma s_{t-1} + (1-\gamma) g_t^2$, only difference with Adagrad
- **Adadelta**
  - $\mathbf{s}_{t}=\rho \mathbf{s}_{t-1}+(1-\rho) \mathbf{g}_{t}^{2}$, $\mathbf{g}_{t}^{\prime}=\frac{\sqrt{\Delta \mathbf{x}_{t-1}+\epsilon}}{\sqrt{\mathbf{s}_{t}+\epsilon}} \odot \mathbf{g}_{t}$, $x_t = x_t - \mathbf{g}_{t}^{\prime}$, $\Delta\mathbf{x}_{t}=\rho\Delta\mathbf{x}_{t-1}+( 1-\rho) \mathbf{g}_{t}^{\prime\, 2}, $.
- **Adam**
  - $v_t = \beta_1 v_{t-1} + (1-\beta_1)g_{t}$, $s_t = \beta_2 s_{t-1} + (1-\beta_2)g^2_{t}$, and the rescale it (otherwise the initial numbers are too diverge from gradient), $\hat v_t = v_t / (1 + \beta_1^t)$, $\hat s_t = s_t / (1 + \beta_2^t)$. Then finnally $x_t = x_t - \eta \hat v_t / (\sqrt{\hat s_t} + \epsilon)$
  - One of the problems of Adam is that it can fail to converge even in convex settings when the second moment estimate in $s_t$ blows up as $g^2_t$ being too large and forget the history. Yogi update is $s_t = s_{t-1} + (1-\beta_2)g^2_{t} \odot (g^2_{t} - s_{t-1})$, the update is not the deviation of $g^2_{t} - s_{t-1}$, it is $g^2_{t}$ with regard to the sign.
- Scheduler
  - Warmup: In particular they find that a warmup phase limits the amount of divergence of parameters in very deep networks. A closer look at deep learning heuristics: learning rate restarts, warmup and distillation. ArXiv:1810.13243.
<!-- <img alt="ResNeXt Block" src="https://d2l.ai/_images/rnn.svg" style="background-color: white; display: inline-block;"> -->
<!-- <img alt="ResNeXt Block" src="https://d2l.ai/_images/rnn-bptt.svg" style="background-color: white; display: inline-block;"> -->

## Chapter 13 : Computation
- compiler
  - net = torch.jit.script(net)
- Automatic Parallesim
  - y.to('cpu', non_blocking=non_blocking) for y in x, will return x[i-1] when calculate x[i]
- Tranning on multiple GPU
  - <img alt="Partion Methods" src="https://d2l.ai/_images/splitting.svg" style="background-color: white; display: inline-block;"> Partion Methods
  - nn.parallel.scatter to split data to different devices
  - 显式同步（torch.cuda.synchronize()）仅在需要精确测量执行时间或调试异步错误时必要，其他情况会自己根据cpu或者后续数据需求隐式调用
- Concise impl :
  - What we need to do
    - Network parameters need to be initialized across all devices.
    - While iterating over the dataset minibatches are to be divided across all devices.
    - We compute the loss and its gradient in parallel across devices.
    - Gradients are aggregated and parameters are updated accordingly.
  - Use torch.nn.parallel.DistributedDataParallel
- Parameter Server
  - <img alt="Parameter Exchange" src="https://d2l.ai/_images/ps-distributed.svg" style="background-color: white; display: inline-block;">
  - last graph above assume gradient can be divided into four parts, and exchange each one of them each GPU.
  - Ring Synchronization
  - Key–Value Stores

<!-- <img alt="ResNeXt Block" src="https://d2l.ai/_images/rnn.svg" style="background-color: white; display: inline-block;"> -->
<!-- <img alt="ResNeXt Block" src="https://d2l.ai/_images/rnn-bptt.svg" style="background-color: white; display: inline-block;"> -->

在**深度学习（DL）工程**和**硬件优化**方面，需要掌握一系列工具、技术和最佳实践，以确保模型能够高效训练、优化和部署。  

---

# **1. 深度学习工程**
**目标**：不仅要训练模型，还要能够在实际应用中高效地**数据处理、训练、调优、部署和维护**。

## **1.1 数据工程**
深度学习的性能很大程度上依赖于数据质量和预处理效率。

### **(1) 数据收集与存储**
- **结构化数据**（SQL, Pandas, BigQuery）
- **图像数据**（OpenCV, PIL, TensorFlow Datasets）
- **文本数据**（NLTK, Hugging Face Datasets）
- **流数据**（Kafka, Apache Spark）

### **(2) 数据预处理**
- **标准化 / 归一化**（Min-Max Scaling, Z-score）
- **数据增强**（图像：旋转、裁剪；文本：同义词替换）
- **降维**（PCA, t-SNE, UMAP）
- **缺失值处理**（均值填充、插值）

### **(3) 数据加载优化**
- **批量加载（Batch Loading）**
- **多线程 / 多进程数据预处理（Dataloader, TensorFlow tf.data）**
- **TFRecord / HDF5**（二进制格式加速数据读取）

---

## **1.2 训练与超参数调优**
深度学习模型训练是一个计算密集型过程，需要高效的**优化策略**和**超参数调整**。

### **(1) 训练优化**
- **优化器选择**
  - SGD（标准梯度下降）
  - Adam / RMSprop（自适应优化）
  - LARS / LAMB（用于大规模分布式训练）
  
- **正则化**
  - Dropout（随机丢弃神经元）
  - Batch Normalization（批量归一化）
  - Weight Decay（L2 正则化）

- **梯度裁剪（Gradient Clipping）**
  - 解决梯度爆炸问题

### **(2) 超参数优化**
自动搜索最优超参数（例如学习率、batch size、权重初始化）。
- **Grid Search（网格搜索）**
- **Random Search（随机搜索）**
- **Bayesian Optimization（贝叶斯优化）**
- **Hyperband（高效采样）**
- **Optuna / Ray Tune（自动化超参数调优工具）**

---

## **1.3 训练加速**
大规模训练时需要高效的训练加速技术：

### **(1) GPU 加速**
- 训练时尽可能利用 **CUDA** / **cuDNN**
- **混合精度训练（Mixed Precision）**：使用 FP16（Half Precision）加速计算
- **数据并行（DataParallel）** vs. **模型并行（ModelParallel）**
  
### **(2) 分布式训练**
- **单机多卡（Multi-GPU Training）**
  - PyTorch `DataParallel`
  - PyTorch `DistributedDataParallel (DDP)`
  
- **多机多卡（Multi-Node Training）**
  - TensorFlow `MirroredStrategy`
  - Horovod（Uber 提出的高效分布式训练框架）

---

## **1.4 部署与推理优化**
深度学习不仅要训练，还要在**边缘设备**或**服务器端**高效推理。

### **(1) 模型压缩**
- **剪枝（Pruning）**：去掉不重要的权重
- **量化（Quantization）**：
  - **8-bit INT 量化**（TensorRT, TFLite）
  - **混合精度推理（FP16, INT8）**

- **知识蒸馏（Knowledge Distillation）**：
  - 用大模型训练小模型，提高推理效率

### **(2) 推理框架**
- **ONNX（Open Neural Network Exchange）**：模型通用格式，可用于 PyTorch / TensorFlow 互转
- **TensorRT（NVIDIA）**：高效的 GPU 加速推理
- **TVM（Apache）**：自动优化模型推理

### **(3) 部署方式**
- **服务器部署**
  - Flask / FastAPI（REST API 部署）
  - TensorFlow Serving / TorchServe（高效模型服务）

- **移动端 / 边缘部署**
  - TensorFlow Lite（TFLite）
  - CoreML（iOS 设备）
  - NVIDIA Jetson（嵌入式 AI）

---

# **2. 硬件优化**
深度学习的计算量极大，硬件的优化能**显著提高训练和推理速度**。

## **2.1 GPU 计算**
GPU 是深度学习的核心计算设备，NVIDIA CUDA 生态至关重要。

### **(1) GPU 编程基础**
- CUDA 编程（掌握 Kernel 编写）
- cuDNN（深度学习优化库）
- Tensor Core（用于混合精度计算）
  
### **(2) GPU 训练优化**
- **减少 CPU-GPU 传输**（优化 `pin_memory=True`）
- **梯度累积（Gradient Accumulation）**，减少显存占用
- **使用 FP16 训练**（提高吞吐量）

---

## **2.2 分布式计算**
适用于**超大规模数据训练**（如 GPT、Llama 等模型）。

### **(1) 并行策略**
- **数据并行（Data Parallelism）**
  - 复制模型到多个 GPU，每个 GPU 训练不同数据
  - PyTorch `DistributedDataParallel (DDP)`

- **模型并行（Model Parallelism）**
  - 适用于超大模型（如 GPT-4）
  - DeepSpeed / Megatron-LM 优化

- **流水线并行（Pipeline Parallelism）**
  - 将不同层分配到不同 GPU，提高计算效率
  - **适用于 Transformer 训练**

### **(2) 高效通信**
- **NCCL（NVIDIA Collective Communication Library）**：优化 GPU 之间的通信
- **RDMA（远程直接内存访问）**：用于 GPU 服务器间高速通信

---

## **2.3 专用 AI 硬件**
除了 GPU，AI 训练还可以用专用芯片加速：
- **TPU（Google）**：专门优化深度学习计算
- **Graphcore IPU**（稀疏计算优化）
- **Cerebras Wafer-Scale Engine**（超大规模 AI 计算）

---

# **3. 总结**
| **类别** | **关键内容** |
|----------|--------------|
| **数据工程** | 数据清洗、数据增强、数据加载优化 |
| **训练优化** | 超参数调优、正则化、优化器选择 |
| **训练加速** | GPU 加速、混合精度、分布式训练 |
| **部署优化** | 模型量化、剪枝、TensorRT 加速 |
| **硬件优化** | CUDA、NCCL、TPU/FPGA |

你已经有**矩阵分解和 Rust 经验**，如果想深入工程优化，可以：
1. **研究 PyTorch DDP / DeepSpeed**（分布式训练优化）
2. **学习 CUDA / cuDNN 编程**（低级 GPU 加速）
3. **尝试 TensorRT / ONNX**（推理加速）

这将让你在 **深度学习工程 & 硬件优化** 方面具备更强的竞争力 🚀

## 梯度检查点

梯度检查点（Gradient Checkpointing，也叫 Checkpointing 或 Recomputation）是一种在训练深度神经网络时用来**节省内存**的技术，尤其在网络非常深或批量大小（batch size）较大时非常有用。它的核心思想是：在内存有限的情况下，通过牺牲一些计算时间来减少存储激活值（中间输出）的内存需求。下面我将详细解释它的原理、实现方式以及优缺点。

---

### 背景：为什么需要梯度检查点？
在深度神经网络的训练中，反向传播需要用到前向传播时每一层的激活值。这些激活值通常会被存储在内存中，以便在计算梯度时直接使用。然而：
- 对于深层网络（比如 ResNet-101、Transformer 等），层数非常多，激活值的内存占用会变得非常大。
- 当批量大小增加时，激活值的存储需求进一步线性增长。
- 如果显存（如 GPU 内存）不足，可能无法训练模型，或者只能使用很小的批量大小，影响训练效果。

梯度检查点的目标是通过**不存储所有激活值**，而是在需要时重新计算它们，从而大幅减少内存占用。

---

### 基本原理
正常情况下，训练一个神经网络的过程是：
1. **前向传播**：从输入到输出，计算每一层的激活值并存储。
2. **反向传播**：利用存储的激活值和损失函数的梯度，计算每一层参数的梯度。

梯度检查点的工作方式是：
- 在前向传播中，**只保存部分关键层的激活值**（这些点称为“检查点”），而不是每一层的激活值。
- 在反向传播时，对于未保存激活值的层，通过从最近的检查点重新运行前向传播来**重新计算激活值**，然后继续计算梯度。

这种方法用额外的计算（重新计算激活值）换取了内存的节省。

---

### 详细步骤
假设一个简单的网络有 5 层：\( L_1 \rightarrow L_2 \rightarrow L_3 \rightarrow L_4 \rightarrow L_5 \)。

#### 普通训练
1. 前向传播：计算并存储所有层的激活值 \( A_1, A_2, A_3, A_4, A_5 \)。
2. 反向传播：从 \( L_5 \) 开始，利用 \( A_5, A_4, \ldots, A_1 \) 计算每一层的梯度。
3. 内存需求：存储所有 \( A_1 \) 到 \( A_5 \)，假设每层激活值占 \( M \) 个字节，总共 \( 5M \)。

#### 使用梯度检查点
假设我们选择 \( L_2 \) 和 \( L_4 \) 作为检查点：
1. **前向传播**：
   - 计算 \( L_1 \rightarrow L_5 \)，但只保存检查点的激活值 \( A_2 \) 和 \( A_4 \)。
   - 其他层的激活值 \( A_1, A_3, A_5 \) 在前向传播后被丢弃。
   - 内存需求：仅存储 \( A_2 \) 和 \( A_4 \)，即 \( 2M \)。

2. **反向传播**：
   - 从 \( L_5 \) 开始计算梯度，需要 \( A_4 \)（已保存）和 \( A_5 \)（未保存）。
   - 从 \( A_4 \) 重新运行前向传播 \( L_4 \rightarrow L_5 \)，重新计算 \( A_5 \)，然后计算 \( L_5 \) 和 \( L_4 \) 的梯度。
   - 继续到 \( L_3 \)，需要 \( A_2 \)（已保存）和 \( A_3 \)（未保存）。
   - 从 \( A_2 \) 重新运行 \( L_2 \rightarrow L_3 \)，计算 \( A_3 \)，然后计算 \( L_3 \) 的梯度。
   - 最后从 \( A_2 \) 和输入重新计算 \( A_1 \)，完成 \( L_2 \) 和 \( L_1 \) 的梯度。

3. **内存节省**：只存 2 个检查点的激活值（\( 2M \)），而不是 5 个（\( 5M \)）。
4. **额外计算**：需要重新运行部分前向传播（例如 \( L_2 \rightarrow L_3 \) 和 \( L_4 \rightarrow L_5 \)）。

---

### 检查点的选择
如何选择检查点是一个关键问题：
- **均匀分布**：例如每隔 \( k \) 层设置一个检查点（如每 10 层）。
- **动态选择**：根据内存需求和计算复杂度，优先选择内存占用大的层作为检查点。
- **理论最优**：研究表明，最优检查点策略可以将内存需求从 \( O(N) \)（N 为层数）降低到 \( O(\sqrt{N}) \)，但实现起来较复杂。

现代深度学习框架（如 PyTorch、TensorFlow）通常内置了检查点功能，用户只需指定哪些层或子模块作为检查点。

---

### 数学分析
假设网络有 \( N \) 层，每层激活值占 \( M \) 字节：
- **普通训练**：内存需求 \( N \times M \)。
- **检查点训练**：
  - 设检查点数为 \( K \)（\( K < N \)），内存需求为 \( K \times M \)。
  - 额外计算量与 \( N/K \) 成正比（每个检查点负责的层数）。
  - 最优情况下，\( K \approx \sqrt{N} \)，内存降为 \( O(\sqrt{N} \times M) \)。

例如，\( N = 100 \)：
- 普通训练：\( 100M \)。
- 检查点训练（\( K = 10 \)）：\( 10M \) + 少量额外计算。

---

### 优点
1. **内存效率**：显著减少激活值的存储需求，允许训练更深的网络或使用更大的批量大小。
2. **灵活性**：在显存受限的设备上也能运行复杂模型。
3. **广泛适用**：适用于 CNN、RNN、Transformer 等各种架构。

---

### 缺点
1. **计算开销**：重新计算激活值增加了训练时间，通常比普通训练慢 20%-50%，具体取决于检查点数量和网络结构。
2. **实现复杂性**：需要手动指定检查点或依赖框架支持，调试可能更困难。
3. **不适合所有场景**：如果内存不是瓶颈，检查点反而会降低效率。

---

### 实际应用
- **PyTorch 示例**：
  ```python
  import torch
  import torch.nn as nn
  from torch.utils.checkpoint import checkpoint

  class MyModel(nn.Module):
      def __init__(self):
          super().__init__()
          self.layer1 = nn.Linear(1024, 1024)
          self.layer2 = nn.Linear(1024, 1024)
          self.layer3 = nn.Linear(1024, 1024)

      def forward(self, x):
          x = checkpoint(self.layer1, x)  # 检查点包裹层1
          x = checkpoint(self.layerThe layer2(x)      # 普通层
          x = checkpoint(self.layer3, x)  # 检查点包裹层3
          return x

  model = MyModel()
  input = torch.randn(32, 1024)  # batch_size = 32
  output = model(input)
  ```

- **大模型**：如 GPT、LLaMA 或大型 Transformer，层数可能高达数百层，使用梯度检查点可以显著降低显存需求。

---

### 总结
梯度检查点是一种以时间换空间的优化技术，通过减少激活值的存储来降低内存需求，非常适合显存受限或深层网络的场景。它的核心是选择部分检查点并在反向传播时重新计算其他层的激活值。虽然增加了计算开销，但在内存是瓶颈时，它能让“不可能的任务”变得可行。

如果你对实现细节或具体场景有疑问，随时告诉我，我可以进一步展开！