In [1]:
import sys
import os
sys.path.append(os.path.abspath('/Users/ericxia/school/Math-148-Project/food-classification'))

import json
from PIL import Image 
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader

from data_utils.utils import keep_existing_photos, downsample_group, preprocess_image
from data_utils.dataset import PhotoLabelDataset, stratified_split_dataset
from model.resnet18 import Resnet18FineTuneModel
from model.fusion_model import FusionModel
from model.utils import get_device
from model.gradcam import GradCAM

from captum.attr import IntegratedGradients
from transformers import BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_dir = "../"

business_df = pd.read_json(f'{base_dir}data/yelp_dataset/yelp_academic_dataset_business.json', lines=True)
photos_df = pd.read_json(f'{base_dir}data/yelp_photos/photos.json', lines=True)
top_reviews_per_restaurant = pd.read_csv(f'{base_dir}data/yelp_dataset/top_reviews_per_restaurant_with_summary.csv')

photo_dir = f"{base_dir}data/yelp_photos/resized_photos"
photos_df = keep_existing_photos(photos_df, photo_dir)

photos_df = photos_df[photos_df['label'] == 'food'].copy()

categories_df = business_df[['business_id', 'attributes']].copy()
photos_df = photos_df.merge(categories_df, on="business_id", how="left")

photos_df = photos_df[photos_df['attributes'].notna()]
photos_df['price_range'] = photos_df['attributes'].apply(lambda x: x.get('RestaurantsPriceRange2'))

photos_df['price_range'] = photos_df['price_range'].astype(int)
photos_df['price_range'] = photos_df['price_range'].replace({2: 1, 3: 2, 4: 2}) 

food_with_reviews_df = photos_df.merge(top_reviews_per_restaurant, on="business_id", how="left")

Checking images: 100%|██████████| 200100/200100 [00:09<00:00, 21162.56it/s]


In [3]:
with open(f'{base_dir}data/ids.json') as f:
    ids_dict = json.load(f)

test_ids = ids_dict['test_ids']

In [4]:
food_with_reviews_df = food_with_reviews_df[food_with_reviews_df['business_id'].isin(test_ids)]

In [5]:
test_photo_ids = food_with_reviews_df.photo_id.tolist()

In [7]:
device = get_device()
image_path = f"{base_dir}data/yelp_photos/resized_photos/{test_photo_ids[0]}.jpg"
original_image, input_tensor = preprocess_image(image_path)
input_tensor = input_tensor.to(device)

In [None]:
num_classes = 2
device = get_device()

binary_image_only = FusionModel(num_classes=num_classes)

price_ckpt_19 = torch.load("../checkpoints/price_2_classes/ckpt_1")
binary_image_only.load_state_dict(price_ckpt_19['model_state_dict'])
binary_image_only.to(device)
binary_image_only.eval()

KeyboardInterrupt: 

In [9]:
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def make_forward_func(model, target_label=0):
    """
    Creates and returns a forward function that:
      - Takes (images, input_ids, attention_mask) as arguments
      - Returns logits[:, target_label]
    """
    def forward_func(images, input_ids, attention_mask):
        logits = model(images, input_ids, attention_mask)  # shape: [B, num_classes]
        return logits[:, target_label]                     # shape: [B]
    return forward_func


