In [1]:
import torch
from torchvision.models.detection import ssd300_vgg16
from torch.utils.data import DataLoader, Subset
from pathlib import Path
import cv2
from torchvision import transforms
from tqdm import tqdm
import os
import numpy as np
from sklearn.metrics import roc_auc_score, matthews_corrcoef, cohen_kappa_score, log_loss

In [3]:
class MURADataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, label_dir, transform=None):
        self.img_dir = Path(img_dir)
        self.label_dir = Path(label_dir)
        self.images = sorted(list(self.img_dir.glob("*.png")))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label_path = self.label_dir / (img_path.stem + ".txt")
        
        img = cv2.imread(str(img_path))
        if img is None:
            return None
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        try:
            with open(label_path, "r") as f:
                line = f.readline().strip().split()
                if not line:
                    return None
                class_id, x, y, w, h = map(float, line)
                x_min = (x - w/2) * 300
                y_min = (y - h/2) * 300
                x_max = (x + w/2) * 300
                y_max = (y + h/2) * 300
                if x_max <= x_min or y_max <= y_min:
                    return None
                target = {
                    "boxes": torch.tensor([[x_min, y_min, x_max, y_max]], dtype=torch.float32),
                    "labels": torch.tensor([int(class_id)], dtype=torch.int64)
                }
        except Exception:
            return None

        if self.transform:
            img = self.transform(img)
        
        return img, target

