In [1]:
import torch

DEVICE = "cuda"

In [4]:

# Dummy Embeddings
embeddings = torch.nn.Embedding(52000, 768).to(device=DEVICE)

# Dummy Synthetic Tokens
synthetic_tokens = torch.randint(0, 52000-1, (120, 128)).to(DEVICE)

# Compute Synthetic Data
### Synthetic Data is the actual syn_data that is updated.
synthetic_data = embeddings(synthetic_tokens)
synthetic_data.shape, embeddings.weight.shape

(torch.Size([120, 128, 768]), torch.Size([52000, 768]))

### Quick Decoding Scheme for Embedding to Token

In [8]:
def decode_embeddings(embeddings: torch.nn.Embedding, embedded_data: torch.Tensor):
    num_sentences = embedded_data.shape[0]
    sentence_len = embedded_data.shape[1]

    sentences = torch.zeros(num_sentences, sentence_len, device=DEVICE, dtype=torch.long)

    for i in range(num_sentences):
        sentence = torch.cdist(embedded_data[i, :, :].to(DEVICE), embeddings.weight.to(DEVICE), p=2)
        sentences[i] = sentence.argmin(-1)
    
    return sentences.type(torch.long)

In [9]:
# Do some operations on synthetic data

# Update synthetic data

# Convert synthetic data into new synthetic tokens.
new_synthetic_tokens = decode_embeddings(embeddings, synthetic_data)

# Verify for non updated
(new_synthetic_tokens == synthetic_tokens).all()

tensor(True, device='cuda:0')

In [10]:
new_synthetic_tokens

tensor([[32203,  3636,  5115,  ..., 18433, 18306, 10781],
        [28191, 34624, 15761,  ..., 25586,  3374, 44167],
        [49045, 48424, 42490,  ..., 24051,  4248, 10695],
        ...,
        [ 5118,  1574, 18591,  ..., 37280, 39965, 34653],
        [24316, 30113,  8084,  ..., 10403, 17517, 27080],
        [ 1594, 33412, 32000,  ...,  5732, 25073, 49251]], device='cuda:0')

## Soft Hot Encoding

In [2]:
# Dummy Synthetic Tokens
synthetic_tokens = torch.rand((120, 128, 50257), device=DEVICE, requires_grad=True)


In [9]:
# # Dummy Synthetic Tokens
# synthetic_tokens = torch.rand((120, 128, 50257)).to(DEVICE)
synthetic_tokens.requires_grad = True
# Dummy Embeddings
embeddings = torch.nn.Embedding(50257, 768).to(device=DEVICE)

# Compute Embeddings
em = synthetic_tokens @ embeddings.weight
em.shape

torch.Size([120, 128, 768])

tensor([0.7802, 0.4084, 0.3867,  ..., 0.8092, 0.2666, 0.3060], device='cuda:0',
       grad_fn=<SelectBackward0>)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.88 GiB (GPU 0; 4.00 GiB total capacity; 3.28 GiB already allocated; 0 bytes free; 3.32 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [10]:
torch.cuda.empty_cache()

In [7]:
synthetic_tokens.argmax(-1)

tensor([[16430, 26160, 33138,  ..., 45060, 43348, 23012],
        [10364, 43655,  6679,  ..., 10967, 40995, 12474],
        [ 9899,  9064, 12806,  ..., 45797, 30795, 42207],
        ...,
        [12166, 16661,  9282,  ..., 16306, 25395, 15862],
        [24887, 23780, 23485,  ..., 10882, 44750,  5158],
        [40683, 35154, 14796,  ...,  9812, 17618, 31652]], device='cuda:0')