In [2]:
import nibabel as nib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

In [3]:
# Load the CT volume and infection mask
ct_volume = nib.load(r'C:\Users\BAPS\Documents\Dicom Analaysis\Dicom_Analyzer\Slice_Classifier\Dataset\Covid_Dataset\ct_scans\coronacases_org_001.nii').get_fdata()          # Shape: (512, 512, num_slices)
infection_mask = nib.load(r'C:\Users\BAPS\Documents\Dicom Analaysis\Dicom_Analyzer\Slice_Classifier\Dataset\Covid_Dataset\infection_mask\coronacases_001.nii').get_fdata()  # Same shape as ct_volume


# Verify shapes match
assert ct_volume.shape == infection_mask.shape, "CT and mask shapes do not match!"

In [4]:
# Identify slices with infections (1) vs. normal (0)
abnormal_slices = []
for slice_idx in range(infection_mask.shape[2]):
    if np.any(infection_mask[:, :, slice_idx]):
        abnormal_slices.append(1)
    else:
        abnormal_slices.append(0)

# Create DataFrame
df = pd.DataFrame({
    'slice_number': list(range(infection_mask.shape[2])),  # 0-based index
    'label': abnormal_slices
})

print(f"Total slices: {len(df)}")
print(f"Abnormal slices: {df['label'].sum()} ({(df['label'].sum()/len(df))*100:.2f}%)")

Total slices: 301
Abnormal slices: 161 (53.49%)


In [5]:
# Split into train/test (stratify to preserve class balance)
X = df['slice_number'].values  # Slice indices (0-based)
y = df['label'].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    stratify=y,  # Critical for imbalanced data
    random_state=42
)

print(f"Train slices: {len(X_train)}")
print(f"Test slices: {len(X_test)}")

Train slices: 240
Test slices: 61


In [9]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import roc_auc_score, confusion_matrix
import numpy as np
import time
from tqdm.notebook import tqdm  # using the notebook version for progress bars

# --- Custom dataset for CT slices ---
class CTDataset(Dataset):
    def __init__(self, slice_indices, labels, ct_volume, mean, std, transform=None):
        self.slice_indices = slice_indices
        self.labels = labels
        self.ct_volume = ct_volume
        self.mean = mean
        self.std = std
        self.transform = transform
        print(f"Initialized CTDataset with {len(self.slice_indices)} samples.")

    def __len__(self):
        return len(self.slice_indices)

    def __getitem__(self, idx):
        slice_idx = self.slice_indices[idx]
        ct_slice = self.ct_volume[:, :, slice_idx].astype(np.float32)
        # Normalize the slice to [0, 1]
        ct_slice = (ct_slice - np.min(ct_slice)) / (np.max(ct_slice) - np.min(ct_slice) + 1e-6)
        ct_slice = torch.tensor(ct_slice).unsqueeze(0)  # shape: [1, H, W]
        # Resize to (299, 299) using bilinear interpolation
        ct_slice = F.interpolate(ct_slice.unsqueeze(0), size=(299, 299),
                                 mode='bilinear', align_corners=False).squeeze(0)
        if self.transform:
            ct_slice = self.transform(ct_slice)
        # Standardize the image using dataset mean and std
        ct_slice = (ct_slice - self.mean) / self.std
        label = torch.tensor(self.labels[idx]).float()
        return ct_slice, label

# --- Compute dataset statistics ---
# (Make sure X_train, y_train, and ct_volume are defined in your notebook)
print("Computing dataset statistics...")
temp_dataset = CTDataset(X_train, y_train, ct_volume, mean=0, std=1, transform=None)
temp_loader = DataLoader(temp_dataset, batch_size=16, shuffle=False, num_workers=0)
all_slices = torch.cat([slices for slices, _ in temp_loader])
dataset_mean, dataset_std = all_slices.mean().item(), all_slices.std().item()
print(f"Dataset mean: {dataset_mean:.4f}, Dataset std: {dataset_std:.4f}")

# --- Create datasets with augmentations ---
print("Creating training and test datasets with augmentations...")
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1)
])
train_dataset = CTDataset(X_train, y_train, ct_volume, dataset_mean, dataset_std, transform=train_transform)
test_dataset = CTDataset(X_test, y_test, ct_volume, dataset_mean, dataset_std, transform=None)
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

batch_size = 16
print("Creating DataLoaders...")
# For notebooks and CPU, we set num_workers=0
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)

