In [None]:
!pip install shap
!pip install torch
!pip install transformers

In [None]:
import torch
from torch import nn
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer

class MisogynyCls(nn.Module):
    def __init__(self, num_linear_layers, task_a_out=1, task_b_out=4, input_dim=1024, hidden_dim=512, drop_value=0.2):
        super().__init__()
        self.head_task_a = nn.Linear(hidden_dim, task_a_out)
        self.head_task_b = nn.Linear(hidden_dim, task_b_out)
        self.sigmoid = nn.Sigmoid()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Check if CUDA is available
        
        # Pretrained CLIP loading...
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


        self.layers = nn.ModuleList()

        for i in range(num_linear_layers):
            if i == 0:
                self.layers.append(nn.Linear(input_dim, hidden_dim))
            else:
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                
            self.layers.append(nn.BatchNorm1d(hidden_dim))
            self.layers.append(nn.Dropout(drop_value))
            self.layers.append(nn.ReLU())

    def forward(self, text_list, image_list):
        clip_inputs = self.clip_processor(text=text_list, images=image_list, return_tensors="pt", padding=True, truncation=True)
        clip_outputs = self.clip_model(**clip_inputs)
        
        x = torch.cat([clip_outputs['text_embeds'], clip_outputs['image_embeds']], dim=1).to(self.device) # model input is the concatenation of the two modalities !
        
        for layer in self.layers:
            x = layer(x)
            #print(x.shape)
        pred_taskA = self.sigmoid(self.head_task_a(x))
        pred_taskB = self.sigmoid(self.head_task_b(x))
        
        return pred_taskA, pred_taskB

class OnlyTextCls(nn.Module):
    def __init__(self, cls):
        super().__init__() # da aggiustare
        self.classifier = cls
    
    def forward(self, text_list):
        text_list = [el.item() for el in text_list]
        null_images = [Image.new('RGB', (100, 100), color=(0, 0, 0)) for _ in range(len(text_list))]
        prediction, _ = self.classifier(text_list, null_images) # for now we return only the prediction about the main task (binary one)
        print(prediction)
    
        return prediction
    
    
class OnlyImageCls(nn.Module):
    def __init__(self, cls):
        super().__init__() # da aggiustare
        self.classifier = cls
    
    def forward(self, image_list):
        null_text = [" "  for _ in range(len(image_list))] # Da cambiare e parametrizzare il token nullo ('[UNK]')
        prediction, _ = self.classifier(null_text, image_list) # for now we return only the prediction about the main task (binary one)
        print(prediction)
        return prediction

## Explaining text


In [None]:
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, AutoTokenizer
import shap
import numpy as np

#### IMPLEMENT ALSO PUNCTUATION PREPROCESSING OVER THE INPUT PROMPT FOR THE MODEL !!! ####

#clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")



mask_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

# 3. Create an instance of the model and load the state dictionary
checkpoint = torch.load('/kaggle/input/model-params/model_3.pth', map_location=torch.device('cpu'))
classifier = MisogynyCls(5)
classifier.load_state_dict(checkpoint)
only_text_classifier = OnlyTextCls(classifier)

classifier.eval()

### Testing the masker for the explainer ###
clip_masker = shap.maskers.Text(mask_tokenizer, mask_token="...", collapse_mask_token=False) # cambiare mask token
explainer = shap.Explainer(only_text_classifier, masker=clip_masker)

sample_text = ["crazy slut bitch yoo shut up"]
# Compute SHAP values
shap_values = explainer(sample_text) # dentro l'explainer l'input verrà 

print(shap_values)
#shap.plots.waterfall(shap_values[0], max_display=14)
shap.plots.text(shap_values[0])


## Loading json file

In [None]:
import csv
import json
import os

images_path = "/kaggle/input/dataset-wow/MAMI DATASET/MAMI DATASET/training/TRAINING"

file_and_dest = [('/kaggle/input/dataset-wow/train_image_text.tsv','/kaggle/working/train_image_text.json'),
                    ('/kaggle/input/dataset-wow/test_image_text.tsv','/kaggle/working/test_image_text.json')]


for file in file_and_dest: 
    data = []

    with open(file[0], newline='', encoding='utf-8') as tsvfile:
        reader = csv.DictReader(tsvfile, delimiter='\t')
        
        for row in reader:
            data.append(row)
                
    if not os.path.exists(file[1]):
        with open(file[1], 'w', encoding='utf-8') as jsonfile:
            json.dump([], jsonfile, ensure_ascii=False, indent=4)
        print(f"File JSON vuoto creato come {file[1]}")

        with open(file[1], 'w', encoding='utf-8') as jsonfile:
            json.dump(data, jsonfile, ensure_ascii=False, indent=4)

        print(f"File JSON salvato come {file[1]}")

## Multimodal Dataset

