In [2]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from torch.nn.utils.rnn import pad_sequence
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class BertTranslator(nn.Module):
    def __init__(self, bert_model, target_vocab_size):
        super(BertTranslator, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.decoder = nn.Linear(self.bert.config.hidden_size, target_vocab_size)

    def forward(self, input_ids, attention_mask):
        # BERT Encoder
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        #print(outputs)

        # Only use the [CLS] token representation as the context vector
        context_vector = outputs.last_hidden_state[:, 0, :]
        #print(context_vector.shape)

        # Decoder
        output = self.decoder(context_vector)

        return output

In [4]:
sentences = ["Translate this sentence.", "How are you?", "BERT is powerful."]
translations = ["g1, g4, h4, j5", "t5,k5", "g1,g4,h6"]

# Create a vocabulary and map each unique token to an index
all_tokens = set([token for translation in translations for token in translation.split(',')])
vocab = {token: idx for idx, token in enumerate(all_tokens)}

# Convert translations to numerical indices and pad sequences
target_ids = [torch.tensor([vocab[token] for token in translation.split(',')], dtype=torch.long) for translation in translations]
target_ids_padded = pad_sequence(target_ids, batch_first=True, padding_value=-1)

# Find the maximum length in the batch
max_len = max(len(ids) for ids in target_ids)

# Pad sequences to the maximum length
padded_target_ids = torch.stack([torch.cat([ids, torch.full((max_len - len(ids),), -1, dtype=torch.long)]) for ids in target_ids])

# Load BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize and convert to PyTorch tensors
input_ids = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)['input_ids']

# Create the translation model
target_vocab_size = len(vocab)
translator_model = BertTranslator('bert-base-uncased', target_vocab_size)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-1)  # Ignore padding value during loss calculation
optimizer = torch.optim.Adam(translator_model.parameters(), lr=0.001)


In [5]:
def computeOneHotEncoding(trans):
    one_hot_encoding = []
    for translation in translations:
        vector = np.zeros(len(all_tokens))
        for token in translation.split(','):
            index = vocab[token]
            vector[index] = 1
        one_hot_encoding.append(vector)

    # Convert the list to a NumPy array
    one_hot_encoding_array = torch.tensor(np.array(one_hot_encoding))
    return one_hot_encoding_array

In [6]:
output = translator_model(input_ids, attention_mask=input_ids != tokenizer.pad_token_id)
output.shape

torch.Size([3, 8])

In [7]:
num_epochs = 10  # Adjust as needed
for epoch in range(num_epochs):
    # Forward pass
    output = translator_model(input_ids, attention_mask=input_ids != tokenizer.pad_token_id)

    # comput the one hot encoding of the translation
    one_hot_encoding_array = computeOneHotEncoding(translations)
    

    # Compute loss
    loss = criterion(output, one_hot_encoding_array)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print loss every 100 epochs
    if (epoch + 1) % 5 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# Test the model with the same input sentences
test_output = translator_model(input_ids, attention_mask=input_ids != tokenizer.pad_token_id)
print("Test Output:", test_output)

Epoch [5/10], Loss: 5.2537
Epoch [10/10], Loss: 4.6195
Test Output: tensor([[ 0.4607,  1.4016,  1.8252, -0.2963, -2.0317, -0.0842,  0.3523, -2.2040],
        [-2.3011, -1.3101, -1.7781, -0.5439,  3.9812, -0.8129, -3.1830,  4.8241],
        [ 0.4014,  0.5710,  1.6860,  0.5097, -1.8541,  0.7205, -0.0972, -1.5161]],
       grad_fn=<AddmmBackward0>)


In [8]:
translations = []
vocab_rev = {value: key for key, value in vocab.items()}
for row in test_output:
    tokens_indices = torch.argmax(row).tolist()
    print(vocab_rev[tokens_indices])
#    translations.append(','.join(tokens))

g1
k5
g1


In [9]:
vocab_rev = {value: key for key, value in vocab.items()}

In [10]:

vocab_rev

{0: ' j5', 1: ' h4', 2: 'g1', 3: 'h6', 4: 't5', 5: 'g4', 6: ' g4', 7: 'k5'}

In [11]:
sentences = ["this is this is this is"]

# Load BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize and convert to PyTorch tensors
input_ids = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)['input_ids']
# Test the model with the same input sentences
test_output = translator_model(input_ids, attention_mask=input_ids != tokenizer.pad_token_id)
print("Test Output:", test_output)

Test Output: tensor([[ 0.2222,  0.6848,  2.3058,  0.6187, -2.4861,  0.8842, -0.4293, -1.9598]],
       grad_fn=<AddmmBackward0>)


In [12]:
num_params = sum(p.numel() for p in translator_model.parameters())

In [13]:
num_params

109488392

In [14]:
test_output

tensor([[ 0.2222,  0.6848,  2.3058,  0.6187, -2.4861,  0.8842, -0.4293, -1.9598]],
       grad_fn=<AddmmBackward0>)

In [15]:
tune = []
for i in np.where(test_output.detach().numpy()[0]>0)[0]:
    tune.append(vocab_rev[i])

In [16]:
tune

[' j5', ' h4', 'g1', 'h6', 'g4']