In [1]:
# Code to set up the project
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/10714
!git clone https://github.com/Doraemonzzz/dlsys-project.git
%cd dlsys-project

!pip3 install pybind11

Mounted at /content/drive
/content/drive/MyDrive/10714/project
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pybind11
  Downloading pybind11-2.10.3-py3-none-any.whl (222 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m222.4/222.4 KB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pybind11
Successfully installed pybind11-2.10.3


In [2]:
import sys
sys.path.append('./python')

In [3]:
# Download the datasets you will for this project
import urllib.request
import os

# Download CIFAR-10 dataset
if not os.path.isdir("./data/cifar-10-batches-py"):
    urllib.request.urlretrieve("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", "./data/cifar-10-python.tar.gz")
    !tar -xvzf './data/cifar-10-python.tar.gz' -C './data'

In [8]:
# complie code
!make clean
!make

rm -rf build python/needle/backend_ndarray/ndarray_backend*.so
-- The C compiler identification is GNU 7.5.0
-- The CXX compiler identification is GNU 7.5.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Python: /usr/bin/python3.8 (found version "3.8.16") found components: Development Interpreter Development.Module Development.Embed 
-- Performing Test HAS_FLTO
-- Performing Test HAS_FLTO - Success
-- Found pybind11: /usr/local/lib/python3.8/dist-packages/pybind11/include (found version "2.10.3")
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_H

# Batch matrix multiplication
In the course, we learned about the implementation of matrix multiplication, which is a mapping of the form:
$$
m\times n, n\times p \to m\times p.
$$
Due to subsequent needs, we implement the batch version of matrix multiplication here, namely `bmm`:
$$
b\times m\times n,b\times  n\times p \to b\times m\times p.
$$

## Implementation
Cpu version:
```c++
void BatchMatmul(const AlignedArray& a, const AlignedArray& b, AlignedArray* out, uint32_t b_, uint32_t m, uint32_t n,
            uint32_t p) {
  for (int l = 0; l < b_; l++) {
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < p; j++) {
        float res = 0;
        for (int k = 0; k < n; k++) {
          res += a.ptr[l * m * n + i * n + k] * b.ptr[l * n * p + k * p + j];
        }
        out->ptr[l * m * p + i * p + j] = res;
      }
    }
  }
}
```

Cuda Version:
```cpp
// batch matrix multiply
__global__ void BatchMatmulKernel(const scalar_t* a, const scalar_t* b, scalar_t* out, uint32_t B, uint32_t M, uint32_t N,
            uint32_t P) {
  size_t x = blockIdx.x * blockDim.x + threadIdx.x;
  size_t y = blockIdx.y * blockDim.y + threadIdx.y;
  if (x < M && y < P) {
    for (int l = 0; l < B; l++) {
      float res = 0;
      for (int z = 0; z < N; z++) {
        res += a[l * M * N + x * N + z] * b[l * N * P + z * P + y];
      }
      out[l * M * P + x * P + y] = res;
    }
  }
}

void BatchMatmul(const CudaArray& a, const CudaArray& b, CudaArray* out, uint32_t B, uint32_t M, uint32_t N,
            uint32_t P) {
  /**
   * Multiply two (compact) matrices into an output (also comapct) matrix.  You will want to look
   * at the lecture and notes on GPU-based linear algebra to see how to do this.  Since ultimately
   * mugrade is just evaluating correctness, you _can_ implement a version that simply parallelizes
   * over (i,j) entries in the output array.  However, to really get the full benefit of this
   * problem, we would encourage you to use cooperative fetching, shared memory register tiling, 
   * and other ideas covered in the class notes.  Note that unlike the tiled matmul function in
   * the CPU backend, here you should implement a single function that works across all size
   * matrices, whether or not they are a multiple of a tile size.  As with previous CUDA
   * implementations, this function here will largely just set up the kernel call, and you should
   * implement the logic in a separate MatmulKernel() call.
   * 
   *
   * Args:
   *   a: compact 2D array of size b x m x n
   *   b: comapct 2D array of size b x n x p
   *   out: compact 2D array of size b x m x p to write the output to
   *   M: rows of a / out
   *   N: columns of a / rows of b
   *   P: columns of b / out
   */

  /// BEGIN YOUR SOLUTION
  dim3 DimGrid((M / d) + 1, (P / d) + 1, 1);
  dim3 DimBlock(d, d, 1);

  BatchMatmulKernel<<<DimGrid, DimBlock>>>(a.ptr, b.ptr, out->ptr, B, M, N, P);
  /// END YOUR SOLUTION
}

```
The code can be found in `project/src`. 
To support `bmm`, we should also change the `__matmul__` in `./python/needle/backend_ndarray/ndarray.py`:
```cpp
    ### Matrix multiplication
    def __matmul__(self, other):
        """Matrix multplication of two arrays.  This requires that both arrays
        be 2D (i.e., we don't handle batch matrix multiplication), and that the
        sizes match up properly for matrix multiplication.

        In the case of the CPU backend, you will implement an efficient "tiled"
        version of matrix multiplication for the case when all dimensions of
        the array are divisible by self.device.__tile_size__.  In this case,
        the code below will restride and compact the matrix into tiled form,
        and then pass to the relevant CPU backend.  For the CPU version we will
        just fall back to the naive CPU implementation if the array shape is not
        a multiple of the tile size

        The GPU (and numpy) versions don't have any tiled version (or rather,
        the GPU version will just work natively by tiling any input size).
        """
        assert (self.ndim == 2 and other.ndim == 2) or (self.ndim == 3 and other.ndim == 3)
        assert self.shape[-1] == other.shape[-2]

        if self.ndim == 2:
            m, n, p = self.shape[0], self.shape[1], other.shape[1]
            out = NDArray.make((m, p), device=self.device)
            self.device.matmul(
                self.compact()._handle, other.compact()._handle, out._handle, m, n, p
            )
            return out
        else:
            b, m, n, p = self.shape[0], self.shape[1], self.shape[2], other.shape[-1]
            out = NDArray.make((b, m, p), device=self.device)
            self.device.batchmatmul(
                self.compact()._handle, other.compact()._handle, out._handle, b, m, n, p
            )
            return out
```

# Linear Attention
In the course, we learned **Attention**. Although the performance is excellent, the time complexity of Attention is $O(n^2d)$, where $n$ is the sequence length and $d$ is the feature dimension. To alleviate this, we can use **Linear Attention**.
Linear Attention is calculated as follows:
$$
\mathbf O = \mathrm{diag}\{\phi(\mathbf Q) \phi (\mathbf K^{\top}) \mathbf 1_n \}^{-1} \phi(\mathbf Q) \phi (\mathbf K^{\top}) \mathbf V, \tag 1
$$
where $\mathbf Q, \mathbf K, \mathbf V \in \mathbb R^{n\times d}$. Note that in this naive implementation, the time complexity is still $O(n^2 d), we will improve this later.$

The difference between **Linear Attention** and **Attention** is as follows:
- Added feature map $\phi$;
- No exp activation function is used for $\phi(\mathbf Q) \phi (\mathbf K^{\top}$ before normalization;
- The associative law of matrix multiplication can be used for efficient calculation,

For the third point, note that for matrixs $\mathbf A, \mathbf B, \mathbf C$, the following equation holds:
$$
\mathbf A \mathbf B \mathbf C = \mathbf A (\mathbf B \mathbf C ).
$$
This is known as the associative law of matrix multiplication. Using this property, Equation (1) is equivlant to:
$$
 \mathrm{diag}\{\phi(\mathbf Q) [\phi (\mathbf K^{\top}) \mathbf 1_n] \}^{-1} [\phi(\mathbf Q)[\phi (\mathbf K^{\top}) \mathbf V]]. \tag 2
$$
Let's analyze the computational complexity of the above formula:
- $\mathrm{diag}\{(\phi(\mathbf Q) \phi (\mathbf K^{\top}) \mathbf 1_n \}^{-1}$
  - $ \phi (\mathbf K^{\top}) \mathbf 1_d: (d, n), (n, 1)\to (d, 1), O(nd);$
  - $\phi(\mathbf Q) [\phi (\mathbf K^{\top}) \mathbf 1_n]: (n, d), (d, 1) \to (n, 1), O(nd)$;
  - $ \mathrm{diag}\{\phi(\mathbf Q) \phi (\mathbf K^{\top}) \mathbf 1_n \}^{-1} :(n, 1) \to (n, n), O(n)$l
- $\phi(\mathbf Q)[\phi (\mathbf K^{\top}) \mathbf V]$:
  - $\phi (\mathbf K^{\top}) \mathbf V: (d, n), (n, d) \to (d, d), O(nd^2)$;
  - $\phi(\mathbf Q)[\phi (\mathbf K^{\top}) \mathbf V]:(n, d), (d, d) \to (n, d), O(nd^2)$,

So the total time complexity is $O(nd^2)$, when $n\gg d$, $n d^2 \ll n^2 d$, which results much faster computation.


## Implementation
In this part, we give the implementation of `LinearAttention`, here we use ReLU as $\phi$, `eps` is added to prevent the value of the denominator from underflowing. To support 3D linear projection, we implemented `Linear3D`. The code is also available at `./python/needle/linear_transformer.py`:
```python
# 3D version of LinearLayer
class Linear3D(Module):
    def __init__(self, in_features, out_features, device=None, dtype="float32"):
        super().__init__()
        self.linear = Linear(in_features, out_features, device=device, dtype=dtype)
        
    def forward(self, x):
        b, n, d = x.shape
        # b, n, d -> b * n, d
        x = ops.reshape(x, (b * n, d))
        # b * n, d -> b * n, e
        x = self.linear(x)
        # b * n, e -> b, n, e
        d = x.shape[-1]
        x = ops.reshape(x, (b, n, d))
        
        return x

class LinearAttention(Module):
    def __init__(self, d, h, device=None, dtype="float32"):
        super().__init__()
        self.qkv = Linear3D(d, 3 * d, device=device, dtype=dtype)
        self.out = Linear3D(d, d, device=device, dtype=dtype)
        self.d = d
        self.h = h
        self.e = self.d // self.h
        self.act = ReLU()
    
    def forward(self, x, eps=1e-5):
        # b, n, d -> b, n, 3 * d
        qkv = self.qkv(x)
        # get shape
        b = qkv.shape[0]
        n = qkv.shape[1]
        # reshape
        qkv = ops.reshape(qkv, (b, n, 3, self.d))
        # split
        q, k, v = ops.split(qkv, axis=2)
        # b, n, d -> b, n, h, e
        q, k, v = [ops.reshape(x, (b, n, self.h, self.e)) for x in (q, k, v)]
        # b, n, h, e -> b, h, n, e
        q, k, v = [ops.transpose(x, (2, 1)) for x in (q, k, v)]
        # b, h, n, e -> b * h, n, e
        q, k, v = [ops.reshape(x, (b * self.h, n, self.e)) for x in (q, k, v)]
        # act
        q = self.act(q) + eps
        k = self.act(k) + eps
        # (b * h, n, e), (b * h, n, e) -> (b * h, e, e)
        kv = ops.matmul(ops.transpose(k, (2, 1)), v)
        # (b * h, n, e), (b * h, e, e) -> (b * h, n, e)
        output = ops.matmul(q, kv)
        # qk denom
        # 1, n, 1 -> b * h, n, 1
        ones = init.ones(n, device=output.device, dtype=output.dtype)
        ones = ops.broadcast_to(ops.reshape(ones, (1, n, 1)), (k.shape[0], n, 1))
        # (b * h, n, e), (b * h, n, 1) -> (b * h, e, 1)
        t1 = ops.matmul(ops.transpose(k, (2, 1)), ones)
        # (b * h, n, e), (b * h, e, 1) -> (b * h, n, 1)
        t2 = ops.matmul(q, t1)
        # (b * h, n, e), (b * h, n, 1) -> (b * h, n, e)
        output = ops.divide(output, ops.broadcast_to(t2, output.shape))
        # (b * h, n, e) -> (b, h, n, e)
        output = ops.reshape(output, (b, self.h, n, self.e))
        # (b, h, n, e) -> (b, n, h, e)
        output = ops.transpose(output, (2, 1))
        # (b, n, h, e) -> (b, n, d)
        output = ops.reshape(output, (b, n, self.d))
        # (b, n, d) -> (b, n, d)
        output = self.out(output)
        
        return output
    
```


# Linear Transformer.
If a **LinearAttention** is followed by an **FFN**, a Linear Transformer is formed, where **FFN** is a mapping of the following form and  $\sigma$ is activation function:
$$
\mathrm{FFN}(\mathbf X) = \sigma(\mathbf X W_1) \mathbf W_2.
$$

## Implementation
In this part, we give the implementation of **FFN** and **LinearTransformer**, we also use ReLU as activation function $\sigma$. The code is also available at `./python/needle/linear_transformer.py`:

```python
class FFN(Module):
    def __init__(self, d, device=None, dtype="float32"):
        super().__init__()
        self.module = Sequential(
            Linear3D(d, 2 * d, device=device, dtype=dtype),
            ReLU(),
            Linear3D(2 * d, d, device=device, dtype=dtype),
        )

    def forward(self, x):
        return self.module(x)
    
class LinearTransformer(Module):
    def __init__(self, d, h, device=None, dtype="float32"):
        super().__init__()
        self.module = Sequential(
            Residual(
                Sequential(
                    LayerNorm1d(d, device=device, dtype=dtype),
                    LinearAttention(d, h, device=device, dtype=dtype),
                )
            ),
            Residual(
                Sequential(
                    LayerNorm1d(d, device=device, dtype=dtype),
                    FFN(d, device=device, dtype=dtype),
                )
            ),
        )
        
    def forward(self, x):
        return self.module(x)
    
```

## Linear Vit
Vit is a visual model proposed by Google. If you replace the **Transformer** with **LinearTransformer**, you can get `LinearVit`. Here we also made the following modifications to facilitate implementation:
- use convolution to implement patchfy;
- use mean pooling insead of cls token.

We list the code below, which is also available at `./apps/linear_vit.py`:
```python
class LinearVit(ndl.nn.Module):
    def __init__(self, d=32, h=2, device=None, dtype="float32"):
        super().__init__()
        self.patch_embedding = lt.PatchEmbedding(3, d, device=device, dtype=dtype)
        self.linear_transformer = nn.Sequential(
            lt.LinearTransformer(d, h, device=device, dtype=dtype),
        )
        self.mean = lt.Mean()
        self.linear = nn.Sequential(
            nn.Linear(d, 128, device=device, dtype=dtype), 
            nn.ReLU(), 
            nn.Linear(128, 10, device=device, dtype=dtype)
        )
        
    def forward(self, x):
        # b, d, h, w -> b, h1 * w1, d1
        x = self.patch_embedding(x)
        # b, h1 * w1, d1 -> b, h1 * w1, d1
        x = self.linear_transformer(x)
        # b, h1 * w1, d1 -> b, d1
        x = self.mean(x)
        # b, d1 -> b, m
        x = self.linear(x)
        
        return x
```

# Train a model on Cifar-10 dataset

Finally, let's train a classification model on CIFAR-10 dataset. The accuracy is about 41.5% after 10 epoch updates.

In [None]:
import needle as ndl
import numpy as np

from apps.linear_vit import LinearVit

np.random.seed(0)

# device = ndl.cpu()
device = ndl.cuda()

def epoch_general(
    dataloader, model, loss_fn=ndl.nn.SoftmaxLoss(), opt=None, device=None
):
    if opt:
        model.train()
    else:
        model.eval()
    correct, total_loss = 0, 0
    cnt = 0
    for i, batch in enumerate(dataloader):
        if opt:
            opt.reset_grad()
        X, y = batch
        X, y = ndl.Tensor(X, device=device), ndl.Tensor(y, device=device)
        out = model(X)
        cnt += X.shape[0]
        correct += np.sum(np.argmax(out.numpy(), axis=1) == y.numpy())
        loss = loss_fn(out, y)
        total_loss += loss.data.numpy() * y.shape[0]
        if opt:
            loss.backward()
            opt.step()
        if opt and i % 10 == 0:
            print(
                f"After update {i} times, the loss is {total_loss / cnt}, the accuracy is {correct / cnt}."
            )

    return correct / cnt, total_loss / cnt


train_data = ndl.data.CIFAR10Dataset("./data/cifar-10-batches-py", train=True)
test_data = ndl.data.CIFAR10Dataset("./data/cifar-10-batches-py", train=False)
print(f"number of train dataset: {train_data.n}")
print(f"number of test dataset: {test_data.n}")
train_dataloader = ndl.data.DataLoader(
    dataset=train_data, batch_size=100, shuffle=False
)
test_dataloader = ndl.data.DataLoader(dataset=test_data, batch_size=100, shuffle=False)
model = LinearVit(device=device, dtype="float32")
loss_fn = ndl.nn.SoftmaxLoss()
opt = ndl.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
epochs = 10
for i in range(epochs):
    print(f"Start epoch {i}")
    train_acc, train_loss = epoch_general(
        dataloader=train_dataloader,
        model=model,
        loss_fn=loss_fn,
        opt=opt,
        device=device,
    )

    test_acc, test_loss = epoch_general(
        dataloader=test_dataloader,
        model=model,
        loss_fn=loss_fn,
        opt=None,
        device=device,
    )

    print(
        f"After training {i} epochs, the loss is {test_loss}, the accuracy is {test_acc}."
    )