In [None]:
import torch
import numpy as np
import os
import json
from torchvision import io
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as T
from torchvision.transforms import Resize, RandomCrop, Normalize
from torchvision.models.detection import MaskRCNN
from torch import nn, device
import random

DATA_DIR = "C:\\New folder\\Dr. Surya\\MaskRCNN\\Unity_Generation\\Concrete"
IMAGES_DIR = os.path.join(DATA_DIR, "Images")
MASKS_DIR = os.path.join(DATA_DIR, "Masks") 
ANNOTATIONS_DIR = os.path.join(DATA_DIR, "BoundingBoxs")
def get_annotations_file_path(image_filename):
  filename_without_extension = os.path.splitext(image_filename)[0]
  return os.path.join(ANNOTATIONS_DIR, f"{filename_without_extension}.json")

class CrackDataset(Dataset):

  def __init__(self, root_dir):

    self.root_dir = root_dir
    
    # Load image names
    self.images = os.listdir(IMAGES_DIR) 

    # Load mask names 
    self.masks = os.listdir(MASKS_DIR)

    # Load annotations
    with open(ANNOTATIONS_DIR) as f:
      self.annotations = json.load(f)

  def __getitem__(self, i):

    # Get image name
    image_name = self.images[i]

    # Get mask name
    mask_name = self.masks[i] 

    # Get annotations 
    annotations = self.annotations[image_name]

    # Read image and mask
    image = io.read_image(os.path.join(IMAGES_DIR, image_name))
    mask = io.read_image(os.path.join(MASKS_DIR, mask_name))

    return image, mask, annotations

  def __len__(self):
    return len(self.images)


def transform(image, mask):

  # Resize images 
  resized = Resize(256)
  image = resized(image)
  mask = resized(mask)

  # Apply random crop
  random_crop = RandomCrop(224,224)  
  image, mask = random_crop(image, mask)

  # Apply horizontal flip
  if random.random() < 0.5:
    image = T.hflip(image)  
    mask = T.hflip(mask)

  # Normalize image
  normalized = Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
  image = normalized(image)

  # Convert to tensors
  image = T.to_tensor(image)  
  mask = torch.squeeze(mask, dim=0)

  return image, mask

dataset = CrackDataset(DATA_DIR)
dataloader = DataLoader(dataset, batch_size=1)

#Split dataset into train and val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=1,
    shuffle=True,
    num_workers=4
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False, 
    num_workers=4
)

model = MaskRCNN(backbone='resnet50', pretrained=True, pretrained_backbone=True)
for param in model.features.parameters():
  param.requires_grad = False
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
criterion = nn.CrossEntropyLoss()

for epoch in range(20):

   # training 
   for images, targets in train_dataloader:
      outputs = model(images)
      loss = criterion(outputs, targets)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
   # validation
   with torch.no_grad():

    val_loss = 0
    val_accuracy = 0

    for images, targets in val_dataloader:

      images = images.to(device)
      targets = targets.to(device)

      outputs = model(images)
      loss = criterion(outputs, targets)

      val_loss += loss.item()

      predictions = torch.argmax(outputs, dim=1)
      actual = targets
      accuracy = (predictions == actual).float().mean()
      val_accuracy += accuracy

    val_loss /= len(val_dataloader)
    val_accuracy /= len(val_dataloader)

    print(f"Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

# Save trained model
torch.save(model.state_dict(), 'C:\\New folder\\Dr. Surya\\MaskRCNN\\crack_maskrcnn.pth')