In [3]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, ViTFeatureExtractor, ViTForImageClassification
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset


In [4]:
food = load_dataset("food101", split="train[:2500]")

In [19]:
print(food)

Dataset({
    features: ['image', 'label'],
    num_rows: 2500
})


In [5]:
# Define the BERT and Vision Transformer models
class TextModel(nn.Module):
    def __init__(self):
        super(TextModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state

class VisionModel(nn.Module):
    def __init__(self):
        super(VisionModel, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.last_hidden_state

# Define the MLP head
class MLPHead(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLPHead, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.layers(x)

# Combine the models and MLP head
class MultimodalClassifier(nn.Module):
    def __init__(self, text_model, vision_model, mlp_head):
        super(MultimodalClassifier, self).__init__()
        self.text_model = text_model
        self.vision_model = vision_model
        self.mlp_head = mlp_head

    def forward(self, text_input_ids, text_attention_mask, vision_pixel_values):
        text_features = self.text_model(text_input_ids, text_attention_mask)
        vision_features = self.vision_model(vision_pixel_values)

        # Assuming you want to concatenate the features
        combined_features = torch.cat([text_features, vision_features], dim=1)

        # MLP head
        output = self.mlp_head(combined_features)

        return output

In [21]:
food = food.train_test_split(test_size=0.2)


AttributeError: 'DatasetDict' object has no attribute 'train_test_split'

In [22]:
print(food["train"].features["label"])

ClassLabel(names=['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'mac

In [6]:
text_model = TextModel()
vision_model = VisionModel()
mlp_head = MLPHead(input_size=768 + 768, hidden_size=256, output_size=101)  # Assuming 768 is the size of BERT and ViT embeddings
model = MultimodalClassifier(text_model, vision_model, mlp_head)


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
for param in model.text_model.parameters():
    param.requires_grad = False

for param in model.vision_model.parameters():
    param.requires_grad = False

In [None]:
#create text and image processors + dictionnaries if needed