In [None]:
class TabTransformer(nn.Module):
    def __init__(self, num_features, dim_embedding=32, num_heads=2, num_layers=1):
        super(TabTransformer, self).__init__()
        self.embedding = nn.Linear(num_features, dim_embedding)
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim_embedding, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        x = self.embedding(x)
        x = x.unsqueeze(1) 
        x = self.transformer(x)
        return x.mean(dim=1) 

class MobileNetV2Regression(nn.Module):
    def __init__(self, num_classes=1, pretrained=True, train_backbone=False):
        super(MobileNetV2Regression, self).__init__()
        mobilenet_v2 = models.mobilenet_v2(pretrained=pretrained)
        self.features = mobilenet_v2.features
        if not train_backbone:
            for param in self.features.parameters():
                param.requires_grad = False

        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1280, 512),
            nn.Dropout(0.5),
            nn.Linear(512, 64)  
        )

    def forward(self, x):
        x = x.expand(-1, 3, -1, -1) 
        x = self.features(x)
        x = self.fc(x)
        return x

class AttentionFusion(nn.Module):
    def __init__(self, dim_image, dim_tabular):
        super(AttentionFusion, self).__init__()
        self.attention = nn.Linear(dim_image + dim_tabular, 2)

    def forward(self, image_embedding, tabular_embedding):
        concatenated = torch.cat([image_embedding, tabular_embedding], dim=1)
        weights = F.softmax(self.attention(concatenated), dim=1)
        combined_embedding = (
            weights[:, 0:1] * image_embedding + weights[:, 1:2] * tabular_embedding
        )
        return combined_embedding

class OneModel(nn.Module):
    def __init__(self, num_tabular_features, pretrained=True, train_backbone=True):
        super(OneModel, self).__init__()
        self.image_model = MobileNetV2Regression(pretrained=pretrained, train_backbone=train_backbone)
        self.tabular_model = TabTransformer(num_features=num_tabular_features, dim_embedding = 64)
        self.fusion = AttentionFusion(dim_image = 64, dim_tabular = 64)

        self.classifier = nn.Sequential(
            nn.Linear(64, 1),
        )

    def forward(self, image, tabular_data):
        image_embedding = self.image_model(image)
        tabular_embedding = self.tabular_model(tabular_data)
        combined_embedding = self.fusion(image_embedding, tabular_embedding)
        output = self.classifier(combined_embedding)
        return output