In [1]:
import torch
import torch.nn as nn

class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=2):
        super(EmbeddingModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
    
    def forward(self, x):
        return self.embedding(x)

# Set up the model
vocab_size = 2
model = EmbeddingModel(vocab_size)

# Print the embedding weights
print("Initial embedding weights:")
print(model.embedding.weight)

# Function to get embeddings for input
def get_embedding(input_data):
    with torch.no_grad():
        return model(input_data)

# Test with different inputs
print("\nTesting with different inputs:")
inputs = [
    torch.LongTensor([0]),
    torch.LongTensor([1]),
    torch.LongTensor([0, 1]),
    torch.LongTensor([[0, 1], [1, 0]])
]

for i, input_data in enumerate(inputs):
    embedding = get_embedding(input_data)
    print(f"\nInput {i + 1}: {input_data}")
    print(f"Embedding: {embedding}")
    print(f"Embedding shape: {embedding.shape}")

Initial embedding weights:
Parameter containing:
tensor([[-0.6822, -0.9730],
        [-0.1405, -0.0839]], requires_grad=True)

Testing with different inputs:

Input 1: tensor([0])
Embedding: tensor([[-0.6822, -0.9730]])
Embedding shape: torch.Size([1, 2])

Input 2: tensor([1])
Embedding: tensor([[-0.1405, -0.0839]])
Embedding shape: torch.Size([1, 2])

Input 3: tensor([0, 1])
Embedding: tensor([[-0.6822, -0.9730],
        [-0.1405, -0.0839]])
Embedding shape: torch.Size([2, 2])

Input 4: tensor([[0, 1],
        [1, 0]])
Embedding: tensor([[[-0.6822, -0.9730],
         [-0.1405, -0.0839]],

        [[-0.1405, -0.0839],
         [-0.6822, -0.9730]]])
Embedding shape: torch.Size([2, 2, 2])
