In [34]:
import torch #ML
import torch.nn.functional as F
from torch.autograd import Variable #these will help us define the prefixes we will optimize
import numpy as np #math
from transformers import GPT2Tokenizer, GPT2LMHeadModel #GPT-2 XL and its tokenizer
from tqdm import tqdm #progress bar
import csv #reading the CSV
from nltk.translate.bleu_score import sentence_bleu #BLEU score computation
device="cuda"

In [35]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-xl') #tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2-xl').to(device) #model

vocab_len, embed_size = tuple(model.state_dict()['transformer.wte.weight'].shape)
embedding_matrix = model.state_dict()['transformer.wte.weight'].clone().cpu() #actual embedding matrix

#we create a custom tokenizer function here, because we want to return an array of embeddings rather than 
#an array of indices
def tokenize(string):
    """
    Tokenizes a string and converts it to a tensor of embeddings.
    
    Args:
    string (str): The input string to be tokenized.
    
    Returns:
    A tensor of embeddings for the input string.
    """
    # Tokenize the string using the tokenizer
    x = torch.tensor(tokenizer(string)['input_ids']).view(1, -1)
    
    # Compute the prompt length and embeddings
    prompt_len = x.shape[-1]
    prompt_embeddings = F.one_hot(x, num_classes=vocab_len).float() @ embedding_matrix
    
    # Return the prompt embeddings on the device
    return prompt_embeddings.to(device)

# Assuming start_english_prompt and start_french_prompt are the variables you want to save

# Save the variables to a file
start_english_prompt = torch.load('/kaggle/input/eng-fr-prefixes/start_english_prompt tuning_length32.pth')
start_french_prompt = torch.load('/kaggle/input/eng-fr-prefixes/start_french_prompt tuning_length32.pth')

In [36]:
def create_example(lang1_str, lang2_str, start_lang1_prompt, start_lang2_prompt):
    """
    Concatenates the prefixes and the two input language strings and returns the resulting tensor and the
    target ground-truth tensor for the second language string.
    
    Args:
    lang1_str (str): The first language string.
    lang2_str (str): The second language string.
    start_lang1_prompt (torch.Tensor): The tensor representing the prefix for the first language string.
    start_lang2_prompt (torch.Tensor): The tensor representing the prefix for the second language string.
    
    Returns:
    A tuple containing the concatenated tensor of the prefixes and the two input language strings, and 
    the ground truth tensor for training.
    """
    lang1 = tokenize(lang1_str)
    lang2 = tokenize(lang2_str)
        
    out = torch.concat((start_lang1_prompt, lang1, start_lang2_prompt, lang2), dim=1)
    
    #we had an EOS token to the target so that we can know when the translation is done generating
    return out, torch.tensor(tokenizer(lang2_str + "<|endoftext|>")['input_ids']).view(1, -1)


In [37]:
def translate_argmax(start_lang1_prompt, input_str, start_lang2_prompt, verbose=True):
    """
    Translates an input language string to the target language using the argmax method.
    
    Args:
    start_lang1_prompt (torch.Tensor): The tensor representing the prefix for the input language.
    input_str (str): The input language string to be translated.
    start_lang2_prompt (torch.Tensor): The tensor representing the prefix for the target language.
    verbose (bool): If True, the function prints the predicted tokens as they are generated. Defaults to True.
    
    Returns:
    A list of token IDs representing the translated target language string.
    """
    lang1 = tokenize(input_str) #tokenize input
    out = torch.concat((start_lang1_prompt, lang1, start_lang2_prompt), dim=1) #create input prompt
    input_len = out.shape[1] #get the input length
    
    if input_len > 512: #memory constraints :(
        return 0
    
    with torch.no_grad(): #use no_grad to save memory
        out_sequence = [] #will be our output sequence
        
        #generate our first output token
        out = model(inputs_embeds=out.to(device))
        out_sequence.append(out['logits'].argmax(dim=-1).flatten()[-1].item())
        
        if verbose:
            print(tokenizer.decode(out_sequence[-1]), end='')
        
        #we terminate generation either when we see the EOS token, *or* when the length of the generation 
        #is greater than the context length, *or* when the length of the generation is ~2x the length of 
        #the input sequence. This last termination clause is just to not waste our time when the model 
        #gets stuck on some repetitive or non-sense generation.
        while out_sequence[-1] != 50256 and len(out_sequence) + input_len < min(1024, input_len*2):
            out = model(inputs_embeds=embedding_matrix[out_sequence[-1]].view(1, -1).to(device), 
                        past_key_values=out['past_key_values']) #use KV recycling to save compute!
            out_sequence.append(out['logits'].argmax(dim=-1)[-1].item())
            if verbose:
                if out_sequence[-1] != 50256:
                    print(tokenizer.decode(out_sequence[-1]), end='')
    
    return out_sequence[:-1] #return all but the EOS token