# --- Define the model ---
print("Defining the model...")
class InceptionResNetCT(nn.Module):
    def __init__(self):
        super().__init__()
        self.base_model = timm.create_model('inception_resnet_v2', pretrained=True, features_only=False)
        # Modify the first convolution to accept 1-channel input instead of 3
        original_conv = self.base_model.conv2d_1a.conv
        self.base_model.conv2d_1a.conv = nn.Conv2d(
            1, original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=False
        )
        with torch.no_grad():
            # Initialize the single-channel conv by averaging the weights of the original 3 channels
            self.base_model.conv2d_1a.conv.weight[:, 0] = original_conv.weight.mean(dim=1)
        # Enhanced classifier
        num_features = self.base_model.classif.in_features
        self.base_model.classif = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.base_model(x).squeeze()

device = torch.device('cpu')
print(f"Using device: {device}")
model = InceptionResNetCT().to(device)

# --- Training setup ---
print("Setting up training...")
# Compute positive class weight (assumes y_train is binary with 0s and 1s)
pos_weight = torch.tensor([(len(y_train) - sum(y_train)) / sum(y_train)], device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=3e-4,
                                          steps_per_epoch=len(train_loader), epochs=50)

best_auc = 0
patience = 0
train_loss_history = []
val_auc_history = []
best_val_preds = None
best_val_labels = None
best_epoch = 0

print("Starting training on CPU...")
for epoch in range(50):
    model.train()
    epoch_loss = 0
    start_time = time.time()
    print(f"\n=== Epoch {epoch+1} ===")
    
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1} Batches")
    for batch_idx, (inputs, labels) in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        epoch_loss += loss.item() * inputs.size(0)
        pbar.set_postfix(loss=f"{loss.item():.4f}")
    
    epoch_loss /= len(train_dataset)
    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1} completed in {elapsed:.1f}s with average loss: {epoch_loss:.4f}")
    
    # --- Validation ---
    model.eval()
    val_preds, val_labels = [], []
    print("Running validation...")
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Validation", leave=False):
            inputs = inputs.to(device)
            outputs = model(inputs)
            val_preds.extend(torch.sigmoid(outputs).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    try:
        val_auc = roc_auc_score(val_labels, val_preds)
    except Exception as e:
        print(f"Error calculating ROC AUC: {e}")
        val_auc = 0
    train_loss_history.append(epoch_loss)
    val_auc_history.append(val_auc)
    print(f"Epoch {epoch+1}: Train Loss: {epoch_loss:.4f} | Val AUC: {val_auc:.4f}")
    
    if val_auc > best_auc:
        best_auc = val_auc
        best_epoch = epoch + 1
        best_val_preds = val_preds
        best_val_labels = val_labels
        torch.save(model.state_dict(), 'best_model.pth')
        print("Best model updated and saved as best_model.pth")
        patience = 0
    else:
        patience += 1
        print(f"No improvement. Patience: {patience}/50")
        if patience >= 50:
            print("Early stopping triggered.")
            break

# --- After training: Compute confusion matrix ---
print("Computing confusion matrix for the best validation predictions...")
best_val_preds_binary = (np.array(best_val_preds) >= 0.5).astype(int)
cm = confusion_matrix(best_val_labels, best_val_preds_binary)
print("Confusion Matrix:")
print(cm)

# --- Save comprehensive checkpoint ---
print("Saving final checkpoint...")
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'best_auc': best_auc,
    'best_epoch': best_epoch,
    'train_loss_history': train_loss_history,
    'val_auc_history': val_auc_history,
    'best_val_preds': best_val_preds,
    'best_val_labels': best_val_labels,
    'confusion_matrix': cm,
}
torch.save(checkpoint, 'final_checkpoint.pth')
print("Checkpoint saved to final_checkpoint.pth")

# --- Optimize for inference: Script the model (no half precision conversion for CPU) ---
print("Optimizing model for inference...")
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.eval()
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_model.pt')
print("Scripted model saved to scripted_model.pt")


Computing dataset statistics...
Initialized CTDataset with 240 samples.
Dataset mean: 0.1509, Dataset std: 0.1687
Creating training and test datasets with augmentations...
Initialized CTDataset with 240 samples.
Initialized CTDataset with 61 samples.
Train dataset size: 240
Test dataset size: 61
Creating DataLoaders...
Defining the model...
Using device: cpu
Setting up training...
Starting training on CPU...

=== Epoch 1 ===


