In [None]:
# Fixed imports (add missing Dataset import)
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
import torchvision

# --- Fixed transform section (remove duplicate definition) ---
def repeat_channel(x):
    return x.repeat(3, 1, 1)  # (C, H, W) → (3, H, W)

# Single transform definition with all steps
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Lambda(repeat_channel),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- Fixed Dataset class ---
class CTScanDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        # Add sorted() for consistent ordering
        self.image_paths = sorted([
            os.path.join(root_dir, f) 
            for f in os.listdir(root_dir) 
            if f.endswith('.png')
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        if self.transform:
            image = self.transform(image)
        return image

# --- Added error handling for data loading ---
try:
    dataset = CTScanDataset(root_dir='hemorrhage_slices', transform=transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False)
except FileNotFoundError as e:
    raise SystemExit(f"Error loading data: {str(e)}")

# --- Improved model loading ---
# --- Fixed model loading section ---
try:
    device = torch.device('cpu')
    
    # 1. Define model architecture (MUST match training architecture)
    class CTScanModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.base_model = torchvision.models.resnet18(pretrained=False)
            self.base_model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.classifier = torch.nn.Linear(512, 2)
        def forward(self, x):
            x = self.base_model(x)
            x = self.classifier(x)
            return x

    
    # 2. Initialize model
    model = CTScanModel()
    
    # 3. Load state dict
    state_dict = torch.load('pytorch_model.bin', map_location=device)
    model.load_state_dict(state_dict)
    
    # 4. Set to eval mode
    model.eval()
    model.to(device)

except Exception as e:
    raise ValueError(f"Model loading failed: {str(e)}") from e


# --- Prediction with tensor device management ---
predictions = []
with torch.inference_mode():
    for batch in dataloader:
        batch = batch.to(device)
        outputs = model(batch)
        _, preds = torch.max(outputs, 1)
        predictions.extend(preds.cpu().numpy())

# Final predictions
class_labels = ['normal', 'abnormal']  # Update with your actual classes
final_predictions = [class_labels[pred] for pred in predictions]

  state_dict = torch.load('pytorch_model.bin', map_location=device)


ValueError: Model loading failed: Error(s) in loading state_dict for CTScanModel:
	Missing key(s) in state_dict: "base_model.conv1.weight", "base_model.bn1.weight", "base_model.bn1.bias", "base_model.bn1.running_mean", "base_model.bn1.running_var", "base_model.layer1.0.conv1.weight", "base_model.layer1.0.bn1.weight", "base_model.layer1.0.bn1.bias", "base_model.layer1.0.bn1.running_mean", "base_model.layer1.0.bn1.running_var", "base_model.layer1.0.conv2.weight", "base_model.layer1.0.bn2.weight", "base_model.layer1.0.bn2.bias", "base_model.layer1.0.bn2.running_mean", "base_model.layer1.0.bn2.running_var", "base_model.layer1.1.conv1.weight", "base_model.layer1.1.bn1.weight", "base_model.layer1.1.bn1.bias", "base_model.layer1.1.bn1.running_mean", "base_model.layer1.1.bn1.running_var", "base_model.layer1.1.conv2.weight", "base_model.layer1.1.bn2.weight", "base_model.layer1.1.bn2.bias", "base_model.layer1.1.bn2.running_mean", "base_model.layer1.1.bn2.running_var", "base_model.layer2.0.conv1.weight", "base_model.layer2.0.bn1.weight", "base_model.layer2.0.bn1.bias", "base_model.layer2.0.bn1.running_mean", "base_model.layer2.0.bn1.running_var", "base_model.layer2.0.conv2.weight", "base_model.layer2.0.bn2.weight", "base_model.layer2.0.bn2.bias", "base_model.layer2.0.bn2.running_mean", "base_model.layer2.0.bn2.running_var", "base_model.layer2.0.downsample.0.weight", "base_model.layer2.0.downsample.1.weight", "base_model.layer2.0.downsample.1.bias", "base_model.layer2.0.downsample.1.running_mean", "base_model.layer2.0.downsample.1.running_var", "base_model.layer2.1.conv1.weight", "base_model.layer2.1.bn1.weight", "base_model.layer2.1.bn1.bias", "base_model.layer2.1.bn1.running_mean", "base_model.layer2.1.bn1.running_var", "base_model.layer2.1.conv2.weight", "base_model.layer2.1.bn2.weight", "base_model.layer2.1.bn2.bias", "base_model.layer2.1.bn2.running_mean", "base_model.layer2.1.bn2.running_var", "base_model.layer3.0.conv1.weight", "base_model.layer3.0.bn1.weight", "base_model.layer3.0.bn1.bias", "base_model.layer3.0.bn1.running_mean", "base_model.layer3.0.bn1.running_var", "base_model.layer3.0.conv2.weight", "base_model.layer3.0.bn2.weight", "base_model.layer3.0.bn2.bias", "base_model.layer3.0.bn2.running_mean", "base_model.layer3.0.bn2.running_var", "base_model.layer3.0.downsample.0.weight", "base_model.layer3.0.downsample.1.weight", "base_model.layer3.0.downsample.1.bias", "base_model.layer3.0.downsample.1.running_mean", "base_model.layer3.0.downsample.1.running_var", "base_model.layer3.1.conv1.weight", "base_model.layer3.1.bn1.weight", "base_model.layer3.1.bn1.bias", "base_model.layer3.1.bn1.running_mean", "base_model.layer3.1.bn1.running_var", "base_model.layer3.1.conv2.weight", "base_model.layer3.1.bn2.weight", "base_model.layer3.1.bn2.bias", "base_model.layer3.1.bn2.running_mean", "base_model.layer3.1.bn2.running_var", "base_model.layer4.0.conv1.weight", "base_model.layer4.0.bn1.weight", "base_model.layer4.0.bn1.bias", "base_model.layer4.0.bn1.running_mean", "base_model.layer4.0.bn1.running_var", "base_model.layer4.0.conv2.weight", "base_model.layer4.0.bn2.weight", "base_model.layer4.0.bn2.bias", "base_model.layer4.0.bn2.running_mean", "base_model.layer4.0.bn2.running_var", "base_model.layer4.0.downsample.0.weight", "base_model.layer4.0.downsample.1.weight", "base_model.layer4.0.downsample.1.bias", "base_model.layer4.0.downsample.1.running_mean", "base_model.layer4.0.downsample.1.running_var", "base_model.layer4.1.conv1.weight", "base_model.layer4.1.bn1.weight", "base_model.layer4.1.bn1.bias", "base_model.layer4.1.bn1.running_mean", "base_model.layer4.1.bn1.running_var", "base_model.layer4.1.conv2.weight", "base_model.layer4.1.bn2.weight", "base_model.layer4.1.bn2.bias", "base_model.layer4.1.bn2.running_mean", "base_model.layer4.1.bn2.running_var", "base_model.fc.weight", "base_model.fc.bias". 
	Unexpected key(s) in state_dict: "vit.embeddings.cls_token", "vit.embeddings.position_embeddings", "vit.embeddings.patch_embeddings.projection.weight", "vit.embeddings.patch_embeddings.projection.bias", "vit.encoder.layer.0.attention.attention.query.weight", "vit.encoder.layer.0.attention.attention.query.bias", "vit.encoder.layer.0.attention.attention.key.weight", "vit.encoder.layer.0.attention.attention.key.bias", "vit.encoder.layer.0.attention.attention.value.weight", "vit.encoder.layer.0.attention.attention.value.bias", "vit.encoder.layer.0.attention.output.dense.weight", "vit.encoder.layer.0.attention.output.dense.bias", "vit.encoder.layer.0.intermediate.dense.weight", "vit.encoder.layer.0.intermediate.dense.bias", "vit.encoder.layer.0.output.dense.weight", "vit.encoder.layer.0.output.dense.bias", "vit.encoder.layer.0.layernorm_before.weight", "vit.encoder.layer.0.layernorm_before.bias", "vit.encoder.layer.0.layernorm_after.weight", "vit.encoder.layer.0.layernorm_after.bias", "vit.encoder.layer.1.attention.attention.query.weight", "vit.encoder.layer.1.attention.attention.query.bias", "vit.encoder.layer.1.attention.attention.key.weight", "vit.encoder.layer.1.attention.attention.key.bias", "vit.encoder.layer.1.attention.attention.value.weight", "vit.encoder.layer.1.attention.attention.value.bias", "vit.encoder.layer.1.attention.output.dense.weight", "vit.encoder.layer.1.attention.output.dense.bias", "vit.encoder.layer.1.intermediate.dense.weight", "vit.encoder.layer.1.intermediate.dense.bias", "vit.encoder.layer.1.output.dense.weight", "vit.encoder.layer.1.output.dense.bias", "vit.encoder.layer.1.layernorm_before.weight", "vit.encoder.layer.1.layernorm_before.bias", "vit.encoder.layer.1.layernorm_after.weight", "vit.encoder.layer.1.layernorm_after.bias", "vit.encoder.layer.2.attention.attention.query.weight", "vit.encoder.layer.2.attention.attention.query.bias", "vit.encoder.layer.2.attention.attention.key.weight", "vit.encoder.layer.2.attention.attention.key.bias", "vit.encoder.layer.2.attention.attention.value.weight", "vit.encoder.layer.2.attention.attention.value.bias", "vit.encoder.layer.2.attention.output.dense.weight", "vit.encoder.layer.2.attention.output.dense.bias", "vit.encoder.layer.2.intermediate.dense.weight", "vit.encoder.layer.2.intermediate.dense.bias", "vit.encoder.layer.2.output.dense.weight", "vit.encoder.layer.2.output.dense.bias", "vit.encoder.layer.2.layernorm_before.weight", "vit.encoder.layer.2.layernorm_before.bias", "vit.encoder.layer.2.layernorm_after.weight", "vit.encoder.layer.2.layernorm_after.bias", "vit.encoder.layer.3.attention.attention.query.weight", "vit.encoder.layer.3.attention.attention.query.bias", "vit.encoder.layer.3.attention.attention.key.weight", "vit.encoder.layer.3.attention.attention.key.bias", "vit.encoder.layer.3.attention.attention.value.weight", "vit.encoder.layer.3.attention.attention.value.bias", "vit.encoder.layer.3.attention.output.dense.weight", "vit.encoder.layer.3.attention.output.dense.bias", "vit.encoder.layer.3.intermediate.dense.weight", "vit.encoder.layer.3.intermediate.dense.bias", "vit.encoder.layer.3.output.dense.weight", "vit.encoder.layer.3.output.dense.bias", "vit.encoder.layer.3.layernorm_before.weight", "vit.encoder.layer.3.layernorm_before.bias", "vit.encoder.layer.3.layernorm_after.weight", "vit.encoder.layer.3.layernorm_after.bias", "vit.encoder.layer.4.attention.attention.query.weight", "vit.encoder.layer.4.attention.attention.query.bias", "vit.encoder.layer.4.attention.attention.key.weight", "vit.encoder.layer.4.attention.attention.key.bias", "vit.encoder.layer.4.attention.attention.value.weight", "vit.encoder.layer.4.attention.attention.value.bias", "vit.encoder.layer.4.attention.output.dense.weight", "vit.encoder.layer.4.attention.output.dense.bias", "vit.encoder.layer.4.intermediate.dense.weight", "vit.encoder.layer.4.intermediate.dense.bias", "vit.encoder.layer.4.output.dense.weight", "vit.encoder.layer.4.output.dense.bias", "vit.encoder.layer.4.layernorm_before.weight", "vit.encoder.layer.4.layernorm_before.bias", "vit.encoder.layer.4.layernorm_after.weight", "vit.encoder.layer.4.layernorm_after.bias", "vit.encoder.layer.5.attention.attention.query.weight", "vit.encoder.layer.5.attention.attention.query.bias", "vit.encoder.layer.5.attention.attention.key.weight", "vit.encoder.layer.5.attention.attention.key.bias", "vit.encoder.layer.5.attention.attention.value.weight", "vit.encoder.layer.5.attention.attention.value.bias", "vit.encoder.layer.5.attention.output.dense.weight", "vit.encoder.layer.5.attention.output.dense.bias", "vit.encoder.layer.5.intermediate.dense.weight", "vit.encoder.layer.5.intermediate.dense.bias", "vit.encoder.layer.5.output.dense.weight", "vit.encoder.layer.5.output.dense.bias", "vit.encoder.layer.5.layernorm_before.weight", "vit.encoder.layer.5.layernorm_before.bias", "vit.encoder.layer.5.layernorm_after.weight", "vit.encoder.layer.5.layernorm_after.bias", "vit.encoder.layer.6.attention.attention.query.weight", "vit.encoder.layer.6.attention.attention.query.bias", "vit.encoder.layer.6.attention.attention.key.weight", "vit.encoder.layer.6.attention.attention.key.bias", "vit.encoder.layer.6.attention.attention.value.weight", "vit.encoder.layer.6.attention.attention.value.bias", "vit.encoder.layer.6.attention.output.dense.weight", "vit.encoder.layer.6.attention.output.dense.bias", "vit.encoder.layer.6.intermediate.dense.weight", "vit.encoder.layer.6.intermediate.dense.bias", "vit.encoder.layer.6.output.dense.weight", "vit.encoder.layer.6.output.dense.bias", "vit.encoder.layer.6.layernorm_before.weight", "vit.encoder.layer.6.layernorm_before.bias", "vit.encoder.layer.6.layernorm_after.weight", "vit.encoder.layer.6.layernorm_after.bias", "vit.encoder.layer.7.attention.attention.query.weight", "vit.encoder.layer.7.attention.attention.query.bias", "vit.encoder.layer.7.attention.attention.key.weight", "vit.encoder.layer.7.attention.attention.key.bias", "vit.encoder.layer.7.attention.attention.value.weight", "vit.encoder.layer.7.attention.attention.value.bias", "vit.encoder.layer.7.attention.output.dense.weight", "vit.encoder.layer.7.attention.output.dense.bias", "vit.encoder.layer.7.intermediate.dense.weight", "vit.encoder.layer.7.intermediate.dense.bias", "vit.encoder.layer.7.output.dense.weight", "vit.encoder.layer.7.output.dense.bias", "vit.encoder.layer.7.layernorm_before.weight", "vit.encoder.layer.7.layernorm_before.bias", "vit.encoder.layer.7.layernorm_after.weight", "vit.encoder.layer.7.layernorm_after.bias", "vit.encoder.layer.8.attention.attention.query.weight", "vit.encoder.layer.8.attention.attention.query.bias", "vit.encoder.layer.8.attention.attention.key.weight", "vit.encoder.layer.8.attention.attention.key.bias", "vit.encoder.layer.8.attention.attention.value.weight", "vit.encoder.layer.8.attention.attention.value.bias", "vit.encoder.layer.8.attention.output.dense.weight", "vit.encoder.layer.8.attention.output.dense.bias", "vit.encoder.layer.8.intermediate.dense.weight", "vit.encoder.layer.8.intermediate.dense.bias", "vit.encoder.layer.8.output.dense.weight", "vit.encoder.layer.8.output.dense.bias", "vit.encoder.layer.8.layernorm_before.weight", "vit.encoder.layer.8.layernorm_before.bias", "vit.encoder.layer.8.layernorm_after.weight", "vit.encoder.layer.8.layernorm_after.bias", "vit.encoder.layer.9.attention.attention.query.weight", "vit.encoder.layer.9.attention.attention.query.bias", "vit.encoder.layer.9.attention.attention.key.weight", "vit.encoder.layer.9.attention.attention.key.bias", "vit.encoder.layer.9.attention.attention.value.weight", "vit.encoder.layer.9.attention.attention.value.bias", "vit.encoder.layer.9.attention.output.dense.weight", "vit.encoder.layer.9.attention.output.dense.bias", "vit.encoder.layer.9.intermediate.dense.weight", "vit.encoder.layer.9.intermediate.dense.bias", "vit.encoder.layer.9.output.dense.weight", "vit.encoder.layer.9.output.dense.bias", "vit.encoder.layer.9.layernorm_before.weight", "vit.encoder.layer.9.layernorm_before.bias", "vit.encoder.layer.9.layernorm_after.weight", "vit.encoder.layer.9.layernorm_after.bias", "vit.encoder.layer.10.attention.attention.query.weight", "vit.encoder.layer.10.attention.attention.query.bias", "vit.encoder.layer.10.attention.attention.key.weight", "vit.encoder.layer.10.attention.attention.key.bias", "vit.encoder.layer.10.attention.attention.value.weight", "vit.encoder.layer.10.attention.attention.value.bias", "vit.encoder.layer.10.attention.output.dense.weight", "vit.encoder.layer.10.attention.output.dense.bias", "vit.encoder.layer.10.intermediate.dense.weight", "vit.encoder.layer.10.intermediate.dense.bias", "vit.encoder.layer.10.output.dense.weight", "vit.encoder.layer.10.output.dense.bias", "vit.encoder.layer.10.layernorm_before.weight", "vit.encoder.layer.10.layernorm_before.bias", "vit.encoder.layer.10.layernorm_after.weight", "vit.encoder.layer.10.layernorm_after.bias", "vit.encoder.layer.11.attention.attention.query.weight", "vit.encoder.layer.11.attention.attention.query.bias", "vit.encoder.layer.11.attention.attention.key.weight", "vit.encoder.layer.11.attention.attention.key.bias", "vit.encoder.layer.11.attention.attention.value.weight", "vit.encoder.layer.11.attention.attention.value.bias", "vit.encoder.layer.11.attention.output.dense.weight", "vit.encoder.layer.11.attention.output.dense.bias", "vit.encoder.layer.11.intermediate.dense.weight", "vit.encoder.layer.11.intermediate.dense.bias", "vit.encoder.layer.11.output.dense.weight", "vit.encoder.layer.11.output.dense.bias", "vit.encoder.layer.11.layernorm_before.weight", "vit.encoder.layer.11.layernorm_before.bias", "vit.encoder.layer.11.layernorm_after.weight", "vit.encoder.layer.11.layernorm_after.bias", "vit.layernorm.weight", "vit.layernorm.bias". 
	size mismatch for classifier.weight: copying a param with shape torch.Size([6, 768]) from checkpoint, the shape in current model is torch.Size([2, 512]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([2]).