In [1]:
# config = {
#     "nodes": {
#         "node:0": "127.0.0.1:9327",
#         "node:1": "127.0.0.1:9328",
#     },
#     "devices": {
#         "SPU": {
#             "kind": "SPU",
#             "config": {
#                 "node_ids": [
#                     "node:0",
#                     "node:1"
#                 ],
#                 "spu_internal_addrs": [
#                     "127.0.0.1:9327",
#                     "127.0.0.1:9328"
#                 ],
#                 "runtime_config": {
#                     "protocol": "SEMI2K",
#                     "field": "FM64",
#                     "enable_pphlo_profile": true,
#                     "enable_pphlo_trace": true,
#                     "enable_hal_profile": true
#                 }
#             }
#         },
#         "P1": {
#             "kind": "PYU",
#             "config": {
#                 "node_id": "node:0"
#             }
#         }
#     }
# }

def twopc_def():
    nodes_def =  {
        "node:0": "127.0.0.1:9327",
        "node:1": "127.0.0.1:9328",
    }
    devices_def = {
        "SPU": {
            "kind": "SPU",
            "config": {
                "node_ids": [
                    "node:0",
                    "node:1"
                ],
                "spu_internal_addrs": [
                    "127.0.0.1:9327",
                    "127.0.0.1:9328"
                ],
                "runtime_config": {
                    "protocol": "SEMI2K",
                    "field": "FM64",
                    "enable_pphlo_profile": True,
                    "enable_pphlo_trace": True,
                    "enable_hal_profile": True
                }
            }
        },
        "P1": {
            "kind": "PYU",
            "config": {
                "node_id": "node:0"
            }
        }
    }
    return nodes_def, devices_def


!pip install flax

In [3]:
import jax.numpy as jnp
import numpy as np
from torch.nn import Module, Conv2d, Linear, Softmax, ReLU,LayerNorm, Sequential
import spu.utils.distributed as ppd
import time

class PatchEmbedding(Module):
    def __init__ (
        self,
        img_size,
        patch_size,
        in_channels,
        embed_dim=768,
        norm_layer=None,
        flatten=True,
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size= patch_size
        self.grid_size=img_size // patch_size
        self.num_patches =self.grid_size * self.grid_size
        self.flatten = flatten

        self.proj = Conv2d(
            in_channels,
            embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )
        self.norm_layer = norm_layer(embed_dim) if norm_layer else lambda x:x
    def __call__ (self, inputs): 
        x = self.proj(inputs)
        return inputs

class GeLU(Module):
    def __init__ (self,approximate = True):
        super().__init__()
        self.approximate = approximate
    def __call__ (self,x):
        if self.approximate:
            sqrt_2_over_pi=np.sqrt(2 /np.pi).astype(x.dtype)
            cdf =0.5 *(1.0 + jnp.tanh(sqrt_2_over_pi*(x + 0.044715 *(x** 3))))
            return x*cdf


class MultiHeadAttention(Module):
    def init (
        self,
        embed_dim,
        num_heads,
        use_bias=True,
        kdim=None,
        vdim=None
    ):
        super().__init__()
        self.embed_dim= embed_dim
        self.num_heads = num_heads
        self.kdim= kdim if kdim is not None else embed_dim
        self.vdim= vdim if vdim is not None else embed_dim
        self.use_bias = use_bias
        self.head_dim=embed_dim // num_heads
        # assert self.head dim *num heads == embed dim
        self.q_proj= Linear(embed_dim, self.kdim, bias=use_bias)
        self.k_proj= Linear(embed_dim, self.kdim, bias=use_bias)
        self.v_proj= Linear(embed_dim,self.vdim, bias=use_bias)
        self.out_proj= Linear(self.vdim, embed_dim, bias=use_bias)
        self.softmax=Softmax(dim=-1)
        self.relu = ReLU()

    def __call__ (self, x):
        bs,seq_len,_=x.shape
        # q，k，v projection
        q= self.q_proj(x)
        k= self.k_proj(x)
        v= self.v_proj(x)
        # softmax(qk^T)V
        attn_logits =jnp.matmul(q,jnp.swapaxes(k,-2,-1))
        attn_logits =attn_logits / jnp.sqrt(self.head_dim)
        attn = self.softmax(attn_logits)
        x= jnp.matmul(attn,v)
        x=x.reshape(bs,seq_len,-1)
        x= self.out_proj(x)
        return x
    # MLP layer

# Transformer
class Transformer(Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        mlp_ratio=4,
        qkv_bias=False,
        act_layer=GeLU,
        norm_layer=LayerNorm,
        kdim=None,
        vdim=None,
    ):
        self.norml = norm_layer(embed_dim)
        self.attn =MultiHeadAttention(
            embed_dim,
            num_heads,
            use_bias=qkv_bias,
            kd_im=kdim,
            vdim=vdim,
        )
        self.norm2= norm_layer(embed_dim)
        self.mlp =Sequential(
            Linear(embed_dim,int(embed_dim * mlp_ratio)),
            GeLU(),
            Linear(int(embed_dim * mlp_ratio),embed_dim)
        )
    def __call__(self,inputs):
        inputs = inputs + self.attn(self.norm1(inputs))
        inputs = inputs + self.mlp(self.norm2(inputs))
        return inputs

def main():
    # init spu env
    nodes_def, devices_def = twopc_def()
    devices_def['SPU']['config']['runtime_config']["enable_action_trace"]=True
    ppd.init(nodes_def, devices_def)

    # hyperparameters
    embed_dim=65
    num_tokens=65
    # kdim, vdim=None, None
    num_heads=4

    # create an obj
    func = MultiHeadAttention(
        embed_dim=embed_dim,
        num_heads = num_heads,
        use_bias=False,
        kvim=embed_dim//num_heads,
        vdim=embed_dim//num_heads
    )

    # assign device and func
    spu_func=ppd.device("SPU")(lambda: func)

    # mpc computing
    for i in range(10):
        feat=np.random.randn(1,num_tokens, embed_dim)
        plain_out=func(feat)

        enc_feat=ppd.device("P1")(lambda: feat)()
        start=time.time()
        enc_out=spu_func(enc_feat)
        enc_out=ppd.get(enc_out)
        
        print(time.time()-start)

