In [1]:
import torch
import torch.nn as nn

In [2]:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embs = nn.Embedding(3, 32)
        self.bottleneck = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 32)
        )
        self.head = nn.Linear(32, 3, bias=False)
        self.head.weight = self.embs.weight
    
    def forward(self, x):
        x = self.embs(x)
        x = self.bottleneck(x)
        x = self.head(x)
        return x

In [3]:
x = torch.randint(3, size=(55, ))
y = torch.where(x - 1 == -1, 2, x - 1)

model = MyModel()
loss = nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=1e-2)

print('Weights before training')
print(model.embs.weight)
print(model.head.weight)

for _ in range(100):
    y_pred = model(x)
    loss_val = loss(y_pred, y)
    loss_val.backward()
    opt.step()
    opt.zero_grad()

print('Weights before training')
print(model.embs.weight)
print(model.head.weight)

Weights before training
Parameter containing:
tensor([[ 1.4653, -0.0140, -0.7266, -0.6036,  0.4302,  0.6452, -0.0399, -1.1336,
         -0.7822,  1.3538, -0.6181, -0.3970,  0.9674,  2.1213,  0.9207,  1.1345,
          0.6180,  0.3821, -0.0822,  0.4593,  0.1340,  0.3734, -0.1101,  0.4247,
          1.2964,  0.7381,  0.7128, -0.8335,  0.8710,  0.2039,  0.4372, -0.6724],
        [-3.0737,  0.1213, -0.2475, -0.7274, -0.6475,  0.3493,  0.0206,  0.2845,
         -0.6437, -0.4328,  1.6070,  0.9812, -1.5739,  0.7342, -1.4385,  0.2355,
         -1.0646,  0.0409, -0.4948,  0.3094, -1.8474, -0.3741, -0.5255,  0.9813,
          0.0763,  0.2365,  1.0601, -0.0596,  0.3631, -0.3902, -1.0279,  0.6889],
        [-0.5639, -1.1827,  0.2241, -2.3446,  0.2981, -2.3123,  1.2089,  0.3537,
          1.5817, -0.4047, -1.0358,  0.0698,  0.7488, -1.6550, -0.4301, -1.3078,
         -1.3123,  1.0817, -0.4174,  0.3847,  0.3943,  0.8967, -0.7576,  0.8159,
          1.0181, -0.2359,  1.7895, -1.2535, -1.4701,  0.6914