# Sentence Reconstruction
---

### Colab Setup

In [None]:
!pip install transformers
!pip install sentence-transformers

### Setup Environment

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers.optimization import get_linear_schedule_with_warmup
from sentence_transformers import SentenceTransformer

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

from pycocotools.coco import COCO

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [3]:
%%capture
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
model.eval()

sent_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device=device)
sent_model.eval()

### Setup Data for Training

In [4]:
class COCOCaptionsDataset(Dataset):
    def __init__(self, caption_file):
        self.coco = COCO(caption_file)
        self.coco_data = []

        for i, data_dict in self.coco.anns.items():
            self.coco_data.append(data_dict['caption'])

    def __len__(self):
        return len(self.coco_data)

    def __getitem__(self, idx):
        return self.coco_data[idx]

In [5]:
coco_dataset = COCOCaptionsDataset("Data/coco_captions/captions_train2017.json")

loading annotations into memory...
Done (t=0.54s)
creating index...
index created!


### Renconstruction Training (Sentence Embedding -> Sentence)

In [6]:
def train_lm(model, tokenizer, caps):

    batch_size = len(caps)

    targets = tokenizer(caps, padding=True, return_tensors='pt', return_attention_mask=True)
    targets_ids = targets['input_ids'].to(device)
    targets_mask = targets['attention_mask'].to(device)

    token_embs = model.transformer.wte(targets_ids)\
    
    sent_embs = sent_model.encode(caps, convert_to_tensor=True)

    input_clip_embs = torch.zeros((token_embs.size(0), token_embs.size(1)+1, token_embs.size(2)), device=device)
    target_clip_mask = torch.zeros((targets_mask.size(0), targets_mask.size(1)+1), dtype=torch.int64, device=device)
    target_clip_ids = torch.zeros((targets_ids.size(0), targets_ids.size(1)+1), dtype=torch.int64, device=device)

    input_clip_embs = torch.cat((sent_embs.unsqueeze(1), token_embs), dim=1) # Add input CLIP embedding
    target_clip_ids = torch.cat((torch.zeros((batch_size, 1)).to(device), targets_ids), dim=1)    # Add dummy token; is ignored in loss 
    target_clip_mask = torch.cat((torch.ones((batch_size, 1)).to(device), targets_mask), dim=1)  # Avoid masking new token

    outputs = model(
        inputs_embeds=input_clip_embs,
        return_dict=True,
        output_hidden_states=True,
        attention_mask=target_clip_mask
    )

    loss = torch.nn.functional.cross_entropy(outputs['logits'][:, :-1].reshape(-1, outputs['logits'].size(-1)), targets_ids.flatten(), ignore_index=0)

    return loss

def tune_caption_to_CLIP(model, tokenizer, optimizer, scheduler, dataloader, epochs=5):
    
    model.train()
    for epoch in range(epochs):
       
        print(f"Training epoch: {epoch}")
        num_batches = len(dataloader)

        for batch_idx, caps in  tqdm(enumerate(dataloader)):

            optimizer.zero_grad()

            loss = train_lm(model, tokenizer, caps)
            loss.backward()

            optimizer.step()
            scheduler.step()
            
            if batch_idx % 1000 == 0 and batch_idx != 0:
                print(f"Loss at batch {batch_idx} / {num_batches}  = {loss}")

In [7]:
coco_dataloader = DataLoader(coco_dataset, batch_size=32, shuffle=True)

epochs = 3
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=2000, num_training_steps=epochs * len(coco_dataloader))

tune_caption_to_CLIP(model, tokenizer, optimizer, scheduler, coco_dataloader, epochs=epochs)

Training epoch: 0


0it [00:00, ?it/s]

Loss at batch 1000 / 18493  = 1.9773715734481812
Loss at batch 2000 / 18493  = 1.197084665298462
Loss at batch 3000 / 18493  = 1.0590198040008545
Loss at batch 4000 / 18493  = 1.1088236570358276
Loss at batch 5000 / 18493  = 1.2102621793746948
Loss at batch 6000 / 18493  = 1.2789114713668823
Loss at batch 7000 / 18493  = 0.8305197358131409
Loss at batch 8000 / 18493  = 0.9642782211303711
Loss at batch 9000 / 18493  = 1.1025140285491943
Loss at batch 10000 / 18493  = 0.7415385842323303
Loss at batch 11000 / 18493  = 0.8394097685813904
Loss at batch 12000 / 18493  = 0.7122843265533447
Loss at batch 13000 / 18493  = 0.7062592506408691
Loss at batch 14000 / 18493  = 0.9173645973205566
Loss at batch 15000 / 18493  = 0.8504229187965393
Loss at batch 16000 / 18493  = 0.697601318359375
Loss at batch 17000 / 18493  = 0.655626654624939
Loss at batch 18000 / 18493  = 0.8708838224411011
Training epoch: 1


