In [None]:
!pip install torch torchvision pandas scikit-learn tqdm

In [1]:
import torch
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from sklearn.neural_network import MLPClassifier
import joblib
from huggingface_hub import snapshot_download
import json
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
snapshot_download(repo_id="google/spiqa", repo_type="dataset", local_dir='.') ### Mention the local directory path

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]

'/content'

In [None]:
!unzip /content/test-A/SPIQA_testA_Images.zip

In [None]:
!unzip /content/train_val/SPIQA_train_val_Images.zip | head -n 500

In [102]:
file_path = "/content/train_val/SPIQA_train.json"

with open(file_path, 'r') as file:
    text = json.load(file)

data = []
cols = ['paper', 'question', 'answer', 'reference_figure', 'reference_figure_caption']
for paper in text.keys():
    for question in text[paper]['qa']:
        data.append([paper, question['question'], question['answer'], question['reference'], text[paper]['all_figures'][question['reference']]['caption']])

df = pd.DataFrame(data, columns=cols)

file_path = "/content/test-A/SPIQA_testA.json"
with open(file_path, 'r') as file:
    text = json.load(file)

data = []
cols = ['paper', 'question', 'answer', 'reference_figure', 'reference_figure_caption']
for paper in text.keys():
    for question in text[paper]['qa']:
        data.append([paper, question['question'], question['answer'], question['reference'], text[paper]['all_figures'][question['reference']]['caption']])


test_df = pd.DataFrame(data, columns=cols)
test_df['generated_answer'] = np.nan

In [80]:
df.head()

Unnamed: 0,paper,question,answer,reference_figure,reference_figure_caption
0,1612.01810v3,"Which method converges faster, joint or separa...",The joint method converges faster than the sep...,1612.01810v3-Figure6-1.png,Comparison of the convergence rate between jo...
1,1612.01810v3,What is the effect of increasing the number of...,The boundary recall generally increases and th...,1612.01810v3-Table1-1.png,Boundary recall and time cost comparisons bet...
2,1612.01810v3,What is the main difference between the search...,"The main difference is that in SLIC, each seed...",1612.01810v3-Figure1-1.png,(a) The search method used in SLIC. Each seed...
3,1612.01810v3,How does the number of superpixels affect the ...,"As the number of superpixels decreases, the re...",1612.01810v3-Figure8-1.png,Images segmented by our proposed approach wit...
4,1612.01810v3,Which algorithm produces the most compact sup...,Our approach.,1612.01810v3-Figure7-1.png,Visual comparison of superpixel segmentation ...


In [81]:
len(df)

262524

In [115]:
df["image_path"] = "/content/SPIQA_train_val_Images/" + df["paper"] + "/" + df["reference_figure"]
test_df["image_path"] = "/content/SPIQA_testA_Images/" + test_df["paper"] + "/" + test_df["reference_figure"]

df['exists'] = df['image_path'].apply(lambda x: os.path.exists(x))
df=df[df['exists']]
len(df)

445

In [63]:
resnet50 = models.resnet50(pretrained=True)
resnet50 = torch.nn.Sequential(*list(resnet50.children())[:-1])
resnet50.eval()
resnet50 = resnet50.to(device)


sbert = SentenceTransformer('all-MiniLM-L6-v2')

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])



In [93]:
class ImageDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform
        self.answers = df['answer'].tolist()
        self.answer_embeddings = sbert.encode(self.answers, convert_to_tensor=True).to(device)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      img_path = self.df.iloc[idx]['image_path']
      # image = Image.open(img_path).convert('RGB')
      try:
          image = Image.open(img_path).convert('RGB')
      except Exception as e:
          # print(f"Error loading image {img_path}: {e}")
          image = Image.new('RGB', (224, 224), (0, 0, 0))

      image = self.transform(image)
      img_feature = resnet50(image.unsqueeze(0).to(device)).squeeze().to(device)

      answer_embedding = self.answer_embeddings[idx].to(device)

      return img_feature, answer_embedding

