In [146]:
import torch
import torch.nn as nn

In [147]:
num_embeddings = 256
d_model = 128
hidden_size = 128
batch_size = 1024
dtype = torch.float32

In [148]:
def l2norm(x: torch.Tensor):
    with torch.no_grad():
        return torch.linalg.vector_norm(x, dim=-1)

def init_parameters(m: nn.Module) -> None:
    std = 0.02
    cutoff = 3 * std
    with torch.no_grad():
        if hasattr(m, "weight") and m.weight is not None:
            nn.init.trunc_normal_(m.weight, mean=0.0, std=std, a=-cutoff, b=cutoff)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.zeros_(m.bias)

In [149]:
class CompetingModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.act = nn.SiLU()
        self.left_emb = nn.Embedding(num_embeddings, d_model, dtype=dtype)
        self.right_emb = nn.Embedding(num_embeddings, d_model, dtype=dtype)
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False, bias=False, dtype=dtype)
        self.ff_in = nn.Linear(d_model * 2, hidden_size, bias=False, dtype=dtype)
        self.ff_out = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype)
        self.left_out = nn.Linear(hidden_size, num_embeddings, bias=False, dtype=dtype)
        self.right_out = nn.Linear(hidden_size, num_embeddings, bias=False, dtype=dtype)
        
    def reset_parameters(self):
        init_parameters(self.left_emb)
        init_parameters(self.right_emb)
        init_parameters(self.norm)
        init_parameters(self.ff_in)
        init_parameters(self.ff_out)
        init_parameters(self.left_out)
        init_parameters(self.right_out)
        
    def forward(self, left_input: torch.Tensor, right_input: torch.Tensor):
        l1 = self.left_emb(left_input)
        r1 = self.right_emb(right_input)
        x1 = torch.concat([l1, r1], dim=-1)
        x2 = self.ff_in(x1)
        x3 = self.act(x2)
        x4 = self.ff_out(x3)
        l2 = self.left_out(x4)
        r2 = self.right_out(x4)
        return ((l2, r2), (
            l2norm(l1),
            l2norm(r1),
            l2norm(x1),
            l2norm(x2),
            l2norm(x3),
            l2norm(x4),
            l2norm(l2),
            l2norm(r2),            
        ))

torch.manual_seed(1492)
model = CompetingModel()
model.reset_parameters()

optim = torch.optim.AdamW(
    model.parameters(),
    lr = 1e-6
)

loss = nn.CrossEntropyLoss()

for step in range(1000000):
    optim.zero_grad()
    left = torch.randint(0, num_embeddings, (batch_size,))
    right = torch.randint(0, num_embeddings, (batch_size,))
    ys, norms = model(left, right)
    left_loss = loss(ys[0], left)
    right_loss = loss(ys[1], right)
    total_loss = left_loss + right_loss
    total_loss.backward()
    optim.step()
    if step % 1000 == 0:
        norm_strings = [f"{x.mean().item():.3f}" for x in norms]
        print(f"loss: {total_loss.item():.3f}\tnorms: {norm_strings}")

loss: 11.090	norms: ['0.222', '0.223', '0.315', '0.070', '0.035', '0.008', '0.002', '0.002']
loss: 11.090	norms: ['0.222', '0.222', '0.314', '0.070', '0.035', '0.008', '0.002', '0.003']
loss: 11.090	norms: ['0.224', '0.225', '0.318', '0.072', '0.036', '0.008', '0.003', '0.003']
loss: 11.090	norms: ['0.226', '0.226', '0.320', '0.075', '0.037', '0.009', '0.003', '0.003']
loss: 11.089	norms: ['0.228', '0.229', '0.323', '0.079', '0.039', '0.010', '0.004', '0.004']
loss: 11.089	norms: ['0.232', '0.233', '0.329', '0.083', '0.042', '0.012', '0.004', '0.004']
loss: 11.088	norms: ['0.235', '0.236', '0.334', '0.089', '0.044', '0.013', '0.005', '0.005']
loss: 11.088	norms: ['0.240', '0.241', '0.341', '0.096', '0.048', '0.016', '0.007', '0.007']
loss: 11.087	norms: ['0.245', '0.246', '0.347', '0.103', '0.051', '0.018', '0.008', '0.008']
loss: 11.086	norms: ['0.249', '0.252', '0.355', '0.111', '0.055', '0.021', '0.010', '0.010']
loss: 11.085	norms: ['0.256', '0.256', '0.362', '0.120', '0.060', '0.0