In [3]:
def custom_collate(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return [], []
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return images, targets


In [4]:
class MobileNetV2Backbone(torch.nn.Module):
    def __init__(self):
        super().__init__()
        base = mobilenet_v2(weights="DEFAULT")
        self.features = base.features
        
        self.stage1 = torch.nn.Sequential(*base.features[0:7])   # 38x38, 32 channels
        self.stage2 = torch.nn.Sequential(*base.features[7:11])  # 19x19, 64 channels
        self.stage3 = torch.nn.Sequential(*base.features[11:18]) # 10x10, 320 channels
        
        self.extra = torch.nn.ModuleList([
            torch.nn.Conv2d(320, 256, 3, stride=2, padding=1), # 5x5
            torch.nn.Conv2d(256, 256, 3, stride=2, padding=1), # 3x3
            torch.nn.Conv2d(256, 256, 3, stride=2, padding=0)  # 1x1
        ])
        
        self.adjust_channels = torch.nn.ModuleList([
            torch.nn.Conv2d(32, 512, 1),  # 38x38
            torch.nn.Conv2d(64, 512, 1),  # 19x19
            torch.nn.Conv2d(320, 512, 1), # 10x10
            torch.nn.Conv2d(256, 256, 1), # 5x5
            torch.nn.Conv2d(256, 256, 1), # 3x3
            torch.nn.Conv2d(256, 256, 1)  # 1x1
        ])
    
    def forward(self, x):
        features = {}
        f1 = self.stage1(x)
        f2 = self.stage2(f1)
        f3 = self.stage3(f2)
        
        features["0"] = self.adjust_channels[0](f1)
        features["1"] = self.adjust_channels[1](f2)
        features["2"] = self.adjust_channels[2](f3)
        
        x = f3
        for i, layer in enumerate(self.extra):
            x = layer(x)
            features[str(i + 3)] = self.adjust_channels[i + 3](x)
        
        return features

In [5]:
# Training script
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = MURADataset(
    "D:/Sem 6 project/MURA_YOLO/train/images",
    "D:/Sem 6 project/MURA_YOLO/train/labels",
    transform=transform
)

# Pre-filter dataset to only include valid image-label pairs
filtered_indices = []
for idx in range(len(train_dataset)):
    try:
        item = train_dataset[idx]
        if item is not None:
            filtered_indices.append(idx)
    except Exception:
        pass

train_dataset = torch.utils.data.Subset(train_dataset, filtered_indices)

train_loader = DataLoader(
    train_dataset,
    batch_size=64,  # Kept at 64 for better GPU utilization
    shuffle=True,
    num_workers=0,  # Set to 0 for Windows; try 2 if on Linux
    collate_fn=custom_collate,
    pin_memory=True
)

# Prepare the model
device = torch.device("cuda")
model = ssd300_vgg16(weights=None)
model.backbone = MobileNetV2Backbone()

# Create a dummy input to get feature map sizes
with torch.no_grad():
    dummy_input = torch.zeros(1, 3, 300, 300).to(device)
    features = model.backbone.to(device)(dummy_input)
    in_channels = [features[k].shape[1] for k in features.keys()]

# Modify classification and regression heads
model.head.classification_head = SSDClassificationHead(
    in_channels, 
    model.anchor_generator.num_anchors_per_location(), 
    num_classes=15  # 14 classes + background
).to(device)

model.head.regression_head = SSDRegressionHead(
    in_channels, 
    model.anchor_generator.num_anchors_per_location()
).to(device)

model.to(device)

SSD(
  (backbone): MobileNetV2Backbone(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(16, 96, kernel_size=(

In [6]:
# Optimizer and Training Loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.002, momentum=0.9, weight_decay=0.0005)
checkpoint_dir = Path("D:/Sem 6 project/checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

# Check for existing checkpoint to resume
start_epoch = 0
checkpoint_path = checkpoint_dir / "latest_checkpoint.pt"
if checkpoint_path.exists():
    print(f"Resuming from checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")
else:
    print("No checkpoint found, starting from epoch 1")

num_epochs = 25
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0
    total_accuracy = 0
    batch_count = 0
    
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        for batch_idx, (images, targets) in enumerate(train_loader):
            if not images or len(images) != len(targets):
                pbar.update(1)
                continue
            
            try:
                images = [img.to(device) for img in images]
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                
                # Compute loss in training mode
                loss_dict = model(images, targets)
                losses = sum(loss_dict.values())
                
                # Compute predictions in evaluation mode for accuracy
                with torch.no_grad():
                    model.eval()
                    preds = model(images)
                    model.train()
                    
                    batch_accuracy = 0
                    for pred, target in zip(preds, targets):
                        pred_labels = pred['labels']
                        pred_scores = pred['scores']
                        gt_labels = target['labels']
                        if len(pred_labels) > 0:
                            max_score_idx = torch.argmax(pred_scores)
                            pred_label = pred_labels[max_score_idx]
                            batch_accuracy += (pred_label == gt_labels[0]).float().mean().item()
                        else:
                            batch_accuracy += 0  # No predictions, count as incorrect
                    batch_accuracy /= len(targets) if len(targets) > 0 else 1
                
                # Backpropagation
                optimizer.zero_grad()
                losses.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
                optimizer.step()
                
                total_loss += losses.item()
                total_accuracy += batch_accuracy
                batch_count += 1
                
                allocated = torch.cuda.memory_allocated() / 1e9
                reserved = torch.cuda.memory_reserved() / 1e9
                pbar.set_postfix({
                    'loss': f"{losses.item():.4f}",
                    'acc': f"{batch_accuracy:.4f}",
                    'mem': f"{allocated:.2f}/{reserved:.2f}GB"
                })
            except RuntimeError as e:
                print(f"Training error in batch {batch_idx+1}: {e}")
                torch.cuda.empty_cache()
            finally:
                pbar.update(1)
    
    avg_loss = total_loss / batch_count if batch_count > 0 else 0
    avg_accuracy = total_accuracy / batch_count if batch_count > 0 else 0
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_accuracy:.4f}")
    
    # Save checkpoint for the epoch
    checkpoint_path = checkpoint_dir / f"ssd_mobilenetv2_mura_epoch{epoch+1}.pt"
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Saved checkpoint: {checkpoint_path}")
    
    # Save latest checkpoint for resuming
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_dir / "latest_checkpoint.pt")
    print(f"Saved latest checkpoint: {checkpoint_dir / 'latest_checkpoint.pt'}")

# Save final model
torch.save(model.state_dict(), "D:/Sem 6 project/ssd_mobilenetv2_mura_final.pt")
print("Final model saved")

Resuming from checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


  checkpoint = torch.load(checkpoint_path, map_location=device)


Resuming from epoch 10


Epoch 11/25: 100%|████████████████████████| 576/576 [11:09<00:00,  1.16s/it, loss=0.7160, acc=0.6667, mem=0.12/11.43GB]


Epoch 11/25, Avg Loss: 0.3488, Avg Accuracy: 0.7880
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch11.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 12/25: 100%|████████████████████████| 576/576 [11:51<00:00,  1.24s/it, loss=1.0367, acc=0.8333, mem=0.12/11.43GB]


Epoch 12/25, Avg Loss: 0.3166, Avg Accuracy: 0.8039
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch12.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 13/25: 100%|████████████████████████| 576/576 [11:44<00:00,  1.22s/it, loss=0.7169, acc=0.8333, mem=0.12/11.43GB]


Epoch 13/25, Avg Loss: 0.2862, Avg Accuracy: 0.8174
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch13.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 14/25: 100%|████████████████████████| 576/576 [11:30<00:00,  1.20s/it, loss=0.5309, acc=1.0000, mem=0.12/11.43GB]


Epoch 14/25, Avg Loss: 0.2518, Avg Accuracy: 0.8299
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch14.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 15/25: 100%|████████████████████████| 576/576 [11:48<00:00,  1.23s/it, loss=0.3053, acc=0.8333, mem=0.12/11.43GB]


Epoch 15/25, Avg Loss: 0.2197, Avg Accuracy: 0.8446
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch15.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 16/25: 100%|████████████████████████| 576/576 [11:33<00:00,  1.20s/it, loss=0.8660, acc=1.0000, mem=0.12/11.43GB]


Epoch 16/25, Avg Loss: 0.1913, Avg Accuracy: 0.8595
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch16.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 17/25: 100%|████████████████████████| 576/576 [11:33<00:00,  1.20s/it, loss=0.8261, acc=1.0000, mem=0.12/11.43GB]


Epoch 17/25, Avg Loss: 0.1617, Avg Accuracy: 0.8730
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch17.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 18/25: 100%|████████████████████████| 576/576 [11:47<00:00,  1.23s/it, loss=0.5762, acc=1.0000, mem=0.12/11.43GB]


Epoch 18/25, Avg Loss: 0.1289, Avg Accuracy: 0.8823
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch18.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 19/25: 100%|████████████████████████| 576/576 [11:37<00:00,  1.21s/it, loss=1.0071, acc=0.8333, mem=0.12/11.43GB]


Epoch 19/25, Avg Loss: 0.1159, Avg Accuracy: 0.8887
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch19.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 20/25: 100%|████████████████████████| 576/576 [13:14<00:00,  1.38s/it, loss=0.3414, acc=0.8333, mem=0.12/11.43GB]


Epoch 20/25, Avg Loss: 0.1006, Avg Accuracy: 0.8925
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch20.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 21/25: 100%|████████████████████████| 576/576 [11:26<00:00,  1.19s/it, loss=0.0107, acc=1.0000, mem=0.12/11.43GB]


Epoch 21/25, Avg Loss: 0.0831, Avg Accuracy: 0.8994
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch21.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 22/25: 100%|████████████████████████| 576/576 [11:27<00:00,  1.19s/it, loss=0.1868, acc=1.0000, mem=0.12/11.43GB]


Epoch 22/25, Avg Loss: 0.0775, Avg Accuracy: 0.9032
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch22.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 23/25: 100%|████████████████████████| 576/576 [10:16<00:00,  1.07s/it, loss=1.0130, acc=1.0000, mem=0.12/11.43GB]


Epoch 23/25, Avg Loss: 0.0715, Avg Accuracy: 0.9054
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch23.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 24/25: 100%|████████████████████████| 576/576 [11:17<00:00,  1.18s/it, loss=1.1727, acc=1.0000, mem=0.12/11.43GB]


Epoch 24/25, Avg Loss: 0.0611, Avg Accuracy: 0.9081
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch24.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt


Epoch 25/25: 100%|████████████████████████| 576/576 [11:36<00:00,  1.21s/it, loss=0.5875, acc=1.0000, mem=0.12/11.43GB]


Epoch 25/25, Avg Loss: 0.0599, Avg Accuracy: 0.9101
Saved checkpoint: D:\Sem 6 project\checkpoints\ssd_mobilenetv2_mura_epoch25.pt
Saved latest checkpoint: D:\Sem 6 project\checkpoints\latest_checkpoint.pt
Final model saved


In [7]:
# Load the Validation Dataset
valid_dataset = MURADataset(
    "D:/Sem 6 project/MURA_YOLO/valid/images",  # Adjust path to your validation images
    "D:/Sem 6 project/MURA_YOLO/valid/labels",   # Adjust path to your validation labels
    transform=transform
)

In [8]:
# Pre-filter validation dataset
filtered_indices = []
for idx in range(len(valid_dataset)):
    try:
        item = valid_dataset[idx]
        if item is not None:
            filtered_indices.append(idx)
    except Exception:
        pass

valid_dataset = torch.utils.data.Subset(valid_dataset, filtered_indices)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    collate_fn=custom_collate,
    pin_memory=True
)

In [14]:
# Load the trained model weights
model_path = "D:/Sem 6 project/ssd_mobilenetv2_mura_final.pt"
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
print(f"Loaded model from {model_path}")

# Evaluation Function with the 10 Metrics
def evaluate(model, data_loader, device, desc="Evaluating"):
    model.eval()
    total_accuracy = 0
    batch_count = 0
    all_preds = []
    all_targets = []
    all_scores = []
    total_log_loss = 0

    with torch.no_grad():
        with tqdm(total=len(data_loader), desc=desc) as pbar:
            for batch_idx, (images, targets) in enumerate(data_loader):
                if not images or len(images) != len(targets):
                    pbar.update(1)
                    continue

                try:
                    images = [img.to(device) for img in images]
                    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                    # Compute predictions
                    preds = model(images)

                    # Compute accuracy
                    batch_accuracy = 0
                    batch_preds = []
                    batch_targets = []
                    batch_scores = []
                    for pred, target in zip(preds, targets):
                        pred_labels = pred['labels']
                        pred_scores = pred['scores']
                        gt_labels = target['labels']
                        if len(pred_labels) > 0:
                            max_score_idx = torch.argmax(pred_scores)
                            pred_label = pred_labels[max_score_idx].item()
                            pred_score = pred_scores[max_score_idx].item()
                            # Convert boolean to float (1.0 for True, 0.0 for False)
                            correct = 1.0 if (pred_label == gt_labels[0].item()) else 0.0
                            batch_accuracy += correct
                        else:
                            pred_label = 0  # Default to background if no predictions
                            pred_score = 0.0
                            batch_accuracy += 0  # No predictions, count as incorrect
                        batch_preds.append(pred_label)
                        batch_targets.append(gt_labels[0].item())
                        batch_scores.append(pred_score)

                    # Compute batch accuracy as the average of correct predictions
                    batch_accuracy /= len(targets) if len(targets) > 0 else 1

                    # Compute log loss for the batch (simplified for binary classification: foreground vs. background)
                    binary_targets = [1 if t > 0 else 0 for t in batch_targets]  # Foreground (1) vs. Background (0)
                    binary_probs = torch.sigmoid(torch.tensor(batch_scores, dtype=torch.float)).cpu().numpy()
                    # Create a 2D array for log_loss: [P(background), P(foreground)]
                    binary_probs_2d = np.vstack([1 - binary_probs, binary_probs]).T
                    batch_log_loss = log_loss(binary_targets, binary_probs_2d, labels=[0, 1])
                    total_log_loss += batch_log_loss

                    all_preds.extend(batch_preds)
                    all_targets.extend(batch_targets)
                    all_scores.extend(batch_scores)

                    total_accuracy += batch_accuracy
                    batch_count += 1

                    pbar.set_postfix({
                        'acc': f"{batch_accuracy:.4f}"
                    })
                except RuntimeError as e:
                    print(f"Evaluation error in batch {batch_idx+1}: {e}")
                    torch.cuda.empty_cache()
                finally:
                    pbar.update(1)

    # Compute average accuracy
    avg_accuracy = total_accuracy / batch_count if batch_count > 0 else 0
    avg_log_loss = total_log_loss / batch_count if batch_count > 0 else 0

    # Convert predictions and targets to numpy arrays
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_scores = np.array(all_scores)

    # Compute confusion matrix components (TP, TN, FP, FN)
    # For simplicity, we'll treat this as a binary classification problem (foreground vs. background)
    # Foreground: any class > 0, Background: class 0
    true_positives = np.sum((all_preds > 0) & (all_targets > 0))
    true_negatives = np.sum((all_preds == 0) & (all_targets == 0))
    false_positives = np.sum((all_preds > 0) & (all_targets == 0))
    false_negatives = np.sum((all_preds == 0) & (all_targets > 0))

    # Compute Precision, Recall, F1 Score
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    # Compute Specificity
    specificity = true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) > 0 else 0

    # Compute G-Mean
    g_mean = np.sqrt(recall * specificity) if (recall * specificity) > 0 else 0

    # Compute AUC (ROC-AUC)
    # For multiclass, we use one-vs-rest AUC; here we simplify to foreground vs. background
    binary_targets = (all_targets > 0).astype(int)
    binary_scores = np.array(all_scores)  # Use scores for foreground probability
    auc = roc_auc_score(binary_targets, binary_scores) if len(np.unique(binary_targets)) > 1 else 0

    # Compute MCC (Matthews Correlation Coefficient)
    mcc = matthews_corrcoef(all_targets, all_preds)

    # Compute Cohen's Kappa
    kappa = cohen_kappa_score(all_targets, all_preds)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "accuracy": avg_accuracy,
        "specificity": specificity,
        "g_mean": g_mean,
        "auc": auc,
        "mcc": mcc,
        "kappa": kappa,
        "log_loss": avg_log_loss
    }

# Evaluate on Training Set
print("\nEvaluating on Training Set:")
train_metrics = evaluate(model, train_loader, device, desc="Eval on Train")
print("Train Evaluation Metrics:")
print(f"1. Precision: {train_metrics['precision']:.4f}")
print(f"2. Recall: {train_metrics['recall']:.4f}")
print(f"3. F1 Score: {train_metrics['f1']:.4f}")
print(f"4. Accuracy: {train_metrics['accuracy']:.4f}")
print(f"5. Specificity: {train_metrics['specificity']:.4f}")
print(f"6. G-Mean: {train_metrics['g_mean']:.4f}")
print(f"7. AUC (ROC-AUC): {train_metrics['auc']:.4f}")
print(f"8. MCC: {train_metrics['mcc']:.4f}")
print(f"9. Cohen's Kappa: {train_metrics['kappa']:.4f}")
print(f"10. Log Loss: {train_metrics['log_loss']:.4f}")

# Evaluate on Validation Set
print("\nEvaluating on Validation Set:")
valid_metrics = evaluate(model, valid_loader, device, desc="Eval on Valid")
print("Validation Evaluation Metrics:")
print(f"1. Precision: {valid_metrics['precision']:.4f}")
print(f"2. Recall: {valid_metrics['recall']:.4f}")
print(f"3. F1 Score: {valid_metrics['f1']:.4f}")
print(f"4. Accuracy: {valid_metrics['accuracy']:.4f}")
print(f"5. Specificity: {valid_metrics['specificity']:.4f}")
print(f"6. G-Mean: {valid_metrics['g_mean']:.4f}")
print(f"7. AUC (ROC-AUC): {valid_metrics['auc']:.4f}")
print(f"8. MCC: {valid_metrics['mcc']:.4f}")
print(f"9. Cohen's Kappa: {valid_metrics['kappa']:.4f}")
print(f"10. Log Loss: {valid_metrics['log_loss']:.4f}")

  model.load_state_dict(torch.load(model_path, map_location=device))


Loaded model from D:/Sem 6 project/ssd_mobilenetv2_mura_final.pt

Evaluating on Training Set:


Eval on Train: 100%|█████████████████████████████████████████████████████| 576/576 [06:51<00:00,  1.40it/s, acc=1.0000]


Train Evaluation Metrics:
1. Precision: 0.9205
2. Recall: 1.0000
3. F1 Score: 0.9586
4. Accuracy: 0.9148
5. Specificity: 0.0000
6. G-Mean: 0.0000
7. AUC (ROC-AUC): 0.5047
8. MCC: 0.9116
9. Cohen's Kappa: 0.9057
10. Log Loss: 0.3939

Evaluating on Validation Set:


Eval on Valid: 100%|███████████████████████████████████████████████████████| 50/50 [00:36<00:00,  1.38it/s, acc=0.7049]

Validation Evaluation Metrics:
1. Precision: 0.9265
2. Recall: 1.0000
3. F1 Score: 0.9618
4. Accuracy: 0.6832
5. Specificity: 0.0000
6. G-Mean: 0.0000
7. AUC (ROC-AUC): 0.3851
8. MCC: 0.6609
9. Cohen's Kappa: 0.6564
10. Log Loss: 0.3993





In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Constants
IMG_SIZE = 512
LATENT_DIM = 512
BATCH_SIZE = 8
NUM_CLASSES = 14
DATASET_DIR = 'MURA-v1.1/valid'  # Update this path to your dataset directory
CHECKPOINT_DIR = 'D:\\Sem 6 project\\checkpoints-BAGAN-GP-WGAN-GP_Old-2'  # Update this path to your checkpoint directory
OUTPUT_DIR = 'autoencoder_samples'  # Update this path for saving output samples
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

body_parts = ['XR_WRIST', 'XR_SHOULDER', 'XR_HAND', 'XR_FOREARM', 'XR_FINGER', 'XR_ELBOW', 'XR_HUMERUS']
case_types = ['positive', 'negative']
class_names = [f"{bp}_{ct}" for bp in body_parts for ct in case_types]

class MURADataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            img_path = self.image_paths[idx]
            full_path = os.path.join(DATASET_DIR, '..', img_path)
            image = Image.open(full_path).convert('L')  # Grayscale for X-rays
            if self.transform:
                image = self.transform(image)
            label = self.labels[idx]
            label_onehot = torch.zeros(NUM_CLASSES)
            label_onehot[label] = 1.0
            return image, label_onehot
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Instead of returning a placeholder, raise an exception to skip invalid images
            raise e

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),  # 512 -> 256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 256 -> 128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # 128 -> 64
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),  # 64 -> 32
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, stride=2, padding=1),  # 32 -> 16
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten()
        )
        self.fc = nn.Linear(1024 * 16 * 16, LATENT_DIM)
    
    def forward(self, x):
        x = self.model(x)
        latent = self.fc(x)
        return latent

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.initial_size = 16
        # Modified to match the original notebook's architecture for checkpoint compatibility
        self.fc = nn.Linear(LATENT_DIM + NUM_CLASSES, 1024 * self.initial_size * self.initial_size)
        self.bn_initial = nn.BatchNorm2d(1024)
        self.deconv = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1),  # 16 -> 32
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),  # 32 -> 64
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 64 -> 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 128 -> 256
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),  # 256 -> 512
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        # Concatenate latent vector with labels instead of multiplication
        x = torch.cat([z, labels], dim=1)
        x = self.fc(x)
        x = x.view(-1, 1024, self.initial_size, self.initial_size)
        x = self.bn_initial(x)
        x = nn.functional.relu(x, inplace=True)
        x = self.deconv(x)
        return x