0it [00:00, ?it/s]

Loss at batch 1000 / 18493  = 0.6847610473632812
Loss at batch 2000 / 18493  = 0.5171463489532471
Loss at batch 3000 / 18493  = 0.6090697646141052
Loss at batch 4000 / 18493  = 0.9996364712715149
Loss at batch 5000 / 18493  = 0.6395652890205383
Loss at batch 6000 / 18493  = 0.5863460302352905
Loss at batch 7000 / 18493  = 0.5561689138412476
Loss at batch 8000 / 18493  = 0.5133061408996582
Loss at batch 9000 / 18493  = 0.7221873998641968
Loss at batch 10000 / 18493  = 0.6266103386878967
Loss at batch 11000 / 18493  = 0.49382537603378296
Loss at batch 12000 / 18493  = 0.5834024548530579
Loss at batch 13000 / 18493  = 0.7013503313064575
Loss at batch 14000 / 18493  = 0.5959141850471497
Loss at batch 15000 / 18493  = 0.49765536189079285
Loss at batch 16000 / 18493  = 0.7473887801170349
Loss at batch 17000 / 18493  = 0.4556562602519989
Loss at batch 18000 / 18493  = 0.534711480140686
Training epoch: 2


0it [00:00, ?it/s]

Loss at batch 1000 / 18493  = 0.4708184599876404
Loss at batch 2000 / 18493  = 0.3935842216014862
Loss at batch 3000 / 18493  = 0.42989712953567505
Loss at batch 4000 / 18493  = 0.36175981163978577
Loss at batch 5000 / 18493  = 0.5760279893875122
Loss at batch 6000 / 18493  = 0.49386459589004517
Loss at batch 7000 / 18493  = 0.33755549788475037
Loss at batch 8000 / 18493  = 0.4637730121612549
Loss at batch 9000 / 18493  = 0.2827076315879822
Loss at batch 10000 / 18493  = 0.41411733627319336
Loss at batch 11000 / 18493  = 0.37108033895492554
Loss at batch 12000 / 18493  = 0.43648576736450195
Loss at batch 13000 / 18493  = 0.44258612394332886
Loss at batch 14000 / 18493  = 0.4757770895957947
Loss at batch 15000 / 18493  = 0.2883436679840088
Loss at batch 16000 / 18493  = 0.3867112398147583
Loss at batch 17000 / 18493  = 0.37036436796188354
Loss at batch 18000 / 18493  = 0.34048911929130554


### Reconstruct text from embedding

In [2]:
# From https://github.com/rmokady/CLIP_prefix_caption/blob/main/predict.py
def generate(
    model,
    tokenizer,
    tokens=None,
    prompt=None,
    embed=None,
    entry_count=1,
    entry_length=67,  # maximum number of words
    top_p=0.8,
    temperature=1.0,
    stop_token: str = ".",
):
    model.eval()
    generated_num = 0
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    device = next(model.parameters()).device

    with torch.no_grad():

        for entry_idx in range(entry_count):
            if embed is not None:
                generated = embed
            else:
                if tokens is None:
                    tokens = torch.tensor(tokenizer.encode(prompt))
                    tokens = tokens.unsqueeze(0).to(device)

                generated = model.transformer.wte(tokens)

            for i in range(entry_length):
                generated = generated.to(device)
                outputs = model(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(
                    torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
                )
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value
                next_token = torch.argmax(logits, -1).unsqueeze(0)
                next_token_embed = model.transformer.wte(next_token)
                if tokens is None:
                    tokens = next_token
                else:
                    tokens = torch.cat((tokens, next_token), dim=1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                if stop_token_index == next_token.item():
                    break

            output_list = list(tokens.squeeze().cpu().numpy())
            output_text = tokenizer.decode(output_list)
            generated_list.append(output_text)

    return generated_list[0]

In [None]:
# Encode text
text = "A man walking his dog."

# test_idx = 93
# text = coco_dataset.coco_data[test_idx]

print("Original Text:\n\n", text, end="\n\n")
text_emb = sent_model.encode([text], convert_to_tensor=True)

print("Reconstructed text:", end="")
generate(model, tokenizer, embed=text_emb.view(1, 1, -1), stop_token=tokenizer.pad_token)

In [None]:
# Works pretty well for text that is in the training distribution; however, poor generalization as OOD inputs lead to poor reconstruction.