Epoch 1 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 1 completed in 281.5s with average loss: 0.6199
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1: Train Loss: 0.6199 | Val AUC: 0.9773
Best model updated and saved as best_model.pth

=== Epoch 2 ===


Epoch 2 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 2 completed in 276.3s with average loss: 0.5539
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 2: Train Loss: 0.5539 | Val AUC: 1.0000
Best model updated and saved as best_model.pth

=== Epoch 3 ===


Epoch 3 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 3 completed in 258.9s with average loss: 0.4147
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 3: Train Loss: 0.4147 | Val AUC: 1.0000
No improvement. Patience: 1/50

=== Epoch 4 ===


Epoch 4 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 4 completed in 254.8s with average loss: 0.2383
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 4: Train Loss: 0.2383 | Val AUC: 1.0000
No improvement. Patience: 2/50

=== Epoch 5 ===


Epoch 5 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 5 completed in 249.7s with average loss: 0.0797
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 5: Train Loss: 0.0797 | Val AUC: 1.0000
No improvement. Patience: 3/50

=== Epoch 6 ===


Epoch 6 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 6 completed in 248.0s with average loss: 0.0446
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 6: Train Loss: 0.0446 | Val AUC: 1.0000
No improvement. Patience: 4/50

=== Epoch 7 ===


Epoch 7 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 7 completed in 237.9s with average loss: 0.0657
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 7: Train Loss: 0.0657 | Val AUC: 0.9903
No improvement. Patience: 5/50

=== Epoch 8 ===


Epoch 8 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 8 completed in 279.4s with average loss: 0.0889
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 8: Train Loss: 0.0889 | Val AUC: 1.0000
No improvement. Patience: 6/50

=== Epoch 9 ===


Epoch 9 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 9 completed in 352.3s with average loss: 0.0676
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 9: Train Loss: 0.0676 | Val AUC: 1.0000
No improvement. Patience: 7/50

=== Epoch 10 ===


Epoch 10 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 10 completed in 300.5s with average loss: 0.1052
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 10: Train Loss: 0.1052 | Val AUC: 1.0000
No improvement. Patience: 8/50

=== Epoch 11 ===


Epoch 11 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 11 completed in 268.5s with average loss: 0.0460
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 11: Train Loss: 0.0460 | Val AUC: 1.0000
No improvement. Patience: 9/50

=== Epoch 12 ===


Epoch 12 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 12 completed in 241.2s with average loss: 0.0380
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 12: Train Loss: 0.0380 | Val AUC: 1.0000
No improvement. Patience: 10/50

=== Epoch 13 ===


Epoch 13 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 13 completed in 296.5s with average loss: 0.0600
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 13: Train Loss: 0.0600 | Val AUC: 1.0000
No improvement. Patience: 11/50

=== Epoch 14 ===


Epoch 14 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 14 completed in 294.3s with average loss: 0.0448
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 14: Train Loss: 0.0448 | Val AUC: 0.9870
No improvement. Patience: 12/50

=== Epoch 15 ===


Epoch 15 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 15 completed in 287.6s with average loss: 0.0472
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 15: Train Loss: 0.0472 | Val AUC: 1.0000
No improvement. Patience: 13/50

=== Epoch 16 ===


Epoch 16 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 16 completed in 296.2s with average loss: 0.0404
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 16: Train Loss: 0.0404 | Val AUC: 1.0000
No improvement. Patience: 14/50

=== Epoch 17 ===


Epoch 17 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 17 completed in 290.4s with average loss: 0.0566
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 17: Train Loss: 0.0566 | Val AUC: 1.0000
No improvement. Patience: 15/50

=== Epoch 18 ===


Epoch 18 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 18 completed in 286.7s with average loss: 0.0962
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 18: Train Loss: 0.0962 | Val AUC: 1.0000
No improvement. Patience: 16/50

=== Epoch 19 ===


Epoch 19 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 19 completed in 304.9s with average loss: 0.0199
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 19: Train Loss: 0.0199 | Val AUC: 1.0000
No improvement. Patience: 17/50

=== Epoch 20 ===


Epoch 20 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 20 completed in 299.5s with average loss: 0.0168
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 20: Train Loss: 0.0168 | Val AUC: 1.0000
No improvement. Patience: 18/50

=== Epoch 21 ===


Epoch 21 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 21 completed in 302.3s with average loss: 0.0380
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 21: Train Loss: 0.0380 | Val AUC: 1.0000
No improvement. Patience: 19/50