def en_fr_3_shot_argmax(input_str, verbose=True):
    """
    Translates a given English input string to French using a 3-shot prompt.
    
    Args:
    input_str (str): The input English string to be translated.
    verbose (bool, optional): If True, the function will print the translation as it is being generated. Defaults to True.
    
    Returns:
    out_sequence (list): A list of tokens representing the generated French translation.
    """
    out = tokenize(prompt_3_shot_en_fr + " " + input_str.strip() + "\nfr:") #prepare input prompt
    input_len = out.shape[1] #get the length of the input
    
    if input_len > 512: #memory constraints :(
        return 0
    
    with torch.no_grad(): #save memory with no_grad
        out_sequence = []
        out = model(inputs_embeds=out.to(device)) #sample first token
        out_sequence.append(out['logits'].argmax(dim=-1).flatten()[-1].item())
        
        if verbose:
            print(tokenizer.decode(out_sequence[-1]), end='')
        
        #we terminate generation either when we see a newline token, *or* when the length of the generation 
        #is greater than the context length, *or* when the length of the generation is ~2x the length of 
        #the input sequence. This last termination clause is just to not waste our time when the model 
        #gets stuck on some repetitive or non-sense generation.
        while out_sequence[-1] != 198 and len(out_sequence) + input_len < min(1024, input_len*2):
            out = model(inputs_embeds=embedding_matrix[out_sequence[-1]].view(1, -1).to(device), 
                        past_key_values=out['past_key_values'])
            out_sequence.append(out['logits'].argmax(dim=-1)[-1].item())
            if verbose:
                if out_sequence[-1] != 198:
                    print(tokenizer.decode(out_sequence[-1]), end='')
    
    return out_sequence[:-1]

def fr_en_3_shot_argmax(input_str, verbose=True):
    """
    Translates a given French input string to English using a 3-shot prompt.
    
    Args:
    input_str (str): The input French string to be translated.
    verbose (bool, optional): If True, the function will print the translation as it is being generated. Defaults to True.
    
    Returns:
    out_sequence (list): A list of tokens representing the generated English translation.
    """
    out = tokenize(prompt_3_shot_fr_en + " " + input_str.strip() + "\nen:") #prepare input prompt 
    input_len = out.shape[1] #get the length of the input
    
    if input_len > 512: #memory constraints :(
        return 0
    
    with torch.no_grad(): #save memory with no_grad
        out_sequence = []
        out = model(inputs_embeds=out.to(device)) #sample first token
        out_sequence.append(out['logits'].argmax(dim=-1).flatten()[-1].item())
        
        if verbose:
            print(tokenizer.decode(out_sequence[-1]), end='')
        
        #we terminate generation either when we see a newline token, *or* when the length of the generation 
        #is greater than the context length, *or* when the length of the generation is ~2x the length of 
        #the input sequence. This last termination clause is just to not waste our time when the model 
        #gets stuck on some repetitive or non-sense generation.
        while out_sequence[-1] != 198 and len(out_sequence) + input_len < min(1024, input_len*2):
            out = model(inputs_embeds=embedding_matrix[out_sequence[-1]].view(1, -1).to(device), 
                        past_key_values=out['past_key_values'])
            out_sequence.append(out['logits'].argmax(dim=-1)[-1].item())
            if verbose:
                if out_sequence[-1] != 198:
                    print(tokenizer.decode(out_sequence[-1]), end='')
    
    return out_sequence[:-1]

In [38]:
translate_argmax(start_french_prompt, "Je suis fute", start_english_prompt)

I am fit

[40, 716, 4197]