In [None]:
class BrainTumorDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths, self.labels, self.transform = file_paths, labels, transform
    def __len__(self): return len(self.file_paths)
    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx]).convert('L')
        if self.transform: img = self.transform(img)
        return img, self.labels[idx]

In [None]:
def create_densenet():
    model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
    conv0 = model.features.conv0
    new_conv = nn.Conv2d(1, conv0.out_channels, conv0.kernel_size, conv0.stride, conv0.padding, bias=False)
    with torch.no_grad(): new_conv.weight[:,0,:,:] = conv0.weight[:,0,:,:]
    model.features.conv0 = new_conv
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    return model.to(device)

In [None]:
def create_vit():
    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224",
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )
    conv = model.vit.embeddings.patch_embeddings.projection
    new_conv = nn.Conv2d(1, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, bias=False)
    with torch.no_grad(): new_conv.weight.copy_(conv.weight.mean(dim=1, keepdim=True))
    model.vit.embeddings.patch_embeddings.projection = new_conv
    model.config.num_channels = 1
    model.vit.embeddings.patch_embeddings.num_channels = 1
    return model.to(device)

In [None]:
# ---------------------- Ensemble ----------------------
class DenseNetViTEnsemble(nn.Module):
    def __init__(self, densenet, vit, dw, vw):
        super().__init__()
        self.densenet, self.vit, self.dw, self.vw = densenet.eval(), vit.eval(), dw, vw
        for p in self.densenet.parameters(): p.requires_grad=False
        for p in self.vit.parameters(): p.requires_grad=False
    @torch.no_grad()
    def forward(self, x):
        densenet_probs = F.softmax(self.densenet(x), dim=1)
        vit_probs = F.softmax(self.vit(x).logits, dim=1)
        return self.dw*densenet_probs + self.vw*vit_probs
    def predict(self, x): return torch.argmax(self(x), dim=1)

In [None]:
def test_model(model, loader, vit=False):
    model.eval(); correct=total=0; y_labels=[]; y_preds=[]
    pbar=tqdm(loader)
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs).logits if vit else model(imgs)
            preds = torch.argmax(outputs, 1)
            total += labels.size(0); correct += (preds==labels).sum().item()
            acc=100*correct/total; pbar.set_postfix_str(f'Acc={acc:.2f}%'); pbar.update()
            y_labels+=labels.cpu().numpy().tolist(); y_preds+=preds.cpu().numpy().tolist()
    pbar.close(); return acc,y_labels,y_preds

In [None]:
def test_ensemble(model, loader):
    model.eval(); correct=total=0; y_labels=[]; y_preds=[]; all_probs=[]
    pbar=tqdm(loader)
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            probs = model(imgs); preds = torch.argmax(probs,1)
            total+=labels.size(0); correct+=(preds==labels).sum().item()
            acc=100*correct/total; pbar.set_postfix_str(f'Acc={acc:.2f}%'); pbar.update()
            y_labels+=labels.cpu().numpy().tolist(); y_preds+=preds.cpu().numpy().tolist()
            all_probs+=probs.cpu().numpy().tolist()
    pbar.close(); return acc,y_labels,y_preds,all_probs

In [None]:
def main():
    densenet, vit = create_densenet(), create_vit()
    densenet.load_state_dict(torch.load('./best_densenet.pth'))
    vit.load_state_dict(torch.load('./best_vit.pth'))
    ensemble = DenseNetViTEnsemble(densenet, vit, dw=0.95, vw=0.05).to(device)

In [None]:
if __name__=="__main__":
    main()