In [115]:
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.rpn import AnchorGenerator
import torchvision.ops as ops
import matplotlib.pyplot as plt
import logging
from scripts.utils import get_processed_images_and_masks

In [116]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [117]:
class FloodDataset(Dataset):
    def __init__(self, images, masks, transforms=None):
        self.images = images
        self.masks = masks
        self.transforms = transforms
        
    def __getitem__(self, idx):
        img_path = self.images[idx]
        mask_path = self.masks[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        obj_ids = np.unique(mask)
        obj_ids = obj_ids[1:]  # Skip background

        masks = mask == obj_ids[:, None, None]
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            if xmin < xmax and ymin < ymax:  # Ensure positive width and height
                boxes.append([xmin, ymin, xmax, ymax])
            else:
                logging.warning(f"Invalid box found: {xmin}, {ymin}, {xmax}, {ymax} for mask {i} in image {img_path}")

        if not boxes:  # Handle cases where all boxes are invalid
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            masks = torch.zeros((0, mask.shape[0], mask.shape[1]), dtype=torch.uint8)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.ones((len(boxes),), dtype=torch.int64)
            masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img = self.transforms(img)
        
        return img, target

    def __len__(self):
        return len(self.images)

In [118]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [119]:
class MaskRCNNModel:
    def __init__(self, num_classes):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.get_model_instance_segmentation(num_classes)
        self.model.to(self.device)

    def get_model_instance_segmentation(self, num_classes):
        # Load a pre-trained model for classification and return only the features
        backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        backbone = torch.nn.Sequential(*(list(backbone.children())[:-2]))
        backbone.out_channels = 2048

        # RPN (Region Proposal Network) anchor generator
        rpn_anchor_generator = AnchorGenerator(
            sizes=((32, 64, 128, 256, 512),),
            aspect_ratios=((0.5, 1.0, 2.0),) * 5)

        # Feature maps to be used for the ROI align operation
        roi_pooler = ops.MultiScaleRoIAlign(
            featmap_names=['0'], output_size=7, sampling_ratio=2)

        mask_roi_pooler = ops.MultiScaleRoIAlign(
            featmap_names=['0'], output_size=14, sampling_ratio=2)

        # Put the pieces together inside a MaskRCNN model
        model = MaskRCNN(backbone,
                         num_classes=num_classes,
                         rpn_anchor_generator=rpn_anchor_generator,
                         box_roi_pool=roi_pooler,
                         mask_roi_pool=mask_roi_pooler)

        return model

    def train(self, data_loader, optimizer, num_epochs=10):
        self.model.train()
        for epoch in range(num_epochs):
            logging.info(f'Starting epoch {epoch + 1}/{num_epochs}')
            i = 0
            for images, targets in data_loader:
                images = list(image.to(self.device) for image in images)
                targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
                loss_dict = self.model(images, targets)
                losses = sum(loss for loss in loss_dict.values())

                optimizer.zero_grad()
                losses.backward()
                optimizer.step()

                i += 1
                if i % 10 == 0:
                    logging.info(f"Epoch: {epoch + 1}, Iteration: {i}, Loss: {losses.item()}")

    def evaluate(self, data_loader):
        self.model.eval()
        with torch.no_grad():
            for images, targets in data_loader:
                images = list(image.to(self.device) for image in images)
                outputs = self.model(images)
                # Process and visualize the outputs
                for i in range(len(images)):
                    img = images[i].permute(1, 2, 0).cpu().numpy()
                    plt.imshow(img)
                    plt.show()
                    masks = outputs[i]['masks'].cpu().numpy()
                    for mask in masks:
                        plt.imshow(mask[0], alpha=0.5)
                    plt.show()

In [120]:
images, masks = get_processed_images_and_masks()

In [121]:
transform = transforms.Compose([transforms.ToTensor()])

In [122]:
dataset = FloodDataset(images, masks, transforms=transform)

In [123]:
data_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [124]:
model = MaskRCNNModel(num_classes=2)

In [125]:
params = [p for p in model.model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

In [126]:
model.train(data_loader, optimizer, num_epochs=1)

2024-07-03 19:20:30,344 - INFO - Starting epoch 1/1


KeyboardInterrupt: 

In [None]:
model.evaluate(data_loader)