In [None]:
# --- STEP 6: Hybrid CNN + ViT Model ---
class HybridCNNViT(nn.Module):
    def __init__(self):
        super(HybridCNNViT, self).__init__()
        self.cnn = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT).features
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.vit.head = nn.Identity()  # remove classification head

        self.classifier = nn.Sequential(
            nn.Linear(1024 + 768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 2)
        )

    def forward(self, x):
        cnn_feat = self.cnn(x)
        cnn_feat = nn.AdaptiveAvgPool2d((1, 1))(cnn_feat)
        cnn_feat = cnn_feat.view(cnn_feat.size(0), -1)

        vit_feat = self.vit(x)

        combined = torch.cat((cnn_feat, vit_feat), dim=1)
        out = self.classifier(combined)
        return out