In [1]:
import torch
from torch.utils.data import DataLoader, Subset
from pathlib import Path
import cv2
import numpy as np
from tqdm import tqdm
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from sklearn.metrics import roc_auc_score, log_loss, matthews_corrcoef, cohen_kappa_score

In [2]:
class MURADataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, label_dir, transform=None):
        self.img_dir = Path(img_dir)
        self.label_dir = Path(label_dir)
        self.images = sorted(list(self.img_dir.glob("*.png")))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label_path = self.label_dir / (img_path.stem + ".txt")
        print(f"Loading image: {img_path}")  # Debug (comment out if not needed)
        img = cv2.imread(str(img_path))
        if img is None:
            print(f"Failed to load image: {img_path}")
            return None
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        try:
            with open(label_path, "r") as f:
                line = f.readline().strip().split()
                if not line:
                    print(f"No label data for: {label_path}")
                    return None
                class_id, x, y, w, h = map(float, line)
                x_min = (x - w/2) * 300
                y_min = (y - h/2) * 300
                x_max = (x + w/2) * 300
                y_max = (y + h/2) * 300
                if x_max <= x_min or y_max <= y_min:
                    print(f"Invalid box for: {label_path}")
                    return None
                target = {
                    "boxes": torch.tensor([[x_min, y_min, x_max, y_max]], dtype=torch.float32),
                    "labels": torch.tensor([int(class_id)], dtype=torch.int64)
                }
        except Exception as e:
            print(f"Error processing label {label_path}: {e}")
            return None

        if self.transform:
            img = self.transform(img)
        
        return img, target

In [3]:
device = torch.device("cuda")
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. This code requires a GPU.")
print(f"Using device: {device}")

Using device: cuda


In [5]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = MURADataset(
    "D:/Sem 6 project/MURA_YOLO/train/images",
    "D:/Sem 6 project/MURA_YOLO/train/labels",
    transform=transform
)

filtered_indices = []
for idx in range(len(train_dataset)):
    try:
        item = train_dataset[idx]
        if item is not None:
            filtered_indices.append(idx)
    except Exception:
        pass

train_dataset = Subset(train_dataset, filtered_indices)
# train_dataset = Subset(train_dataset, range(100))  # Uncomment to test with 100 samples

valid_dataset = MURADataset(
    "D:/Sem 6 project/MURA_YOLO/valid/images",
    "D:/Sem 6 project/MURA_YOLO/valid/labels",
    transform=transform
)

filtered_indices = []
for idx in range(len(valid_dataset)):
    try:
        item = valid_dataset[idx]
        if item is not None:
            filtered_indices.append(idx)
    except Exception:
        pass

valid_dataset = Subset(valid_dataset, filtered_indices)

Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image10.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image11.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10000.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10003.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10006.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10009.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_1001.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10012.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10016.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10019.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10022.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10025.png
Loading image: D:\Sem 6 project\MURA_YOLO\train\images\image1_10028.png
Loading i

In [6]:
def custom_collate(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return [], []
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return images, targets

train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    collate_fn=custom_collate,
    pin_memory=True
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2,
    collate_fn=custom_collate,
    pin_memory=True
)

model = fasterrcnn_resnet50_fpn(num_classes=15, weights=None)
model = model.to(device)

In [None]:
checkpoint_dir = Path("D:/Sem 6 project/checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

start_epoch = 0
latest_checkpoint_path = checkpoint_dir / "latest_faster_rcnn_checkpoint.pt"

if latest_checkpoint_path.exists():
    checkpoint = torch.load(latest_checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")
else:
    print("No checkpoint found, starting from scratch")

num_epochs = 25
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    model.train()
    total_loss = 0
    batch_count = 0
    
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        for batch_idx, (images, targets) in enumerate(train_loader):
            if not images:
                pbar.update(1)
                continue
            
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            try:
                loss_dict = model(images, targets)
                losses = sum(loss_dict.values())
                
                optimizer.zero_grad()
                losses.backward()
                optimizer.step()
            except RuntimeError as e:
                print(f"Error in batch {batch_idx+1}: {e}")
                torch.cuda.empty_cache()
                continue
            
            current_mem = torch.cuda.memory_allocated(device) / 1e9
            total_mem = torch.cuda.get_device_properties(device).total_memory / 1e9
            mem_str = f"{current_mem:.2f}/{total_mem:.2f}GB"
            
            total_loss += losses.item()
            batch_count += 1
            
            pbar.set_postfix({
                'loss': f"{losses.item():.4f}",
                'mem': mem_str
            })
            pbar.update(1)
    
    avg_loss = total_loss / batch_count if batch_count > 0 else 0
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
    
    checkpoint_path = checkpoint_dir / f"faster_rcnn_resnet50_mura_epoch{epoch+1}.pt"
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)
    print(f"Saved checkpoint: {checkpoint_path}")
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, latest_checkpoint_path)
    print(f"Saved latest checkpoint: {latest_checkpoint_path}")

No checkpoint found, starting from scratch

Epoch 1/25


Epoch 1/25:   0%|                                                                             | 0/9202 [00:00<?, ?it/s]

