In [1]:
!pip install -U sentence-transformers
!pip install gdown
!pip install thop

##Requirement
import torch
import pandas as pd
import random
import numpy as np
import functools
import pickle
import torch.nn as nn
import json
import gdown
import os
# import evaluate
from thop import profile
import warnings
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, AutoConfig
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers.util import cos_sim
from sentence_transformers import SentenceTransformer,util
from collections import defaultdict, Counter
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.optim import AdamW
from transformers import get_scheduler

### Path load dataset
# img_path = "/kaggle/input/cosmos/images_test_acm/images_test_acm/test"
test_path = "/content/public_test_acm.json"

#Load data

class LoadTest(Dataset):
    def __init__(self, file_path):
        with open(file_path, 'r') as file:
            self.data = [json.loads(line) for line in file]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        def extract_mid(caption):
            if len(caption) > 128:
                start = max(0, len(caption) // 2 - 64)
                end = min(len(caption), len(caption) // 2 + 64)
                return caption[start:end]
            return caption

        caption1 = extract_mid(item.get('caption1', ''))
        caption2 = extract_mid(item.get('caption2', ''))
        concatenated_caption = f"{caption1} {caption2}"

        label = int(item.get('context_label', 0))

        return {
            'img_local_path': item.get('img_local_path'),
            'caption1':caption1,
            'caption2':caption2,
            'text': concatenated_caption,
            'label': label,
        }

class HeuristicDataLoader(Dataset):
    def __init__(self, df):
        self.data = df.to_dict('records')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            'img_local_path': item.get('img_local_path'),
            'caption1': item.get('caption1'),
            'caption2': item.get('caption2'),
            'label': item.get('label'),
            'sbertlabel':item.get('pred_y')
        }

def collate_fn(batch):
    inputs = tokenizer([item['text'] for item in batch], padding=True, truncation=True, return_tensors='pt')
    labels = torch.tensor([item['label'] for item in batch])
    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': labels
    }

## Preprocessing
class ExplainableModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.bert_config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.intermediate = AutoModel.from_pretrained(model_name)

        output_size = 2 * self.bert_config.hidden_size
        self.output = nn.Linear(output_size, output_size)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        mean_embeddings = sum_embeddings / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return mean_embeddings

    def forward(self, input_ids_tuple, attention_mask_tuple):
        input_ids1, attention_mask1 = input_ids_tuple
        input_ids2, attention_mask2 = attention_mask_tuple

        with torch.no_grad():
            model_output1 = self.intermediate(input_ids1, attention_mask=attention_mask1).last_hidden_state
            model_output2 = self.intermediate(input_ids2, attention_mask=attention_mask2).last_hidden_state

        sentence_embeddings1 = self.mean_pooling(model_output1, attention_mask1)
        sentence_embeddings1 = F.normalize(sentence_embeddings1, p=2, dim=1)

        sentence_embeddings2 = self.mean_pooling(model_output2, attention_mask2)
        sentence_embeddings2 = F.normalize(sentence_embeddings2, p=2, dim=1)

        return sentence_embeddings1, sentence_embeddings2


def get_ids(text):
    caption1_str = text['caption1']
    caption2_str = text['caption2']

    inputs1 = tokenizer(
        caption1_str,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=128
    )

    inputs2 = tokenizer(
        caption2_str,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=128
    )

    return inputs1, inputs2