In [94]:
dataset = ImageDataset(df, transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [95]:
class MLP(nn.Module):
    def __init__(self, input_dim=2048, output_dim=384, hidden_dim=1024):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.model(x)

mlp = MLP()
criterion = nn.MSELoss()
optimizer = optim.Adam(mlp.parameters(), lr=0.001)

num_epochs = 10
mlp.to(device)
resnet50.to(device)


Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [98]:
for epoch in tqdm(range(num_epochs)):
    for img_features, answer_embeddings in dataloader:

      img_features, answer_embeddings = img_features.to(device), answer_embeddings.to(device)
      optimizer.zero_grad()
      output = mlp(img_features)
      loss = criterion(output, answer_embeddings)
      loss.backward()
      optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

 10%|█         | 1/10 [00:12<01:48, 12.09s/it]

Epoch 1, Loss: 0.0028


 20%|██        | 2/10 [00:24<01:36, 12.11s/it]

Epoch 2, Loss: 0.0024


 30%|███       | 3/10 [00:35<01:20, 11.57s/it]

Epoch 3, Loss: 0.0023


 40%|████      | 4/10 [00:46<01:10, 11.67s/it]

Epoch 4, Loss: 0.0023


 50%|█████     | 5/10 [00:59<00:59, 11.90s/it]

Epoch 5, Loss: 0.0023


 60%|██████    | 6/10 [01:11<00:48, 12.04s/it]

Epoch 6, Loss: 0.0023


 70%|███████   | 7/10 [01:24<00:36, 12.17s/it]

Epoch 7, Loss: 0.0023


 80%|████████  | 8/10 [01:36<00:24, 12.19s/it]

Epoch 8, Loss: 0.0022


 90%|█████████ | 9/10 [01:47<00:11, 11.86s/it]

Epoch 9, Loss: 0.0022


100%|██████████| 10/10 [01:59<00:00, 11.95s/it]

Epoch 10, Loss: 0.0022





In [116]:
def retrieve_answer(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    img_feature = resnet50(image).squeeze().detach()
    predicted_embedding = mlp(img_feature.unsqueeze(0)).detach().cpu().numpy()
    answer_embeddings = sbert.encode(test_df['answer'].tolist())
    similarities = cosine_similarity(predicted_embedding, answer_embeddings)
    best_match_idx = np.argmax(similarities)
    return test_df['answer'].iloc[best_match_idx]


print(retrieve_answer('/content/SPIQA_testA_Images/1603.00286v5/1603.00286v5-Figure1-1.png'))

The representation module takes an input image and outputs a feature representation. The learning-to-learn module takes a set of features and learns how to segment the image.


In [117]:
test_df

Unnamed: 0,paper,question,answer,reference_figure,reference_figure_caption,generated_answer,image_path
0,1611.04684v1,What are the main differences between the educ...,The Bonaparte school focuses on outdoor physic...,1611.04684v1-Table1-1.png,A difficult example from QA,,/content/SPIQA_testA_Images/1611.04684v1/1611....
1,1611.04684v1,Which model performs the best for response sel...,The KEHNN model performs the best for response...,1611.04684v1-Table4-1.png,Table 4: Evaluation results on response selection,,/content/SPIQA_testA_Images/1611.04684v1/1611....
2,1611.04684v1,Which model performs best on the Ubuntu datase...,KEHNN,1611.04684v1-Table5-1.png,Accuracy on different length of text,,/content/SPIQA_testA_Images/1611.04684v1/1611....
3,1611.04684v1,What is the role of the knowledge gates in the...,The knowledge gates are responsible for select...,1611.04684v1-Figure1-1.png,Architecture of KEHNN,,/content/SPIQA_testA_Images/1611.04684v1/1611....
4,1611.04684v1,How does the average number of answers per que...,The training set has a higher average number o...,1611.04684v1-Table2-1.png,Table 2: Statistics of the answer selection da...,,/content/SPIQA_testA_Images/1611.04684v1/1611....
...,...,...,...,...,...,...,...
661,1809.04276v2,How does the discriminator in the proposed REA...,The discriminator takes as input a response an...,1809.04276v2-Figure1-1.png,Figure 1: An overview of our proposed approach...,,/content/SPIQA_testA_Images/1809.04276v2/1809....
662,1710.01507v4,What is the role of the LSTM network in the mo...,The LSTM network is used to process the post t...,1710.01507v4-Figure1-1.png,Model Architecture,,/content/SPIQA_testA_Images/1710.01507v4/1710....
663,1709.00139v4,Which method generally achieved a lower object...,"For all datasets presented, Incremental SVM ac...",1709.00139v4-Table1-1.png,Table 1: Experimental Results of FISVDD and In...,,/content/SPIQA_testA_Images/1709.00139v4/1709....
664,1809.01989v2,Which method achieved the highest tracking acc...,The Ridge method achieved the lowest sum of ab...,1809.01989v2-Table1-1.png,Table 1. Absolute percentage errors for differ...,,/content/SPIQA_testA_Images/1809.01989v2/1809....


In [118]:
test_df['generated_answer'] = test_df['image_path'].apply(retrieve_answer)

In [119]:
test_df

Unnamed: 0,paper,question,answer,reference_figure,reference_figure_caption,generated_answer,image_path
0,1611.04684v1,What are the main differences between the educ...,The Bonaparte school focuses on outdoor physic...,1611.04684v1-Table1-1.png,A difficult example from QA,The VQA dataset presents the biggest challenge.,/content/SPIQA_testA_Images/1611.04684v1/1611....
1,1611.04684v1,Which model performs the best for response sel...,The KEHNN model performs the best for response...,1611.04684v1-Table4-1.png,Table 4: Evaluation results on response selection,The proposed model has the highest accuracy on...,/content/SPIQA_testA_Images/1611.04684v1/1611....
2,1611.04684v1,Which model performs best on the Ubuntu datase...,KEHNN,1611.04684v1-Table5-1.png,Accuracy on different length of text,"The ""Conv. Self-Correction"" method achieves th...",/content/SPIQA_testA_Images/1611.04684v1/1611....
3,1611.04684v1,What is the role of the knowledge gates in the...,The knowledge gates are responsible for select...,1611.04684v1-Figure1-1.png,Architecture of KEHNN,The representation module takes an input image...,/content/SPIQA_testA_Images/1611.04684v1/1611....
4,1611.04684v1,How does the average number of answers per que...,The training set has a higher average number o...,1611.04684v1-Table2-1.png,Table 2: Statistics of the answer selection da...,The log model.,/content/SPIQA_testA_Images/1611.04684v1/1611....
...,...,...,...,...,...,...,...
661,1809.04276v2,How does the discriminator in the proposed REA...,The discriminator takes as input a response an...,1809.04276v2-Figure1-1.png,Figure 1: An overview of our proposed approach...,The decoder is responsible for generating the ...,/content/SPIQA_testA_Images/1809.04276v2/1809....
662,1710.01507v4,What is the role of the LSTM network in the mo...,The LSTM network is used to process the post t...,1710.01507v4-Figure1-1.png,Model Architecture,The representation module takes an input image...,/content/SPIQA_testA_Images/1710.01507v4/1710....
663,1709.00139v4,Which method generally achieved a lower object...,"For all datasets presented, Incremental SVM ac...",1709.00139v4-Table1-1.png,Table 1: Experimental Results of FISVDD and In...,The log model.,/content/SPIQA_testA_Images/1709.00139v4/1709....
664,1809.01989v2,Which method achieved the highest tracking acc...,The Ridge method achieved the lowest sum of ab...,1809.01989v2-Table1-1.png,Table 1. Absolute percentage errors for differ...,The log model.,/content/SPIQA_testA_Images/1809.01989v2/1809....


In [120]:
test_df.drop(columns=['image_path'], inplace=True)
test_df.to_csv('resnet_mlp_results.csv', index=False)