In [None]:
from torch.utils.data import Dataset
import os
from tqdm import tqdm
import wandb
import json

class MultimodalDataset(Dataset): # Dataset for handling multimodal data
    def __init__(self, images_dir, json_file_path): # dir_path -> directory path where images are stored / json_file_path -> file path for metadata (including labels)   
        file_paths, text_list, labels_misogyny, shaming_label_list, stereotype_label_list, objectification_label_list, violence_label_list = load_json_file(json_file_path)
   
        self.file_paths = file_paths
        self.images_dir = images_dir
        self.text_list = text_list
        self.labels_misogyny = labels_misogyny
        self.shaming_label_list = shaming_label_list
        self.stereotype_label_list = stereotype_label_list
        self.objectification_label_list = objectification_label_list
        self.violence_label_list = violence_label_list
        
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        return self.file_paths[idx], self.text_list[idx], self.labels_misogyny[idx], self.shaming_label_list[idx], self.stereotype_label_list[idx], self.objectification_label_list[idx], self.violence_label_list[idx]
    
def load_json_file(json_file_path):
    with open(json_file_path,"r") as f:
        data = json.load(f)
        text_list = [] # list of the text related to each image
        image_list = [] # list of images path
        labels_misogyny = [] # list of TASK A labels (misogyny classification)

        ### list of TASK B labels ###
        shaming_label_list = [] 
        stereotype_label_list = []
        objectification_label_list = []
        violence_label_list = []


        for item in tqdm(data): 
            image_list.append(item['file_name'])
            text_list.append(item["text"])
            labels_misogyny.append(float(item["label"]))
            shaming_label_list.append(float(item["shaming"]))
            stereotype_label_list.append(float(item["stereotype"]))
            objectification_label_list.append(float(item["objectification"]))
            violence_label_list.append(float(item["violence"]))

        #print(f"{type(labels_misogyny)}")                                
        return image_list, text_list, torch.tensor(labels_misogyny, dtype=torch.float32), torch.tensor(shaming_label_list, dtype=torch.float32), torch.tensor(stereotype_label_list, dtype=torch.float32), torch.tensor(objectification_label_list, dtype=torch.float32), torch.tensor(violence_label_list, dtype=torch.float32)
    
def accuracy(preds, labels, thresh):
    num_samples = labels.shape[0]
    preds = preds > thresh
    matching_rows = torch.eq(labels.bool(), preds)
    
    # in case we're dealing with the prediction of task B/task A (they've different number of dimensions)
    num_correct = matching_rows.all(dim=1).sum().item() if preds.ndim!=1 else matching_rows.sum().item()
    return 100*(num_correct/num_samples)

## Explaining Image

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from PIL import Image
import os
from tqdm import tqdm


train_data = MultimodalDataset("/kaggle/input/dataset-wow/MAMI DATASET/MAMI DATASET/training/TRAINING", "/kaggle/working/train_image_text.json")
test_data = MultimodalDataset("/kaggle/input/dataset-wow/MAMI DATASET/MAMI DATASET/test", "/kaggle/working/test_image_text.json")
train_dataloader = DataLoader(train_data, 100, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_data, 16, shuffle=True, pin_memory=True)

batch_train, _, _, _, _, _, _ = next(iter(train_dataloader))
batch_test, _, _, _, _, _, _  = next(iter(test_dataloader))

batch_train = [ToTensor()(Image.open(f"{os.path.join('/kaggle/input/dataset-wow/MAMI DATASET/MAMI DATASET/training/TRAINING', img)}")) for img in batch_train]
batch_test = [ToTensor()(Image.open(f"{os.path.join('/kaggle/input/dataset-wow/MAMI DATASET/MAMI DATASET/test', img)}")) for img in batch_test]


In [None]:
import shap
import torch
from torchvision import transforms
import numpy as np

batch_size = 50
n_evals = 100
resize = transforms.Resize((440, 440))

checkpoint = torch.load('/kaggle/input/model-params/model_3.pth', map_location=torch.device('cpu'))
classifier = MisogynyCls(5)
classifier.load_state_dict(checkpoint)
only_image_classifier = OnlyImageCls(classifier)
classifier.eval()

data = [resize(img) for img in batch_train]

data_to_test = data[8].permute(1, 2, 0)
print(data_to_test.shape)

masker_blur = shap.maskers.Image("blur(128,128)", data_to_test[0].shape)
explainer = shap.Explainer(only_image_classifier, masker_blur)

shap_values = explainer(
    data_to_test.unsqueeze(0),
    max_evals=n_evals,
    batch_size=batch_size,
)

print(shap_values.data.shape)
print(shap_values.values.shape)

print(shap_values.values.shape)

shap.image_plot(
    shap_values=shap_values.values.squeeze(4),
    pixel_values=shap_values.data.numpy(),
)