In [18]:
import torch
import os
import json
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor, MaskRCNN_ResNet50_FPN_Weights
from torch import nn, device
import torchvision.transforms.functional as T
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor


# Set paths
ROOT_DIR = os.getcwd()
DATA_DIR = os.path.join(ROOT_DIR, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete')
MODEL_DIR = os.path.join(ROOT_DIR, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\final_models.pth')

# Load a pre-trained Mask R-CNN model 
model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1)
class MaskRCNN(nn.Module):
    def __init__(self, model_dir, num_classes=2):
        super(MaskRCNN, self).__init__()
        self.model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1, pretrained=True)
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        self.model.roi_heads.mask_predictor = MaskRCNNPredictor(256, 256, num_classes)

class CrackDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.image_list = os.listdir(os.path.join(root, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Images'))
        self.masks_dir = os.path.join(root, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Masks')
        self.annotations_file = os.path.join(root, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\BoundingBoxs')

        if not os.path.exists(self.annotations_file):
            with open(self.annotations_file) as f:
                self.bbox_data = json.load(f)
        self.box_index = 0
    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Images', self.image_list[idx])
        mask_path = os.path.join(self.masks_dir, self.image_list[idx])
        boxes = self.bbox_data[self.image_list[idx]]['boxes'][0]
        box = boxes[self.box_index]  
        self.box_index += 1
        if self.box_index >= len(boxes):
          self.box_index = 0
        box = boxes[self.box_index]
        boxes = boxes.reshape(1, -1, 4)
        boxes = boxes.squeeze()
        print(f"Single box shape: {box.shape}")
        image = read_image(img_path).float()
        mask = read_image(mask_path).float()
        print(f"Targets shape before transform: {mask.shape}")
        if self.transforms:
            image, mask = self.transforms(image, mask)
        print(f"Targets shape after transform: {mask.shape}")
        return image, mask, box

class ResizedCrackDataset(Dataset):
      def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.image_list = os.listdir(os.path.join(root, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Images'))
        self.masks_dir = os.path.join(root, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Masks')
        self.annotations_file = os.path.join(root, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\BoundingBoxs')
        
        self.image_size = 256
        self.mask_size = 256

      def __len__(self):
        return len(self.image_list)

      def __getitem__(self, idx):
        img_path = os.path.join(self.root, 'C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete\\Images', self.image_list[idx])
        mask_path = os.path.join(self.masks_dir, self.image_list[idx])
        
        image = read_image(img_path)
        mask = read_image(mask_path)

        if self.transforms:
          image, mask = self.transforms(image, mask)
      
        return image, mask
  
# Define data augmentation and transformation
def transform(image, mask):
    image = T.to_pil_image(image)
    mask = T.to_pil_image(mask)

    # Apply transformations (resize, flip, etc.)
    image = T.resize(image, [256, 256], antialias=True) 

    mask = T.resize(mask, [100, 4], antialias=True) 

    # Convert back to tensors
    image = T.to_tensor(image)
    mask = T.to_tensor(mask)

    return image, mask

# Load the dataset
dataset = ResizedCrackDataset(DATA_DIR, transforms=transform)

# Freeze backbone layers
for param in model.backbone.parameters():
    param.requires_grad = False
    
# Train only the classifier and box layer heads    
params = [p for p in model.parameters() if p.requires_grad] 

optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
# Dataloaders    
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True) 
val_dataloader = DataLoader(dataset, batch_size=1, shuffle=True) 
# Train one epoch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def to_device(data):
  """Move tensor(s) to device"""

  if isinstance(data, (list,tuple)):
    return [to_device(x) for x in data]
  return data.to(device)

def calculate_loss(outputs, targets):
  # Get predictions 
  preds = torch.max(outputs, 1)
  # Calculate loss 
  criterion = nn.CrossEntropyLoss()
  return criterion(preds, targets)

# Validation function
def validate(model, val_dataloader):
  model.eval()
  loss_total = 0
  accuracy_total = 0
  for images, targets in val_dataloader:
    targets_np = targets.cpu().numpy()
    print(f"Validation targets shape: {targets_np.shape}")
    
    outputs = model(images)
    criterion = nn.CrossEntropyLoss()
    images = to_device(images) 
    # Calculate loss
    loss = criterion(outputs, targets)
    loss_total += loss.item()

    # Calculate accuracy
    preds = torch.argmax(outputs,dim=1)  
    accuracy = torch.sum(preds==targets).float()/targets.shape[0]
    accuracy_total += accuracy  
    
    # Validation logic
    loss = criterion(outputs, targets)
    loss_total += loss.item()
    preds = torch.argmax(outputs,dim=1)  
    accuracy = torch.sum(preds==targets).float()/targets.shape[0]  
    accuracy_total += accuracy

  # Calculate average loss and accuracy
  loss_total /= len(val_dataloader)
  accuracy_total /= len(val_dataloader)

  print(f'Validation Loss: {loss_total}, Accuracy: {accuracy_total}')

  return loss_total, accuracy_total

def train_one_epoch(model, dataloader, optimizer):

  model.train()

  for images, targets in dataloader:
    
    # Print shape before processing
    print(f"Raw targets shape: {targets.shape}")  

    images = to_device(images)
    targets = to_device(targets)

    # Save targets to numpy
    targets_np = targets.cpu().numpy()

    # Print shape again
    print(f"Targets numpy shape: {targets_np.shape}")

    # Iterate over targets like model expects
    for target in targets_np:

      # Process single target  

      boxes = target

      # Reshape for model
      target = boxes.reshape(-1, 4)  

      outputs = model(images, target)

      loss = calculate_loss(outputs, target)  

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      print(f"Train loss: {loss.item():.2f}")

# Train model

for _ in range(20):
    train_one_epoch(
    model, 
    train_dataloader,
    optimizer)
    validate(
    model,
    val_dataloader)


# Save trained model
torch.save(model.state_dict(), 'crack_maskrcnn.pth')

Raw targets shape: torch.Size([1, 4, 100, 4])
Targets numpy shape: (1, 4, 100, 4)


IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices