In [1]:
import os
import json
import skimage.io
import torch
import torchvision.transforms as T
import numpy as np
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch import nn
from torch.optim import SGD
import os
import json
import skimage.io
import torch
import torchvision.transforms as T
import numpy as np

# Additional imports
import torchvision.utils as utils
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# Set paths
ROOT_DIR = os.getcwd()
DATA_DIR = os.path.join(ROOT_DIR, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete')
MODEL_DIR = os.path.join(ROOT_DIR, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\final_models')

# Crack dataset class
class CrackDataset:
    def __init__(self):
        self.bbox_data = None

    def load_data(self):
        IMAGES_DIR = 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Images'
        MASKS_DIR = 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Masks'
        ANNOTATIONS_FILE = 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\BoundingBoxs'

        images = os.listdir(os.path.join(DATA_DIR, IMAGES_DIR))
        masks = os.listdir(os.path.join(DATA_DIR, MASKS_DIR))

        annotations_file_path = os.path.join(DATA_DIR, ANNOTATIONS_FILE)

        if not os.path.exists(annotations_file_path):
            raise FileNotFoundError(f"Annotations file '{ANNOTATIONS_FILE}' not found.")
        else:
            with open(annotations_file_path) as f:
                self.bbox_data = json.load(f)

        images_dict = dict(zip(images, masks))
        return images_dict

    def load_image(self, image_id):
        image_path = os.path.join(DATA_DIR, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Images', image_id)
        image = read_image(image_path).float()
        return image

    def load_mask(self, image_id):
        mask_path = os.path.join(DATA_DIR, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Masks', image_id)
        mask = read_image(mask_path).float()
        return mask

    def image_reference(self, image_id):
        return 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Images' + image_id

    def bbox_reference(self, image_id):
        if self.bbox_data is not None:
            for bb in self.bbox_data:
                if bb['image_id'] == image_id:
                    return bb  # Return the matching dictionary

        return None  # Return None if no matching 'image_id' is found

# Crack config
class CrackConfig:
    # Hyperparameters
    NAME = "crack"  # Name of your configuration
    NUM_CLASSES = 1 + 1  # Crack + Background
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1  # You can adjust this based on your GPU memory

# Create crack model
class MaskRCNN(nn.Module):
    def __init__(self, config, model_dir):
        super(MaskRCNN, self).__init__()
        self.config = config
        self.mask_rcnn = maskrcnn_resnet50_fpn(pretrained=True)
        self.model_dir = model_dir

    def forward(self, images, masks, targets=None):
        images = [F.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) for image in images]
        images = [F.to_tensor(image) for image in images]
        images = [F.interpolate(image.unsqueeze(0), size=(256, 256)).squeeze(0) for image in images]
        images = torch.stack(images)
        return self.mask_rcnn(images)

# Train weights on the crack dataset
dataset = CrackDataset()
model = MaskRCNN(CrackConfig(), MODEL_DIR)

# Convert the mask to binary (0s and 1s)
masks = [mask > 0.5 for mask in masks]

# Convert masks to PyTorch tensors
masks = [torch.unsqueeze(mask, dim=0).float() for mask in masks]

# Train the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()
epochs = 50

for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(images, masks)
    loss = criterion(outputs[0]['masks'].float(), masks[0])
    loss.backward()
    optimizer.step()

# Save trained crack model
torch.save(model.state_dict(), os.path.join(MODEL_DIR, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\crack_model.pth'))

# Evaluate on the validation dataset
val_dataset = CrackDataset()
val_images_dict = val_dataset.load_data()
val_images, val_masks = list(val_images_dict.keys()), list(val_images_dict.values())

val_masks = [mask > 0.5 for mask in val_masks]
val_masks = [torch.unsqueeze(mask, dim=0).float() for mask in val_masks]

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

model.eval()
total_loss = 0

for val_image, val_mask in zip(val_images, val_masks):
    val_image = dataset.load_image(val_image)
    val_image = val_image.unsqueeze(0)
    val_mask = val_mask.unsqueeze(0)
    
    with torch.no_grad():
        val_output = model(val_image, None)

    loss = criterion(val_output[0]['masks'].float(), val_mask)
    total_loss += loss.item()

average_loss = total_loss / len(val_images)
print(f'Validation Loss: {average_loss}')

# Predict crack on a test image (replace 'test_img' with your image data)
test_img = None  # Load your test image here

# Preprocess the test image
test_img = F.normalize(test_img, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
test_img = F.to_tensor(test_img)
test_img = F.interpolate(test_img.unsqueeze(0), size=(256, 256)).squeeze(0)

# Inference
model.eval()
with torch.no_grad():
    results = model(test_img.unsqueeze(0), None)


Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to C:\Users\peima/.cache\torch\hub\checkpoints\maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
100%|██████████| 170M/170M [00:16<00:00, 10.6MB/s] 


NameError: name 'masks' is not defined