=== Epoch 22 ===


Epoch 22 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 22 completed in 311.6s with average loss: 0.0471
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 22: Train Loss: 0.0471 | Val AUC: 1.0000
No improvement. Patience: 20/50

=== Epoch 23 ===


Epoch 23 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 23 completed in 287.9s with average loss: 0.0146
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 23: Train Loss: 0.0146 | Val AUC: 1.0000
No improvement. Patience: 21/50

=== Epoch 24 ===


Epoch 24 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 24 completed in 311.2s with average loss: 0.0157
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 24: Train Loss: 0.0157 | Val AUC: 1.0000
No improvement. Patience: 22/50

=== Epoch 25 ===


Epoch 25 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 25 completed in 286.7s with average loss: 0.0026
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 25: Train Loss: 0.0026 | Val AUC: 1.0000
No improvement. Patience: 23/50

=== Epoch 26 ===


Epoch 26 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 26 completed in 272.5s with average loss: 0.0101
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 26: Train Loss: 0.0101 | Val AUC: 1.0000
No improvement. Patience: 24/50

=== Epoch 27 ===


Epoch 27 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 27 completed in 304.1s with average loss: 0.0195
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 27: Train Loss: 0.0195 | Val AUC: 1.0000
No improvement. Patience: 25/50

=== Epoch 28 ===


Epoch 28 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 28 completed in 283.2s with average loss: 0.0045
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 28: Train Loss: 0.0045 | Val AUC: 1.0000
No improvement. Patience: 26/50

=== Epoch 29 ===


Epoch 29 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 29 completed in 257.5s with average loss: 0.0467
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 29: Train Loss: 0.0467 | Val AUC: 1.0000
No improvement. Patience: 27/50

=== Epoch 30 ===


Epoch 30 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 30 completed in 256.3s with average loss: 0.0088
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 30: Train Loss: 0.0088 | Val AUC: 1.0000
No improvement. Patience: 28/50

=== Epoch 31 ===


Epoch 31 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 31 completed in 253.5s with average loss: 0.0005
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 31: Train Loss: 0.0005 | Val AUC: 1.0000
No improvement. Patience: 29/50

=== Epoch 32 ===


Epoch 32 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 32 completed in 266.8s with average loss: 0.0003
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 32: Train Loss: 0.0003 | Val AUC: 1.0000
No improvement. Patience: 30/50

=== Epoch 33 ===


Epoch 33 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 33 completed in 244.2s with average loss: 0.0002
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 33: Train Loss: 0.0002 | Val AUC: 1.0000
No improvement. Patience: 31/50

=== Epoch 34 ===


Epoch 34 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 34 completed in 340.6s with average loss: 0.0011
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 34: Train Loss: 0.0011 | Val AUC: 1.0000
No improvement. Patience: 32/50

=== Epoch 35 ===


Epoch 35 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 35 completed in 248.1s with average loss: 0.0004
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 35: Train Loss: 0.0004 | Val AUC: 1.0000
No improvement. Patience: 33/50

=== Epoch 36 ===


Epoch 36 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 36 completed in 234.6s with average loss: 0.0002
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 36: Train Loss: 0.0002 | Val AUC: 1.0000
No improvement. Patience: 34/50

=== Epoch 37 ===


Epoch 37 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 37 completed in 271.5s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 37: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 35/50

=== Epoch 38 ===


Epoch 38 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 38 completed in 285.0s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 38: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 36/50

=== Epoch 39 ===


Epoch 39 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 39 completed in 273.3s with average loss: 0.0002
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 39: Train Loss: 0.0002 | Val AUC: 1.0000
No improvement. Patience: 37/50

=== Epoch 40 ===


Epoch 40 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 40 completed in 250.4s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 40: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 38/50

=== Epoch 41 ===


Epoch 41 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 41 completed in 237.9s with average loss: 0.0015
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 41: Train Loss: 0.0015 | Val AUC: 1.0000
No improvement. Patience: 39/50

=== Epoch 42 ===


Epoch 42 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 42 completed in 299.7s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 42: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 40/50

=== Epoch 43 ===


Epoch 43 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 43 completed in 352.8s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 43: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 41/50

=== Epoch 44 ===


Epoch 44 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 44 completed in 316.3s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 44: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 42/50

=== Epoch 45 ===


