# 1. Import modules

In [2]:
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from tqdm import tqdm

from modules import ContrastiveLoss, DummyFlickerDataset, DummyHardNegativeDataset, DummyTextToTextDataset
from modules import open_image, get_logger
from modules.models import ClipModelRetriever

# 2. Load model, loss function and optimizer

In [6]:
logger = get_logger()

device = "cuda" if torch.cuda.is_available() else "cpu"

num_epochs = 10
batch_size = 4
dataset_size = 32

model = ClipModelRetriever(device)

criterion = ContrastiveLoss().to(device)
optimizer = optim.Adam(model.model.parameters(), lr=0.01)

# 3. Modality-aware hard negative mining

### 3.1. Load dataset

In [7]:
dataset = DummyFlickerDataset(dataset_size)
hard_negative_dataset = DummyHardNegativeDataset(model=model, size=dataset_size)

def collate_fn(batch):
    query_index = [item[0] for item in batch]
    document_index = [item[1] for item in batch]
    documnet_type = [item[2] for item in batch]
    negative_indices = [item[3] for item in batch]

    return query_index, document_index, documnet_type, negative_indices 

dataloader = DataLoader(hard_negative_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

2025-03-03 15:40:05 - INFO - Initializing DummyFlickerDataset...
2025-03-03 15:40:05 - INFO - Initializing DummyFlickerDataset...
2025-03-03 15:40:05 - INFO - Initializing DummyHardNegativeDataset...
2025-03-03 15:40:05 - INFO - Initializing Multimodal Hard Negative Mining Controler.
2025-03-03 15:40:05 - INFO - Initializing Multimodal Hard Negative Mining Controler.


### 3.2. Finetuning

In [8]:
logger.info("Start Modality-aware hard negative finetuning...")

for i in range(num_epochs):
    total_loss = 0
    with tqdm(dataloader, desc=f"Training", dynamic_ncols=True) as epoch_bar:
        for query_index_set, document_index_set, documnet_type_set, negative_indices_set in epoch_bar:
            optimizer.zero_grad()
            batch_loss = 0
            for query_index, document_index, documnet_type, negative_indices in zip(query_index_set, document_index_set, documnet_type_set, negative_indices_set):
                query = dataset[query_index][0]

                outputs = model.text_encoder(texts=[query])
                
                positive_emb = torch.Tensor([]).to(device)
                
                if documnet_type == "text":
                    positive_sample = dataset[document_index][0]
                    positive_emb = model.text_encoder(texts=positive_sample)
                else:
                    positive_sample = dataset[document_index][1]
                    positive_image = open_image(positive_sample)
                    positive_emb = model.image_encoder(images=positive_image)   
                
                
                neg_text_emb = torch.Tensor([]).to(device)
                neg_image_emb = torch.Tensor([]).to(device)
                
                if negative_indices["text"]:
                    neg_text_samples = dataset[negative_indices["text"]][0].to_list()
                    neg_text_emb = model.text_encoder(texts=neg_text_samples)
                
                if negative_indices["image"]:
                    neg_image_samples = dataset[negative_indices["image"]][1].to_list()
                    neg_images = [open_image(image) for image in neg_image_samples]
                    neg_image_emb = model.image_encoder(images=neg_images)
                
                negative_emb = torch.concat([neg_text_emb, neg_image_emb], dim=0)

                loss = criterion(outputs, positive_emb, negative_emb)

                batch_loss = batch_loss + loss

            total_loss += batch_loss.item()
            batch_loss.backward()
            optimizer.step()

        print(f"Epoch {i+1}: {total_loss / len(hard_negative_dataset)}")


logger.info("Finish Modality-aware hard negative finetuning.")

2025-03-03 15:40:49 - INFO - Start Modality-aware hard negative finetuning...
Training: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 1: 4.133142292499542


Training: 100%|██████████| 8/8 [00:07<00:00,  1.08it/s]


Epoch 2: 3.7724913954734802


Training: 100%|██████████| 8/8 [00:07<00:00,  1.04it/s]


Epoch 3: 3.691261053085327


Training: 100%|██████████| 8/8 [00:08<00:00,  1.03s/it]


Epoch 4: 3.6880630254745483


Training: 100%|██████████| 8/8 [00:08<00:00,  1.08s/it]


Epoch 5: 3.702912151813507


Training: 100%|██████████| 8/8 [00:08<00:00,  1.11s/it]


Epoch 6: 3.695889800786972


Training: 100%|██████████| 8/8 [00:08<00:00,  1.04s/it]


Epoch 7: 3.6901063323020935


Training: 100%|██████████| 8/8 [00:08<00:00,  1.08s/it]


Epoch 8: 3.684810996055603


Training: 100%|██████████| 8/8 [00:08<00:00,  1.01s/it]


Epoch 9: 3.680789291858673


Training: 100%|██████████| 8/8 [00:08<00:00,  1.03s/it]
2025-03-03 15:42:12 - INFO - Finish Modality-aware hard negative finetuning.


Epoch 10: 3.6789028644561768


# 4. Text-to-Text finetuning

### 4.1. Load dataset

In [9]:
num_epochs = 10
batch_size = 8

dataset = DummyFlickerDataset(size=dataset_size)
finetuning_dataset = DummyTextToTextDataset(model=model, size=dataset_size)

def collate_fn_finetuning(batch):
    query_index = [item[0] for item in batch]
    document_index = [item[1] for item in batch]
    documnet_type = [item[2] for item in batch]
    negative_indices = [item[3] for item in batch]

    return query_index, document_index, documnet_type, negative_indices 

finetuning_dataloader = DataLoader(finetuning_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_finetuning)

2025-03-03 15:42:12 - INFO - Initializing DummyFlickerDataset...
2025-03-03 15:42:12 - INFO - Initializing DummyFlickerDataset...
2025-03-03 15:42:13 - INFO - Initializing DummyHardNegativeDataset...
2025-03-03 15:42:13 - INFO - Initializing Multimodal Hard Negative Mining Controler.
2025-03-03 15:42:13 - INFO - Initializing Multimodal Hard Negative Mining Controler.


### 4.2. Finetuning

In [10]:
logger.info("Start text-to-text finetuning...")

for i in range(num_epochs):
    total_loss = 0
    with tqdm(finetuning_dataloader, desc=f"Training", dynamic_ncols=True) as epoch_bar:
        for query_index_set, document_index_set, documnet_type_set, negative_indices_set in epoch_bar:
            optimizer.zero_grad()
            batch_loss = 0
            for query_index, document_index, documnet_type, negative_indices in zip(query_index_set, document_index_set, documnet_type_set, negative_indices_set):
                query = dataset[query_index][0]

                outputs = model.text_encoder(texts=[query])
                
                positive_emb = torch.Tensor([]).to(device)
                
                positive_sample = dataset[document_index][0]
                positive_emb = model.text_encoder(texts=positive_sample)
                
                neg_text_emb = torch.Tensor([]).to(device)
                
                if negative_indices["text"]:
                    neg_text_samples = dataset[negative_indices["text"]][0].to_list()
                    neg_text_emb = model.text_encoder(texts=neg_text_samples)

                negative_emb = neg_text_emb
                
                loss = criterion(outputs, positive_emb, negative_emb)

                batch_loss = batch_loss + loss

            total_loss += batch_loss.item()
            batch_loss.backward()
            optimizer.step()

        print(f"Epoch {i+1}: {total_loss / len(hard_negative_dataset)}")

logger.info("Finish text-to-text finetuning.")

2025-03-03 15:42:14 - INFO - Start text-to-text finetuning...
Training: 100%|██████████| 4/4 [00:01<00:00,  2.04it/s]


Epoch 1: 3.0440885424613953


Training: 100%|██████████| 4/4 [00:01<00:00,  2.27it/s]


Epoch 2: 3.0439098477363586


Training: 100%|██████████| 4/4 [00:01<00:00,  2.35it/s]


Epoch 3: 3.0437732338905334


Training: 100%|██████████| 4/4 [00:01<00:00,  2.34it/s]


Epoch 4: 3.0438162088394165


Training: 100%|██████████| 4/4 [00:01<00:00,  2.39it/s]


Epoch 5: 3.0435999631881714


Training: 100%|██████████| 4/4 [00:01<00:00,  2.39it/s]


Epoch 6: 3.0433127880096436


Training: 100%|██████████| 4/4 [00:01<00:00,  2.32it/s]


Epoch 7: 3.0429492592811584


Training: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s]