class Autoencoder(nn.Module):
    def __init__(self, num_classes):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.num_classes = num_classes
    
    def forward(self, x, labels):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent, labels)
        return reconstructed

def scan_dataset(dataset_dir=DATASET_DIR):
    image_paths = []
    labels = []
    for root, _, files in os.walk(dataset_dir):
        for file in files:
            if file.endswith('.png'):
                full_path = os.path.join(root, file)
                relative_path = os.path.relpath(full_path, os.path.join(dataset_dir, '..'))
                label = None
                for bp in body_parts:
                    if bp in relative_path:
                        case = 'positive' if 'positive' in relative_path else 'negative'
                        class_idx = body_parts.index(bp) * 2 + (0 if case == 'positive' else 1)
                        label = class_idx
                        break
                if label is not None:
                    image_paths.append(relative_path)
                    labels.append(label)
    return np.array(image_paths), np.array(labels)

# Load validation dataset
val_paths, val_labels = scan_dataset()
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
val_dataset = MURADataset(val_paths, val_labels, transform=val_transform)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Load autoencoder
autoencoder = Autoencoder(NUM_CLASSES).to(device)
checkpoint_path = os.path.join(CHECKPOINT_DIR, 'autoencoder_epoch_50.pth')
if os.path.exists(checkpoint_path):
    autoencoder.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
    print(f"Loaded autoencoder checkpoint from {checkpoint_path}")