def get_embeddings(data_loader):
    all_embeddings_caption1 = []
    all_embeddings_caption2 = []
    all_labels = []

    for batch in data_loader:
        inputs1_tensor_list = []
        inputs2_tensor_list = []

        labels_batch = batch['label']
        valid_indices = [i for i, label in enumerate(labels_batch) if label is not None]
        labels_batch = [label for label in labels_batch if label is not None]
        labels_tensor = torch.tensor(labels_batch, dtype=torch.float32).to('cuda')

        inputs1, inputs2 = get_ids(batch)
        inputs1_tensor_list.append(inputs1.to('cuda'))
        inputs2_tensor_list.append(inputs2.to('cuda'))

        processed_batch = {
            'ids1': inputs1_tensor_list,
            'ids2': inputs2_tensor_list,
            'labels': labels_tensor,
        }

        embeddings_caption1, embeddings_caption2 = explainable_model(
            (processed_batch['ids1'][0]['input_ids'], processed_batch['ids1'][0]['attention_mask']),
            (processed_batch['ids2'][0]['input_ids'], processed_batch['ids2'][0]['attention_mask'])
        )

        embeddings_caption1 = embeddings_caption1.cpu().numpy()
        embeddings_caption2 = embeddings_caption2.cpu().numpy()
        labels = labels_tensor.cpu().numpy()

        all_embeddings_caption1.append(embeddings_caption1)
        all_embeddings_caption2.append(embeddings_caption2)
        all_labels.append(labels)

    all_embeddings_caption1 = np.concatenate(all_embeddings_caption1, axis=0)
    all_embeddings_caption2 = np.concatenate(all_embeddings_caption2, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    return all_embeddings_caption1, all_embeddings_caption2, all_labels


def cosine_collate_fn(batch):
    collated_batch = {}
    for key in batch[0].keys():
        if isinstance(batch[0][key], torch.Tensor):
            collated_batch[key] = torch.stack([sample[key] for sample in batch])
        else:
            collated_batch[key] = [sample[key] for sample in batch]
    return collated_batch


def get_embeddings_nli(texts, model, tokenizer, device):
    inputs = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')
    inputs = {key: tensor.to(device) for key, tensor in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        embeddings = torch.nn.functional.softmax(logits, dim=1)

    return embeddings.cpu().numpy()

def collate_fn_combined(batch, model=None, tokenizer=None, device=None):
    nli_true_postfix = " is true"
    similarities = []
    nli_scores_true = []

    premise_true = [f'{item["caption1"]}{nli_true_postfix}' for item in batch]
    hypothesis = [item['caption2'] for item in batch]
    if model is not None and tokenizer is not None and device is not None:
        inputs_true = tokenizer(premise_true, hypothesis, truncation=True, padding=True, return_tensors='pt')
        inputs_true = {key: tensor.to(device) for key, tensor in inputs_true.items()}

        with torch.no_grad():
            output_true = model(**inputs_true)
            predictions_true = torch.softmax(output_true.logits, dim=-1)
            prediction_probs_true = predictions_true.cpu().numpy()

            for pred_true in tqdm(prediction_probs_true, desc="Calculating NLI scores", total=len(batch)):
                label_names = ["entailment", "neutral", "contradiction"]
                prediction_true = {name: round(float(prob) * 100, 1) for prob, name in zip(pred_true, label_names)}
                nli_scores_true.append(prediction_true['contradiction'])

    for idx, item in enumerate(batch):
        item['nli_score_is_true'] = nli_scores_true[idx]
        item['input_ids'] = inputs_true['input_ids'][idx]
        item['attention_mask'] = inputs_true['attention_mask'][idx]
    return batch

class Prepare_data_pred(Dataset):
    def __init__(self, df):
        self.data = df.to_dict('records')
        self.nli = 0.6 #0.75

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        x = {'nli_score_is_true': item.get('nli_score_is_true')}

        prediction_info = {'predict': None}
        first_sen_contrast = x['nli_score_is_true'] >= self.nli
        if first_sen_contrast:
            prediction_info['predict'] = 1
        else:
            prediction_info['predict'] = 0

        return {**x, **prediction_info}

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
#=========================================

# SBERT CLASSIFICATION
##Load data & Model for classification
###load model sbert for classification & cosine similarity calculator
batch_size = 64
sb_model_name = "sentence-transformers/all-mpnet-base-v2"
tokenizer = AutoTokenizer.from_pretrained(sb_model_name)
sb_model = AutoModelForSequenceClassification.from_pretrained(sb_model_name)


url = 'https://drive.google.com/uc?id=1lAn4jHtBBx-Y4sL92tW4YskbUNT4KgU9'
output = 'model.pth'
gdown.download(url, output, quiet=False)
saved_state_dict = torch.load(output, map_location=torch.device('cpu'))


num_classes_saved = saved_state_dict['classifier.out_proj.weight'].shape[0]
num_classes_current = sb_model.config.num_labels

if num_classes_saved != num_classes_current:
    sb_model.config.num_labels = num_classes_saved
    sb_model.classifier.out_proj = torch.nn.Linear(sb_model.config.hidden_size, num_classes_saved)

sb_model.load_state_dict(saved_state_dict, strict=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sb_model.to(device)

test_data = LoadTest(test_path)
test_dataloader = DataLoader(test_data, batch_size=batch_size, collate_fn=collate_fn)

### Processing classification
predicted_labels = []
for batch in test_dataloader:
    with torch.no_grad():
        outputs = sb_model(batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device))
    predicted_labels.extend(outputs.logits.argmax(dim=-1).cpu().tolist())

input_sample = next(iter(test_dataloader))
input_ids = input_sample['input_ids'].to(device)
attention_mask = input_sample['attention_mask'].to(device)
macs_classify, params_classify = profile(sb_model, (input_ids, attention_mask))

df = pd.DataFrame({
    'img_local_path': [item['img_local_path'] for item in test_data],
    'caption1':[item['caption1'] for item in test_data],
    'caption2':[item['caption2'] for item in test_data],
    'text': [item['text'] for item in test_data],
    'label': [item['label'] for item in test_data],
    'pred_y': predicted_labels
})

warnings.filterwarnings("ignore", category=FutureWarning)

# SBERT + NLI + HEURISTIC
### Load data for heuristic step
df1 = df[df['pred_y'] == 0.0].copy()
dataset = HeuristicDataLoader(df1)
finall_dataset = HeuristicDataLoader(df)

##load data and calculator cosine similarity from sbert

explainable_model = ExplainableModel(sb_model_name)
explainable_model = explainable_model.to('cuda')
explainable_model.eval()

finall_df_loader = DataLoader(finall_dataset, batch_size=32, collate_fn=cosine_collate_fn)
total_gflops = 0.0

for batch in finall_df_loader:
    inputs1_tensor_list = []
    inputs2_tensor_list = []
    labels_batch = batch['label']
    valid_indices = [i for i, label in enumerate(labels_batch) if label is not None]
    labels_batch = [label for label in labels_batch if label is not None]
    labels_tensor = torch.tensor(labels_batch, dtype=torch.float32).to(device)

    inputs1, inputs2 = get_ids(batch)
    inputs1_tensor_list.append((inputs1['input_ids'].to(device), inputs1['attention_mask'].to(device)))
    inputs2_tensor_list.append((inputs2['input_ids'].to(device), inputs2['attention_mask'].to(device)))

    processed_batch = {
        'ids1': inputs1_tensor_list,
        'ids2': inputs2_tensor_list,
        'labels': labels_tensor,
    }
    macs, params = profile(explainable_model, (processed_batch['ids1'][0], processed_batch['ids2'][0]))
    total_gflops += macs

embeddings_caption1, embeddings_caption2, labels = get_embeddings(finall_df_loader)
cosine_similarities = cosine_similarity(embeddings_caption1, embeddings_caption2)
df['cosine_similarity'] = cosine_similarities.diagonal()

###load model mli for adding meaning into context
nli_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
model_nli = AutoModelForSequenceClassification.from_pretrained(nli_model_name)
model_nli.to(device)
model_nli.eval()

# Load data and calculate contradiction score from NLI
data_nli = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: collate_fn_combined(batch, model=model_nli, tokenizer=tokenizer, device=device))
merged_df = pd.concat([pd.DataFrame(batch) for batch in data_nli], ignore_index=True)

# Calculate GFLOPS for NLI process
total_macs_nli = 0
for i, row in merged_df.iterrows():
    input_ids_list = row['input_ids'].cpu().numpy().tolist()
    attention_mask_list = row['attention_mask'].cpu().numpy().tolist()
    input_ids_tensor = torch.tensor(input_ids_list).to(device)
    attention_mask_tensor = torch.tensor(attention_mask_list).to(device)

    with torch.no_grad():
        macs_nli, _ = profile(model_nli, inputs=(input_ids_tensor.unsqueeze(0), attention_mask_tensor.unsqueeze(0)))
        total_macs_nli += macs_nli


##Prepare data for prediction process

test_set = Prepare_data_pred(merged_df)
df_new_data = pd.DataFrame([test_set[idx] for idx in range(len(test_set))])
df_original = merged_df.copy()
df_updated = pd.concat([df_original, df_new_data], axis=1)

##set threshold
if 'predict' not in df.columns:
    df['predict'] = df_updated['predict']

df['finall_label'] = df['pred_y']
condition1 = (df['cosine_similarity'] < 0.47) & (df['pred_y'] == 0) & (df['predict'] == 1)
df.loc[condition1, 'finall_label'] = 1

actual_labels = df['label'].values
predicted_labels = df['finall_label'].values
correct_predictions = (actual_labels == predicted_labels).sum()
total_predictions = len(df)

accuracy = (correct_predictions / total_predictions)
recall = recall_score(actual_labels, predicted_labels)
precision = precision_score(actual_labels, predicted_labels)
f1 = f1_score(actual_labels, predicted_labels)


# Print  metrics
print(f"Accuracy: {accuracy:.4f}%")
print(f"Recall: {recall:.4f}%")
print(f"Precision: {recall:.4f}%")
print(f"F1 Score: {f1:.4f}%")

print(f"\nNumber of sbert classification Trainable Parameters: {count_parameters(sb_model):,}")
print(f"Number of sbert embedding Trainable Parameters: {count_parameters(explainable_model):,}")
print(f"Number of nli Trainable Parameters: {count_parameters(model_nli):,}")
print(f"Gflops SBERT classify process: {macs_classify:.2f}")
print(f"Gflops SBERT cosine similarity : {total_gflops:.2f}")
print(f"GFLOPS for NLI process: {total_macs_nli:.2f}")


torch.save(sb_model.state_dict(), 'sb_model.pth')
model_size_bytes = os.path.getsize('sb_model.pth')
sb_classify_ms = model_size_bytes / (1024 * 1024)
print(f"\nSBERT classify model size: {sb_classify_ms:.2f} MB")

torch.save(explainable_model.state_dict(), 'cs_model.pth')
model_size_bytes = os.path.getsize('cs_model.pth')
sb_cosine_ms = model_size_bytes / (1024 * 1024)
print(f"SBERT cosine model size: {sb_cosine_ms:.2f} MB")

torch.save(model_nli.state_dict(), 'nli_model.pth')
model_size_bytes = os.path.getsize('nli_model.pth')
nli_ms = model_size_bytes / (1024 * 1024)
print(f"NLI model size: {nli_ms:.2f} MB")

Collecting sentence-transformers
  Downloading sentence_transformers-2.5.1-py3-none-any.whl (156 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/156.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━[0m [32m92.2/156.5 kB[0m [31m2.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m156.5/156.5 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of MPNetForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-mpnet-base-v2 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Downloading...
From (original): https://drive.google.com/uc?id=1lAn4jHtBBx-Y4sL92tW4YskbUNT4KgU9
From (redirected): https://drive.google.com/uc?id=1lAn4jHtBBx-Y4sL92tW4YskbUNT4KgU9&confirm=t&uuid=81d21e50-420d-40f0-a3fe-0ad9c634f8b6
To: /content/model.pth
100%|██████████| 438M/438M [00:05<00:00, 84.9MB/s]


[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register 

tokenizer_config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/8.66M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/23.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/286 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.09k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/369M [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Calculating NLI scores: 100%|██████████| 64/64 [00:00<00:00, 6689.15it/s]
Calculating NLI scores: 100%|██████████| 64/64 [00:00<00:00, 113599.43it/s]
Calculating NLI scores: 100%|██████████| 64/64 [00:00<00:00, 67209.68it/s]
Calculating NLI scores: 100%|██████████| 64/64 [00:00<00:00, 67378.38it/s]
Calculating NLI scores: 100%|██████████| 64/64 [00:00<00:00, 58292.17it/s]
Calculating NLI scores: 100%|██████████| 64/64 [00:00<00:00, 124218.17it/s]
Calculating NLI scores: 100%|██████████| 64/64 [00:00<00:00, 38216.89it/s]
Calculating NLI scores: 100%|██████████| 64/64 [00:00<00:00, 196081.41it/s]
Calculating NLI scores: 100%|██████████| 35/35 [00:00<00:00, 45086.19it/s]


[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn