- Loads your dataset with the same folder structure (COVID/, Non-COVID/, Normal/).

- Preprocesses images for COVID-CAPS.

- Loads pre-trained weights (pre-train.h5) from COVID-CAPS.

- Replaces the final Dense layer for 3-class classification.

- Fine-tunes on your dataset.

- Evaluates with a classification report + confusion matrix.

- Saves the fine-tuned model (covid_caps_finetuned.h5)


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau

import tqdm
import cv2

In [None]:
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20
LR = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
base_infect = "data_cxr/covid_qu_ex/Infection Segmentation Data/Infection Segmentation Data"
base_lung = "data_cxr/covid_qu_ex/Lung Segmentation Data/Lung Segmentation Data"
output_base = "data_cxr/covid_qu_ex_masked"

splits = ["Train", "Val", "Test"]
classes = ["COVID-19", "Normal", "Non-COVID"]

os.makedirs(output_base, exist_ok=True)

for split in splits:
    for cls in classes:
        img_dir = os.path.join(base_infect, split, cls, "images")
        mask_dir = os.path.join(base_lung, split, cls, "lung masks")
        out_dir = os.path.join(output_base, split.lower(), cls)
        os.makedirs(out_dir, exist_ok=True)
        
        for fname in tqdm.tqdm(os.listdir(img_dir), desc=f"{split}-{cls}"):
            img_path = os.path.join(img_dir, fname)
            mask_path = os.path.join(mask_dir, fname)
            
            if not os.path.exists(mask_path):
                continue
            
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            mask = (mask > 0).astype(np.uint8)
            masked_img = img * mask
            masked_img_rgb = cv2.cvtColor(masked_img, cv2.COLOR_GRAY2BGR)
            cv2.imwrite(os.path.join(out_dir, fname), masked_img_rgb)


In [3]:
data_dir = "data_cxr/covid_qu_ex_masked"  # after your preprocessing
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
test_dir = os.path.join(data_dir, "test")

In [4]:
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # grayscale normalization
])

val_test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [5]:
train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_test_transforms)
test_dataset = datasets.ImageFolder(test_dir, transform=val_test_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

num_classes = len(train_dataset.classes)
print("Classes:", train_dataset.classes)

Classes: ['COVID-19', 'Non-COVID', 'Normal']


In [10]:
class CapsuleNet(nn.Module):
    def __init__(self, num_classes=3):
        super(CapsuleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=0)
        self.primary_caps = nn.Conv2d(64, 32 * 8, kernel_size=9, stride=2, padding=0)
        
        # figure out flattened size automatically
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 224, 224)  # batch=1, RGB, 224x224
            out = torch.relu(self.conv1(dummy))
            out = self.primary_caps(out)
            flat_size = out.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flat_size, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.primary_caps(x)
        x = x.view(x.size(0), -1)  # flatten
        x = torch.relu(self.fc1(x))
        out = self.fc2(x)
        return out


model = CapsuleNet(num_classes=num_classes).to(DEVICE)

In [7]:
pretrain_path = "pre-train.h5"
if os.path.exists(pretrain_path):
    try:
        state_dict = torch.load(pretrain_path, map_location=DEVICE)
        model.load_state_dict(state_dict, strict=False)
        print("✅ Loaded pretrained COVID-CAPS weights")
    except Exception as e:
        print("⚠️ Could not load pretrained weights:", e)

⚠️ Could not load pretrained weights: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 72

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5)

In [11]:
best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    train_loss, correct, total = 0, 0, 0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_acc = 100. * correct / total

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_acc = 100. * val_correct / val_total
    print(f"Epoch [{epoch+1}/{EPOCHS}] "
          f"Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "covid_caps_best.pth")
        print("💾 Saved Best Model")

    scheduler.step(val_acc)


: 

In [None]:
model.load_state_dict(torch.load("covid_caps_best.pth"))
model.eval()
test_correct, test_total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_acc = 100. * test_correct / test_total
print(f"🎯 Test Accuracy: {test_acc:.2f}%")