Epoch 45 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 45 completed in 290.5s with average loss: 0.0007
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 45: Train Loss: 0.0007 | Val AUC: 1.0000
No improvement. Patience: 43/50

=== Epoch 46 ===


Epoch 46 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 46 completed in 270.7s with average loss: 0.0086
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 46: Train Loss: 0.0086 | Val AUC: 1.0000
No improvement. Patience: 44/50

=== Epoch 47 ===


Epoch 47 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 47 completed in 296.5s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 47: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 45/50

=== Epoch 48 ===


Epoch 48 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 48 completed in 284.6s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 48: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 46/50

=== Epoch 49 ===


Epoch 49 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 49 completed in 276.3s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 49: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 47/50

=== Epoch 50 ===


Epoch 50 Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 50 completed in 261.4s with average loss: 0.0001
Running validation...


Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 50: Train Loss: 0.0001 | Val AUC: 1.0000
No improvement. Patience: 48/50
Computing confusion matrix for the best validation predictions...
Confusion Matrix:
[[28  0]
 [ 1 32]]
Saving final checkpoint...
Checkpoint saved to final_checkpoint.pth
Optimizing model for inference...
Scripted model saved to scripted_model.pt


In [1]:
import os
import pydicom
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

# Set the path to your root folder containing DICOM slices (and subfolders)
root_folder = r"C:\Users\BAPS\Documents\Dicom Analaysis\Dicom_Dataset\Single Sliced Dataset\CMB-MML\MSB-00140"

# Make sure your model is in evaluation mode
model.eval()

# Set your device (e.g., cuda or cpu)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# IMPORTANT: Set these to your computed dataset mean and std
dataset_mean = 0.5  # <-- replace with your actual mean
dataset_std = 0.25  # <-- replace with your actual std

# List to store slices that are predicted to be anomalous
anomalies = []

# Gather all file paths from the folder and its subdirectories
all_file_paths = []
for root, dirs, files in os.walk(root_folder):
    for file in files:
        all_file_paths.append(os.path.join(root, file))

# Create a progress bar for all files
progress_bar = tqdm(all_file_paths, desc="Processing DICOM slices", ncols=100)

for file_path in progress_bar:
    # Try reading the file as a DICOM; if it fails, skip it.
    try:
        dicom_data = pydicom.dcmread(file_path)
    except Exception:
        continue

    # Get the pixel array from the DICOM and convert to float32
    ct_slice = dicom_data.pixel_array.astype(np.float32)
    
    # Normalize the slice to [0, 1]
    ct_slice_norm = (ct_slice - np.min(ct_slice)) / (np.max(ct_slice) - np.min(ct_slice) + 1e-8)
    
    # Convert the normalized slice to a PyTorch tensor and add a channel dimension: [1, H, W]
    ct_tensor = torch.tensor(ct_slice_norm).unsqueeze(0)  # shape: [1, H, W]
    
    # Resize the tensor to 299x299 using F.interpolate.
    # First, add a batch dimension so the shape becomes [1, 1, H, W]
    ct_tensor = ct_tensor.unsqueeze(0)
    ct_tensor = F.interpolate(ct_tensor, size=(299, 299), mode='bilinear', align_corners=False)
    # Remove the batch dimension so it's back to [1, 299, 299]
    ct_tensor = ct_tensor.squeeze(0)
    
    # Apply standardization using the dataset mean and std (same as in your training)
    ct_tensor = (ct_tensor - dataset_mean) / dataset_std
    
    # Add a batch dimension: [1, 1, 299, 299] and move to the correct device
    ct_tensor = ct_tensor.unsqueeze(0).to(device)
    
    # Run model inference
    with torch.no_grad():
        output = model(ct_tensor).squeeze()  # output is the logit
    probability = torch.sigmoid(output).item()  # convert logit to probability
    prediction = 1 if probability > 0.5 else 0  # threshold at 0.5
    
    # Update the progress bar with the current file's info
    progress_bar.set_postfix_str(f"File: {os.path.basename(file_path)} | Pred: {prediction} | Prob: {probability:.4f}")
    
    # Save details if an anomaly is detected
    if prediction == 1:
        anomalies.append((file_path, probability))

# Print a summary of slices with potential anomalies
print("\nAnomalies detected in the following slices:")
if anomalies:
    for path, prob in anomalies:
        print(f"{path} - Probability: {prob:.4f}")
else:
    print("No anomalies detected in the provided folder and its subfolders.")


NameError: name 'model' is not defined