else:
    raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")

def generate_autoencoder_samples(autoencoder, dataloader, num_samples_per_class=1):
    autoencoder.eval()
    class_samples = {i: [] for i in range(NUM_CLASSES)}
    class_counts = {i: 0 for i in range(NUM_CLASSES)}

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Generating autoencoder samples"):
            try:
                images, labels_onehot = batch
                images = images.to(device)
                labels_onehot = labels_onehot.to(device)
                labels = torch.argmax(labels_onehot, dim=1)

                reconstructed = autoencoder(images, labels_onehot)

                images = (images + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
                reconstructed = (reconstructed + 1) / 2

                for i in range(images.size(0)):
                    class_idx = labels[i].item()
                    if class_counts[class_idx] < num_samples_per_class:
                        class_samples[class_idx].append((images[i], reconstructed[i]))
                        class_counts[class_idx] += 1

                if all(count >= num_samples_per_class for count in class_counts.values()):
                    break
            except Exception as e:
                print(f"Error processing batch: {e}")
                continue

    # Plot and save samples
    fig, axes = plt.subplots(NUM_CLASSES, 2, figsize=(8, NUM_CLASSES * 2))
    for class_idx in range(NUM_CLASSES):
        if class_samples[class_idx]:
            real_img, recon_img = class_samples[class_idx][0]
            real_img = real_img.cpu().numpy().squeeze()
            recon_img = recon_img.cpu().numpy().squeeze()

            axes[class_idx, 0].imshow(real_img, cmap='gray')
            axes[class_idx, 0].set_title(f"Real: {class_names[class_idx]}")
            axes[class_idx, 0].axis('off')

            axes[class_idx, 1].imshow(recon_img, cmap='gray')
            axes[class_idx, 1].set_title(f"Recon: {class_names[class_idx]}")
            axes[class_idx, 1].axis('off')

            plt.imsave(
                os.path.join(OUTPUT_DIR, f'reconstructed_{class_names[class_idx]}.png'),
                recon_img, cmap='gray'
            )
        else:
            print(f"No samples generated for class {class_names[class_idx]}")

    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'autoencoder_samples.png'))
    plt.close()
    print(f"Saved autoencoder samples to {OUTPUT_DIR}")

# Generate samples
generate_autoencoder_samples(autoencoder, val_dataloader, num_samples_per_class=1)

Using device: cuda
Loaded autoencoder checkpoint from D:\Sem 6 project\checkpoints-BAGAN-GP-WGAN-GP_Old-2\autoencoder_epoch_50.pth


Generating autoencoder samples:  81%|██████████████████████████████████████▊         | 323/400 [00:57<00:13,  5.57it/s]


Saved autoencoder samples to autoencoder_samples
