In [3]:
import pandas as pd
import numpy as np
from datasets import load_dataset, DatasetDict, Features, Dataset, Value, ClassLabel, Image


In [4]:
dataset = load_dataset("imagefolder", data_dir="/data_vault/hexai/ArtEmis-FinalSplits/")

label2id = {label: i for i, label in enumerate(dataset["train"].features["label"].names)}
id2label = {i: label for label, i in label2id.items()}

Resolving data files:   0%|          | 0/13311 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1903 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/3806 [00:00<?, ?it/s]

In [5]:
import regex as re
alpha_re = re.compile(r"[^a-zA-Z-_0-9]")

In [6]:
import sys
sys.getdefaultencoding()

'utf-8'

In [7]:
metadata = pd.read_csv("final-splits.csv")

def replace_unicode(x):
    unicode = {uni: "a" for uni in ['ã¶', 'ã\xad', 'ã©', 'ã¨', 'ã³', 'ã¼', 'â\xa0']}
    for uni in unicode:
        x = x.replace(uni, unicode[uni])
   
    return x

In [8]:
metadata["painting"] = metadata["painting"].apply(lambda x: replace_unicode(x))
metadata["painting"] = metadata["painting"].apply(lambda x: alpha_re.sub("", x))

In [9]:
from tqdm import tqdm
def add_utterance_data(dataset, metadata):
    # Training dataset
    train_utterances = []
    for batch in tqdm(dataset["train"]):
        painting = batch["image"].filename.split("/")[-1].split(".")[0]
        painting = alpha_re.sub("", painting)     
        meta_utt = metadata[metadata.painting.str.contains(painting)]["utterance"].values[0]
        train_utterances.append(meta_utt)
    
    dataset["train"] = dataset["train"].add_column("utterance", train_utterances)

    # Validation dataset
    valid_utterances = []
    for batch in tqdm(dataset["validation"]):
        painting = batch["image"].filename.split("/")[-1].split(".")[0]
        painting = alpha_re.sub("", painting)     
        meta_utt = metadata[metadata.painting.str.contains(painting)]["utterance"].values[0]
        valid_utterances.append(meta_utt)
    
    dataset["validation"] = dataset["validation"].add_column("utterance", valid_utterances)

    # Testing dataset
    test_utterances = []
    for batch in tqdm(dataset["test"]):
        painting = batch["image"].filename.split("/")[-1].split(".")[0]
        painting = alpha_re.sub("", painting)     
        meta_utt = metadata[metadata.painting.str.contains(painting)]["utterance"].values[0]
        test_utterances.append(meta_utt)
    
    dataset["test"] = dataset["test"].add_column("utterance", test_utterances)

    return dataset

In [10]:
dataset = add_utterance_data(dataset, metadata)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 13311/13311 [03:12<00:00, 69.15it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 1903/1903 [00:27<00:00, 69.57it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 3806/3806 [00:54<00:00, 69.61it/s]


In [11]:
dataset["test"]

Dataset({
    features: ['image', 'label', 'utterance'],
    num_rows: 3806
})

In [12]:
from transformers import AutoModel, AutoImageProcessor, AutoModelForImageClassification, AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

vision_checkpoint = "efficientnet_best/"
bert_checkpoint = "bert_best/"

image_processor = AutoImageProcessor.from_pretrained(vision_checkpoint)
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased", do_lower_case=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

def image_transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
 
    del examples["image"]
    return examples

def tokenize(examples):
    return tokenizer(examples['utterance'], padding='max_length', truncation=True, max_length=128)

In [13]:
text_dataset = dataset.map(tokenize, batched=True)
text_dataset.set_format(type='torch', columns=["input_ids", "attention_mask", "label"])

In [14]:
image_dataset = dataset.with_transform(image_transforms)

In [38]:
image_model = AutoModelForImageClassification.from_pretrained("vit_best/").cuda()
text_model = AutoModelForSequenceClassification.from_pretrained("bert_best_v2/").cuda()

In [39]:
import torch
import torch.nn as nn
class LateFusionModel(nn.Module):
    def __init__(self, cv_model, text_model, weights = [0.5, 0.5]):
        super(LateFusionModel, self).__init__()
        self.cv_model = cv_model
        self.text_model = text_model
        self.softmax = nn.Softmax(dim=1)
        self.weights = weights

    def forward(self, image, text, attention_mask):
        img_out = self.cv_model(image)
        text_out = self.text_model(text, attention_mask)

        return self.weights[0] * self.softmax(img_out.logits) + self.weights[1] * self.softmax(text_out.logits)

In [46]:
lfm = LateFusionModel(image_model, text_model, weights=[0.75, 0.25]).cuda()

In [47]:
img_data = next(iter(image_dataset["test"]))
txt_data = next(iter(text_dataset["test"]))

In [48]:
pixels = img_data["pixel_values"].cuda().unsqueeze(dim=0)
input_ids = txt_data["input_ids"].cuda().unsqueeze(dim=0)
attn_mask = txt_data["attention_mask"].cuda().unsqueeze(dim=0)


In [49]:
preds, labels = [], []
for img_batch, txt_batch in zip(image_dataset["test"], text_dataset["test"]):
    pixels = img_batch["pixel_values"].cuda().unsqueeze(dim=0)
    input_ids = txt_data["input_ids"].cuda().unsqueeze(dim=0)
    attn_mask = txt_data["attention_mask"].cuda().unsqueeze(dim=0)
    
    out = lfm(pixels, input_ids, attn_mask)
    predictions = np.argmax(out.detach().cpu().numpy(), axis=-1)
    preds.extend(predictions)
    labels.append(img_batch["label"])


In [50]:
from sklearn.metrics import classification_report, confusion_matrix
print(classification_report(labels, preds))

              precision    recall  f1-score   support

           0       0.39      0.29      0.33       708
           1       0.36      0.63      0.46       680
           2       0.40      0.02      0.04       208
           3       0.40      0.38      0.39       759
           4       0.27      0.52      0.36       454
           5       0.44      0.22      0.30       516
           6       0.47      0.22      0.30       481

    accuracy                           0.36      3806
   macro avg       0.39      0.33      0.31      3806
weighted avg       0.39      0.36      0.34      3806



In [51]:
from imblearn.metrics import specificity_score
print(1 - specificity_score(labels, preds, average="macro"))

0.10869076017138635