def explain_text_with_ig(model, images, input_ids, attention_mask, target_label=0):
    """
    Compute integrated gradients attributions for the given batch with respect
    to `target_label`. Returns a tuple of attributions:
        (attr_images, attr_ids, attr_mask)
    where each matches the shape of the corresponding input.
    """
    model.eval()  # ensure eval mode
    # Make sure grads are allowed; if your model is frozen, you at least need
    # the inputs to require grad or final layers unfrozen to see something.

    # 1) Build the forward function that Captum will use.
    forward_func = make_forward_func(model, target_label=target_label)

    # 2) Create the IntegratedGradients object with this forward function.
    ig = IntegratedGradients(forward_func)

    # 3) Create baselines for each input: 
    #    e.g. zero image, zero input_ids, zero attention_mask
    pad_images = torch.zeros_like(images)
    pad_ids = torch.zeros_like(input_ids)
    pad_attention_mask = torch.zeros_like(attention_mask)

    # 4) Now call ig.attribute with a *tuple* of inputs and baselines.
    #    The forward_func will receive these inputs in the same order.
    input_tuple = (images, input_ids, attention_mask)
    baseline_tuple = (pad_images, pad_ids, pad_attention_mask)

    attributions = ig.attribute(
        inputs=input_tuple,
        baselines=baseline_tuple,
        n_steps=10,               # number of steps in the integration path
        internal_batch_size=1     # can help if memory is an issue
        # target=target_label is unnecessary because
        # we already handle that inside forward_func
    )

    # attributions is a tuple of the same structure: (attr_images, attr_ids, attr_mask)
    return attributions

def predict_text_with_image_emb(embeddings, attention_mask, image_tensor, model, target_class=None):
    """Predict using embeddings instead of input_ids"""
    model.eval()
    
    # Extract features from both image and text
    image_features = model.image_encoder(image_tensor)  # Already moved to MPS
    
    # Use embeddings directly instead of input_ids
    text_features = model.text_encoder.bert(inputs_embeds=embeddings, attention_mask=attention_mask).last_hidden_state[:, 0, :]
    text_features = model.text_encoder.fc(text_features)  # Keep all tensors on MPS
    image_features = image_features.expand(text_features.shape[0], -1)
    
    fused_features = model.fusion_mlp(torch.cat([image_features, text_features], dim=-1))
    logits = model.mlp_classifier(fused_features)

    if target_class is None:
        return logits
    return logits[:, target_class]

# ✅ Function to compute attributions
def interpret_text_with_image(image_path, text, model, target_class=1, n_steps=50):
    """Compute Integrated Gradients for the text component while keeping image fixed."""
    
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(device)  # ✅ Move to MPS

    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    input_ids = inputs["input_ids"].to(torch.long).to(device)  # ✅ Move to MPS
    attention_mask = inputs["attention_mask"].to(torch.long).to(device)  # ✅ Move to MPS

    # Get embeddings from BERT's embedding layer
    embeddings = model.text_encoder.bert.embeddings.word_embeddings(input_ids).to(torch.float32).to(device)  # ✅ Convert to float32
    embeddings.requires_grad = True  # Compute gradients on embeddings

    # ✅ Convert baseline tensor to float32
    baselines = (embeddings * 0).to(torch.float32)

    # Define Integrated Gradients on embeddings
    ig = IntegratedGradients(lambda emb: predict_text_with_image_emb(emb, attention_mask, image_tensor, model, target_class))

    # ✅ Compute attributions and ensure `step_sizes` are float32
    attributions, _ = ig.attribute(
        embeddings, 
        baselines=baselines,
        n_steps=n_steps, 
        return_convergence_delta=True
    )

    # Convert attributions to numpy for visualization
    attributions = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()
    
    # Map attributions to tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).cpu().numpy())
    return dict(zip(tokens, attributions))


### Intergrated gradient

In [7]:
device = get_device()
model = FusionModel(explain_model=True)
price_ckpt_15 = torch.load("checkpoints/price_model_multimodal_binary_unique_restaurants_duplicate_photos/ckpt_15", map_location=torch.device("cpu"))
model.load_state_dict(price_ckpt_15['model_state_dict'])
model.to(device) 
model.eval()



