In [1]:
import torch
import torch.nn as nn
from torchvision import models

# Image Encoder
class ImageEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super(ImageEncoder, self).__init__()
        resnet = models.resnet50(pretrained=pretrained)
        self.features = nn.Sequential(*list(resnet.children())[:-1])  # Remove final classification layer

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        return x

# Entity Type Encoder
class EntityEncoder(nn.Module):
    def __init__(self, num_entities, embedding_dim):
        super(EntityEncoder, self).__init__()
        self.embedding = nn.Embedding(num_entities, embedding_dim)

    def forward(self, x):
        x = self.embedding(x)
        return x

# Combined Model
class AttributePredictor(nn.Module):
    def __init__(self, num_entities, embedding_dim):
        super(AttributePredictor, self).__init__()
        self.image_encoder = ImageEncoder()
        self.entity_encoder = EntityEncoder(num_entities, embedding_dim)
        self.fc = nn.Sequential(
            nn.Linear(2048 + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1)  # Output scalar value
        )

    def forward(self, image, entity_type):
        image_features = self.image_encoder(image)
        entity_features = self.entity_encoder(entity_type)
        combined = torch.cat((image_features, entity_features), dim=1)
        output = self.fc(combined)
        return output
