In [1]:
import torch
import torch.nn as nn

In [2]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, 
                device: torch.device | None = None, 
                dtype: torch.dtype | None = None):
        super().__init__()
        W = torch.empty((in_features, out_features), dtype=dtype, device=device)
        W = nn.init.trunc_normal_(W, mean=0, std=2/(in_features+out_features))
        self.W = nn.Parameter(W, requires_grad=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.W

In [3]:
linear = Linear(32, 64)

In [6]:
print(linear.state_dict().keys())

odict_keys(['W'])


In [9]:
weights = torch.rand(32, 64)
linear.load_state_dict({'W': weights})

<All keys matched successfully>

In [10]:
linear(torch.rand(8,32)).shape

torch.Size([8, 64])

In [15]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, 
                 device: torch.device | None = None,
                 dytpe: torch.dtype | None = None):
        super().__init__()
        embed_matrix = torch.empty((num_embeddings, embedding_dim), dtype=dytpe, device=device)
        embed_matrix = nn.init.trunc_normal_(embed_matrix, mean=0, std=1, a=-3, b=3)
        self.embed_matrix = nn.Parameter(embed_matrix, requires_grad=True)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_matrix[token_ids]

In [16]:
embed = Embedding(256, 32)

In [17]:
x = torch.randint(low=0, high=256, size=(8, 64))

In [18]:
embed(x).shape

torch.Size([8, 64, 32])