FusionModel(
  (image_encoder): ResNetFeatureExtractor(
    (feature_extractor): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       

In [27]:
food_with_reviews_df

Unnamed: 0,photo_id,business_id,caption,label,attributes,price_range,text,summary
93,rU04nXxljNXGquvdCd9Gtw,HqgMCVe-jvlP0kT93M_5Sg,Seafood platter and 1lb of crawfish,food,"{'RestaurantsPriceRange2': '1', 'Alcohol': ''b...",1,Crispy Cajun is a small casual eat-in and take...,Crispy Cajun is a small casual eat-in and tak...
104,Ovta76oHlqk1p5EE34K4RA,l2TdZEPHrboWrMUnoFB1RQ,Gnocchi Bolognese,food,"{'BikeParking': 'True', 'Corkage': 'False', 'B...",1,Mangia Macaroni had been on my radar for a whi...,Mangia Macaroni is a BYOB-only restaurant . T...
128,k1Bv2jVTuMZ0ZIJi6pVHSQ,8-0EBUcwlrRMWhSNdT-5FQ,,food,"{'RestaurantsPriceRange2': '2', 'BYOBCorkage':...",1,Good sushi. I ordered the sashimi deluxe and g...,Sushi chef is lazy and just makes the sushi a...
133,nkY1w_ZCIDbEXYfTAvPguQ,iMWGmFrMVEEktWNriZ2cTQ,WOW!,food,"{'BusinessAcceptsCreditCards': 'True', 'HappyH...",1,Oh my gosh! Where do I begin?\n\nThe food is a...,Little Gourmand is an option with Postmates a...
215,wU88wTejnJJaBtHQq8adEg,dEUF0eTd9a1xOYvYu46dsQ,"Basil Pesto Agnolotti - half moon, bocconcini ...",food,"{'RestaurantsAttire': 'u'casual'', 'HasTV': 'F...",1,My family and I went here last night for the f...,Broad St. in Woodbury is mostly a ghost town ...
...,...,...,...,...,...,...,...,...
96720,ae5HeQUx9zJnmUKJzDDN7Q,7_1GqlDlbkShY0az7J0XNg,Beef tacos (very filling),food,"{'RestaurantsAttire': 'u'casual'', 'NoiseLevel...",1,I almost have to thank La Cocina - because fro...,La Cocina is located on the site of El Presid...
96727,RAQNEAF4awAmuT5GeNNkHw,UoDicg0wO3Q1JPUymA-91w,Johnny Brusco's New York Style Pizza,food,"{'HasTV': 'True', 'RestaurantsDelivery': 'True...",1,4.25 stars rounded up for this pizzeria chain....,Johnny Brusco's is a very good chain pizzeria...
96789,G79hHnRgBN23Zdy4CGkziw,-3-6BB10tIWNKGEF0Es2BA,Sliced Duck Rice Noodle,food,"{'Alcohol': 'u'none'', 'DogsAllowed': 'False',...",1,What what? I can get Taiwanese food in Tucson ...,The scallion cake was everything I wanted it ...
96796,9kRazWNbBrNOeFec_8MPiw,0RxU5OglQyPVtLKC1LPgsA,,food,"{'Caters': 'True', 'GoodForMeal': '{'dessert':...",1,Love this amazing bakery! Waiting in the hot s...,Bree'osh is tucked away in a popular strip ma...


In [10]:
row = food_with_reviews_df.iloc[0]
image_path = f"{base_dir}data/yelp_photos/resized_photos/{row.photo_id}.jpg"
text = row.summary

inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs["input_ids"].to(torch.long).to(device)
attention_mask = inputs["attention_mask"].to(torch.long).to(device)  
original_image, input_tensor = preprocess_image(image_path)
input_tensor = input_tensor.to(device)

attributions = explain_text_with_ig(model, 
                    input_tensor, 
                    input_ids, 
                    attention_mask, 
                    target_label=0)

# Print word attributions
print("Token attributions:", attributions)

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got MPSFloatType instead (while checking arguments for embedding)

In [None]:
# Example usage
example_text = "This is a very positive and happy review!"
tokenized = tokenizer(example_text, return_tensors="pt", padding=True, truncation=True)

attributions = interpret_text(BertFeatureExtractor(), tokenized['input_ids'], tokenized['attention_mask'], class_idx=1)
visualize_text_attributions(example_text, attributions)