In [None]:
import os
import xml.etree.ElementTree as ET
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch.optim as optim

# ====================== CONFIG ======================
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png']
CLASS_NAMES = ['crazing', 'inclusion', 'patches', 'pitted_surface', 'rolled_in_scale', 'scratches']
CLASS_NAME_TO_IDX = {name.lower().replace('-', '').replace(' ', ''): i + 1 for i, name in enumerate(CLASS_NAMES)}
NUM_CLASSES = len(CLASS_NAMES) + 1  # +1 for background
SEED = 42
torch.manual_seed(SEED)

# ====================== DATASET ======================
class VOCStyleSteelDataset(Dataset):
    def __init__(self, root_dir, transforms=None):
        self.root_dir = root_dir
        self.transforms = transforms
        self.samples = []

        for class_folder in os.listdir(root_dir):
            folder_path = os.path.join(root_dir, class_folder)
            if not os.path.isdir(folder_path):
                continue
            for img_file in os.listdir(folder_path):
                if any(img_file.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
                    img_path = os.path.join(folder_path, img_file)
                    xml_path = img_path.rsplit('.', 1)[0] + '.xml'
                    if os.path.exists(xml_path):
                        self.samples.append((img_path, xml_path))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, xml_path = self.samples[idx]
        try:
            img = Image.open(img_path).convert("RGB")
            tree = ET.parse(xml_path)
            root = tree.getroot()

            boxes = []
            labels = []
            for obj in root.findall('object'):
                name = obj.find('name').text.lower().replace('-', '').replace(' ', '')
                label = CLASS_NAME_TO_IDX.get(name, 0)
                bndbox = obj.find('bndbox')
                xmin = float(bndbox.find('xmin').text)
                ymin = float(bndbox.find('ymin').text)
                xmax = float(bndbox.find('xmax').text)
                ymax = float(bndbox.find('ymax').text)
                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(label)

            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)

            target = {"boxes": boxes, "labels": labels, "image_id": torch.tensor([idx])}

            if self.transforms:
                img = self.transforms(img)

            return img, target

        except Exception as e:
            print(f"❌ Skipping {img_path} due to error: {e}")
            return self.__getitem__((idx + 1) % len(self))

# ====================== DATALOADER COLLATE ======================
def collate_fn(batch):
    return tuple(zip(*batch))

# ====================== TRAINING FUNCTION ======================
def train_detector(dataset_path, num_classes=NUM_CLASSES, epochs=20, batch_size=2, val_split=0.2):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("📁 Loading dataset...")
    transforms_train = T.Compose([T.ToTensor()])
    dataset = VOCStyleSteelDataset(dataset_path, transforms=transforms_train)

    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    print("📦 Building model...")
    weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn(weights=weights)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    model.to(device)

    optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    patience = 3
    no_improve = 0

    print("\n🟢 Starting Faster R-CNN Training...")
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0

        for batch_idx, (imgs, targets) in enumerate(train_loader):
            print(f"🔁 Epoch {epoch+1} — Processing batch {batch_idx+1}/{len(train_loader)}")
            imgs = list(img.to(device) for img in imgs)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(imgs, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            epoch_loss += losses.item()

        avg_train_loss = epoch_loss / len(train_loader)
        print(f"📊 Epoch {epoch+1}/{epochs} — Train Loss: {avg_train_loss:.4f}")

        scheduler.step()

        no_improve += 1
        if no_improve >= patience:
            print(f"🛑 Early stopping triggered at epoch {epoch+1}")
            break

    print("✅ Training complete.")

    # ====================== SAVE FULL MODEL ======================
    save_path = "E:/steel surface/models/fasterrcnn_steel_defect_full.pth"
    save_dir = os.path.dirname(save_path)
    if not os.path.exists(save_dir):
        print(f"📂 Creating directory: {save_dir}")
        os.makedirs(save_dir, exist_ok=True)
    else:
        print(f"✅ Save directory exists: {save_dir}")

    print(f"💾 Saving full model to {save_path}...")
    torch.save(model, save_path)
    print("✅ Full model successfully saved.")

# ====================== CALL FUNCTION ======================
if __name__ == "__main__":
    dataset_path = "E:/steel surface/steel_defects_dataset"
    train_detector(dataset_path, batch_size=4, epochs=10)


In [1]:
import torch
from torchvision import transforms as T
from PIL import Image, ImageDraw, ImageFont
import os

# ================== CONFIG ==================
CLASS_NAMES = ['crazing', 'inclusion', 'patches', 'pitted_surface', 'rolled_in_scale', 'scratches']
CLASS_COLORS = {
    'crazing': 'blue',
    'inclusion': 'green',
    'patches': 'orange',
    'pitted_surface': 'purple',
    'rolled_in_scale': 'brown',
    'scratches': 'magenta'
}
MODEL_PATH = "E:/steel surface/models/fasterrcnn_steel_defect_full.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SCORE_THRESHOLD = 0.5  # Detection confidence threshold

# ================== LOAD MODEL ==================
def load_model(model_path):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"❌ Model file not found: {model_path}")
    
    print(f"📦 Loading full model from: {model_path}")
    model = torch.load(model_path, map_location=DEVICE, weights_only=False)
    model.to(DEVICE)
    model.eval()
    print("✅ Model loaded and ready for inference.")
    return model

# ================== TRANSFORM ==================
transform = T.Compose([T.ToTensor()])

# ================== DETECTION FUNCTION ==================
def detect_defects(image_path, model, save_result=True):
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"❌ Image not found: {image_path}")
    
    print(f"🖼️ Loading image: {image_path}")
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(DEVICE)

    print("🔍 Running inference...")
    with torch.no_grad():
        prediction = model(image_tensor)[0]

    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    
    for box, label, score in zip(prediction["boxes"], prediction["labels"], prediction["scores"]):
        if score >= SCORE_THRESHOLD:
            box = box.cpu().numpy()
            label_name = CLASS_NAMES[label - 1]
            color = CLASS_COLORS.get(label_name, 'red')  # fallback to red
            draw.rectangle(box.tolist(), outline=color, width=3)
            draw.text((box[0], box[1] - 10), f"{label_name} {score:.2f}", fill=color, font=font)

    image.show(title="Detection Output")

    if save_result:
        folder, filename = os.path.split(image_path)
        name, ext = os.path.splitext(filename)
        output_filename = f"{name}_detected.jpg"
        output_path = os.path.join(folder, output_filename)
        image.save(output_path)
        print(f"💾 Output saved to: {output_path}")
    else:
        print("ℹ️ Output not saved.")

# ================== MAIN ==================
if __name__ == "__main__":
    test_image_path = r"E:\steel surface\steel_defects_dataset\inclusion\inclusion_20.jpg"
    model = load_model(MODEL_PATH)
    detect_defects(test_image_path, model)


📦 Loading full model from: E:/steel surface/models/fasterrcnn_steel_defect_full.pth
✅ Model loaded and ready for inference.
🖼️ Loading image: E:\steel surface\steel_defects_dataset\inclusion\inclusion_20.jpg
🔍 Running inference...
💾 Output saved to: E:\steel surface\steel_defects_dataset\inclusion\inclusion_20_detected.jpg
