In [None]:
from transformers import AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import MSELoss
from sentence_transformers import SentenceTransformer, util
import matplotlib.pyplot as plt

[train_dataset, valloader] = torch.load('wikiloader_ssquote_ssbord')
trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)

device='cuda'
sent = SentenceTransformer("all-mpnet-base-v2", device=device)

In [None]:
class Model(nn.Module):
    def __init__(self, d_emb, original_model):
        super().__init__()
        hidden_size = original_model.config.hidden_size
        
        self.emb = nn.Linear(d_emb, hidden_size)
        self.emb_rev = nn.Linear(hidden_size, d_emb)
        
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=16,
            lora_alpha=32,
            lora_dropout=0.1
        )
        self.mod = get_peft_model(original_model, lora_config)

    def forward(self, x):
        x = self.emb(x)
        
        x = self.mod(inputs_embeds=x)
        x = x.hidden_states[-1]
        
        x = self.emb_rev(x)
        return x

#orig_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", output_hidden_states=True)
#orig_model = AutoModelForCausalLM.from_pretrained("gpt2", output_hidden_states=True)
orig_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", output_hidden_states=True)

In [None]:
model = Model(sent.get_sentence_embedding_dimension(), orig_model).to(device)


optimizer = Adam([
    *model.emb.parameters(),
    *model.emb_rev.parameters(),
    *model.mod.parameters(),
])

pad = 99

def criterion(output, target):
    mask = target == pad
    output = output[~mask]
    target = target[~mask]
    mse = nn.MSELoss(reduction='none'); out = mse(output, target).sum()
    #mse = nn.MSELoss(); out = mse(output, target)
    #cossim = nn.CosineSimilarity(dim=-1, eps=1e-6); out = 1-cossim(output, target) #PEUT-ETRE DIM=1
    return out

val_list=[]

def test(epoch):
    model.eval()
    with torch.no_grad():
        src = next(iter(valloader))
        outputs = model(src[:, :-1])
        val_loss = criterion(outputs, src[:, 1:]).item()
        print('Epoch', str(epoch+1) + ', Loss:', val_loss)
        val_list.append(val_loss)

test(-1)

for epoch in range(1000):
    model.train()
    for src in trainloader:
        optimizer.zero_grad()
        outputs = model(src[:, :-1])
        loss = criterion(outputs, src[:, 1:])
        loss.backward()
        optimizer.step()
        #scheduler.step()

    test(epoch)
    
    #early stopping code
    if val_list[-1] < min(val_list[:-1]):#its performs better, we save the model
        model_save = model
    elif (len(val_list) - val_list.index(min(val_list))) > 3: #no better model in the last epochs
        break;

model = model_save
plt.plot(val_list)
plt.show()

In [None]:
def predict(transfo, src, steps=max_len):
    if steps > max_len:
        print('steps could not be superior to max_len (who is '+str(max_len)+')')
        print('steps is set to '+str(max_len))
        steps = max_len
    if steps == 0:
        print('steps is set to 1')
        steps = 1

    transfo.eval()
    with torch.no_grad():

        text = sent.encode(src, convert_to_tensor=True).to("cuda").unsqueeze(0)
        
        while text.size(1) < steps:
            output = (transfo(text))[:,-1:]
            text = torch.cat((text, output), dim=1)

        similarities = util.semantic_search(text[0, 1:], corpus_embedding) #delete the batch and the first token
        affichage = [[f"{round(sim['score'], 2)} {corpus[sim['corpus_id']]}" for sim in liste] for liste in similarities] #score, corpus dans un string


        for aff in src:
            print(aff, '\n')
        for aff in affichage:
            print(aff[0])
        
        df = pd.DataFrame(affichage).transpose()
        df.columns = [i+len(src) for i in df.columns]
        
        for i, source in enumerate(src):
            new_column = [source] + [""] * (len(df) - 1)
            df.insert(i, str(i), new_column)

        return df


#src = [' = John Lenon = \n']
#src = ['John Lenon']
src = ['= John Lenon =', 'John Winston Ono Lennon(born John Winston Lennon; 9 October 1940 - 8 December 1980) was an English singer-songwriter, musician and political activist. He gained worldwide fame as the founder, co-lead vocalist and rhythm guitarist of the Beatles. His work included music, writing, drawings and film. His songwriting partnership with Paul McCartney remains the most successful in history as the primary songwriters in the Beatles.']
#src = ['Chocolate cake recipe']
#src = ['= Chocolate cake =', 'Chocolate cake or chocolate gâteau (from French: gâteau au chocolat) is a cake flavored with melted chocolate, cocoa powder, or both. It can also have other ingredients such as fudge, vanilla creme, and other sweeteners.']
#src = ['import pandas as pd']
predict(model, src)