In [108]:
import torch
import torch.nn as nn

In [109]:
num_embeddings = 64
d_model = 128
batch_size = 256
dtype = torch.float32

In [110]:
def l2norm(x: torch.Tensor):
    with torch.no_grad():
        return torch.linalg.vector_norm(x, dim=-1)

def init_parameters(m: nn.Module) -> None:
    std = 0.02
    cutoff = 3 * std
    with torch.no_grad():
        if hasattr(m, "weight") and m.weight is not None:
            nn.init.trunc_normal_(m.weight, mean=0.0, std=std, a=-cutoff, b=cutoff)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.zeros_(m.bias)

In [None]:
class CompetingModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.left_emb = nn.Embedding(num_embeddings, d_model, dtype=dtype)
        self.right_emb = nn.Embedding(num_embeddings, d_model, dtype=dtype)
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False, bias=False, dtype=dtype)
        self.left_out = nn.Linear(d_model, num_embeddings, bias=False, dtype=dtype)
        self.right_out = nn.Linear(d_model, num_embeddings, bias=False, dtype=dtype)
        
    def reset_parameters(self):
        init_parameters(self.left_emb)
        init_parameters(self.right_emb)
        init_parameters(self.norm)
        init_parameters(self.left_out)
        init_parameters(self.right_out)
        
    def forward(self, left_input: torch.Tensor, right_input: torch.Tensor):
        l1 = self.left_emb(left_input)
        r1 = self.right_emb(right_input)
        x1 = l1 + r1
        x2 = x1 # self.norm(x1)
        l2 = self.left_out(x2)
        r2 = self.right_out(x2)
        l3 = nn.functional.softmax(l2, dim=-1)
        r3 = nn.functional.softmax(r2, dim=-1)
        return ((l3, r3), (
            l2norm(l1),
            l2norm(r1),
            l2norm(x1),
            l2norm(x2),
            l2norm(l2),
            l2norm(r2),            
        ))
    
model = CompetingModel()
model.reset_parameters()

optim = torch.optim.AdamW(
    model.parameters(),
    lr = 1e-6
)

loss = nn.CrossEntropyLoss()

torch.manual_seed(1492)

for step in range(1000000):
    optim.zero_grad()
    left = torch.randint(0, num_embeddings, (batch_size,))
    right = torch.randint(0, num_embeddings, (batch_size,))
    ys, norms = model(left, right)
    left_loss = loss(ys[0], left)
    right_loss = loss(ys[1], right)
    total_loss = left_loss + right_loss
    total_loss.backward()
    optim.step()
    if step % 1000 == 0:
        norm_strings = [f"{x.mean().item():.3f}" for x in norms]
        print(f"loss: {total_loss.item():.3f}\tnorms: {norm_strings}")

loss: 8.318	norms: ['0.226', '0.221', '0.316', '0.316', '0.049', '0.050']
loss: 8.318	norms: ['0.224', '0.224', '0.318', '0.318', '0.050', '0.051']
loss: 8.318	norms: ['0.226', '0.222', '0.317', '0.317', '0.049', '0.051']
loss: 8.317	norms: ['0.227', '0.223', '0.318', '0.318', '0.052', '0.052']
loss: 8.317	norms: ['0.227', '0.226', '0.321', '0.321', '0.052', '0.053']
loss: 8.317	norms: ['0.231', '0.229', '0.326', '0.326', '0.055', '0.055']
loss: 8.317	norms: ['0.235', '0.230', '0.329', '0.329', '0.057', '0.057']
loss: 8.317	norms: ['0.239', '0.235', '0.335', '0.335', '0.060', '0.060']
loss: 8.317	norms: ['0.241', '0.236', '0.337', '0.337', '0.062', '0.062']
loss: 8.317	norms: ['0.244', '0.240', '0.342', '0.342', '0.065', '0.066']
loss: 8.317	norms: ['0.249', '0.246', '0.349', '0.349', '0.069', '0.069']
loss: 8.317	norms: ['0.253', '0.249', '0.355', '0.355', '0.073', '0.072']
loss: 8.317	norms: ['0.259', '0.255', '0.363', '0.363', '0.076', '0.076']
loss: 8.316	norms: ['0.262', '0.259', 