Epoch 8: 3.0424980521202087


Training: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s]


Epoch 9: 3.0418633818626404


Training: 100%|██████████| 4/4 [00:01<00:00,  2.24it/s]
2025-03-03 15:42:31 - INFO - Finish text-to-text finetuning.


Epoch 10: 3.0408228635787964


# 5. Prompting multimodal LLMS for reranking


In [13]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch.nn.functional as F

class ZeroShotReranker():
    def __init__(self):
        model_name = "gpt2"  
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.model.eval()
    
    def calculate_score(self, query: str, candidate):
        query = "Question: What is this person doing?"
        candidate = "Running on the beach."
        prompt = f"{query}\nAnswer: {candidate}\nDoes the answer correctly answer the question? True or False"
        
        inputs = self.tokenizer(prompt, return_tensors="pt")
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits 
        
        true_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(" True")[0])
        false_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(" False")[0])
        
        true_logit = logits[0, -1, true_token_id]
        false_logit = logits[0, -1, false_token_id]
        
        probs = F.softmax(torch.tensor([true_logit, false_logit]), dim=0)
        true_prob = probs[0].item()
        return true_prob

In [15]:
query = "Question: What is this person doing?"
candidate = "Running on the beach."

reranker = ZeroShotReranker()

true_prob = reranker.calculate_score(query, candidate)
print(f"True 확률 (관련성 점수): {true_prob:.4f}")

True 확률 (관련성 점수): 0.5386