In [6]:
if latest_checkpoint_path.exists():
    checkpoint = torch.load(latest_checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from {latest_checkpoint_path} for evaluation")
else:
    print("No checkpoint found for evaluation. Using current model state.")

def evaluate(model, data_loader, device, desc="Evaluating"):
    model.eval()
    total_accuracy = 0
    batch_count = 0
    all_preds = []
    all_targets = []
    all_scores = []
    total_log_loss = 0

    with torch.no_grad():
        with tqdm(total=len(data_loader), desc=desc) as pbar:
            for batch_idx, (images, targets) in enumerate(data_loader):
                if not images or len(images) != len(targets):
                    pbar.update(1)
                    continue

                try:
                    images = [img.to(device) for img in images]
                    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                    preds = model(images)

                    batch_accuracy = 0
                    batch_preds = []
                    batch_targets = []
                    batch_scores = []
                    for pred, target in zip(preds, targets):
                        pred_labels = pred['labels']
                        pred_scores = pred['scores']
                        gt_labels = target['labels']
                        if len(pred_labels) > 0:
                            max_score_idx = torch.argmax(pred_scores)
                            pred_label = pred_labels[max_score_idx].item()
                            pred_score = pred_scores[max_score_idx].item()
                            correct = 1.0 if (pred_label == gt_labels[0].item()) else 0.0
                            batch_accuracy += correct
                        else:
                            pred_label = 0
                            pred_score = 0.0
                            batch_accuracy += 0
                        batch_preds.append(pred_label)
                        batch_targets.append(gt_labels[0].item())
                        batch_scores.append(pred_score)

                    batch_accuracy /= len(targets) if len(targets) > 0 else 1

                    binary_targets = [1 if t > 0 else 0 for t in batch_targets]
                    binary_probs = torch.sigmoid(torch.tensor(batch_scores)).cpu().numpy()
                    binary_probs_2d = np.vstack([1 - binary_probs, binary_probs]).T
                    batch_log_loss = log_loss(binary_targets, binary_probs_2d, labels=[0, 1])
                    total_log_loss += batch_log_loss

                    all_preds.extend(batch_preds)
                    all_targets.extend(batch_targets)
                    all_scores.extend(batch_scores)

                    total_accuracy += batch_accuracy
                    batch_count += 1

                    pbar.set_postfix({'acc': f"{batch_accuracy:.4f}"})
                    pbar.update(1)
                except RuntimeError as e:
                    print(f"Evaluation error in batch {batch_idx+1}: {e}")
                    torch.cuda.empty_cache()

    avg_accuracy = total_accuracy / batch_count if batch_count > 0 else 0
    avg_log_loss = total_log_loss / batch_count if batch_count > 0 else 0

    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_scores = np.array(all_scores)

    true_positives = np.sum((all_preds > 0) & (all_targets > 0))
    true_negatives = np.sum((all_preds == 0) & (all_targets == 0))
    false_positives = np.sum((all_preds > 0) & (all_targets == 0))
    false_negatives = np.sum((all_preds == 0) & (all_targets > 0))

    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    specificity = true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) > 0 else 0
    g_mean = np.sqrt(recall * specificity) if (recall * specificity) > 0 else 0

    binary_targets = (all_targets > 0).astype(int)
    auc = roc_auc_score(binary_targets, all_scores) if len(np.unique(binary_targets)) > 1 else 0

    mcc = matthews_corrcoef(all_targets, all_preds)
    kappa = cohen_kappa_score(all_targets, all_preds)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "accuracy": avg_accuracy,
        "specificity": specificity,
        "g_mean": g_mean,
        "auc": auc,
        "mcc": mcc,
        "kappa": kappa,
        "log_loss": avg_log_loss
    }

print("\nEvaluating on Training Set:")
train_metrics = evaluate(model, train_loader, device, desc="Eval on Train")
print("Train Evaluation Metrics:")
print(f"1. Precision: {train_metrics['precision']:.4f}")
print(f"2. Recall: {train_metrics['recall']:.4f}")
print(f"3. F1 Score: {train_metrics['f1']:.4f}")
print(f"4. Accuracy: {train_metrics['accuracy']:.4f}")
print(f"5. Specificity: {train_metrics['specificity']:.4f}")
print(f"6. G-Mean: {train_metrics['g_mean']:.4f}")
print(f"7. AUC (ROC-AUC): {train_metrics['auc']:.4f}")
print(f"8. MCC: {train_metrics['mcc']:.4f}")
print(f"9. Cohen's Kappa: {train_metrics['kappa']:.4f}")
print(f"10. Log Loss: {train_metrics['log_loss']:.4f}")

print("\nEvaluating on Validation Set:")
valid_metrics = evaluate(model, valid_loader, device, desc="Eval on Valid")
print("Validation Evaluation Metrics:")
print(f"1. Precision: {valid_metrics['precision']:.4f}")
print(f"2. Recall: {valid_metrics['recall']:.4f}")
print(f"3. F1 Score: {valid_metrics['f1']:.4f}")
print(f"4. Accuracy: {valid_metrics['accuracy']:.4f}")
print(f"5. Specificity: {valid_metrics['specificity']:.4f}")
print(f"6. G-Mean: {valid_metrics['g_mean']:.4f}")
print(f"7. AUC (ROC-AUC): {valid_metrics['auc']:.4f}")
print(f"8. MCC: {valid_metrics['mcc']:.4f}")
print(f"9. Cohen's Kappa: {valid_metrics['kappa']:.4f}")
print(f"10. Log Loss: {valid_metrics['log_loss']:.4f}")

NameError: name 'latest_checkpoint_path' is not defined