In [None]:
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNN_ResNet50_FPN_Weights
import torchvision.ops as ops
import torch.utils.data
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import xml.etree.ElementTree as ET
import cv2
import os
from tqdm import tqdm

In [None]:
if torch.cuda.is_available():
    print('CUDA is available!')
else:
    print('CUDA is not available.')

# Print the CUDA device count
print(f"Number of CUDA devices: {torch.cuda.device_count()}")

# Get the current CUDA device
current_device = torch.cuda.current_device()
print(f"Current CUDA device: {current_device}")

# Print the name of the current CUDA device
print(f"Current CUDA device name: {torch.cuda.get_device_name(current_device)}")
torch.backends.cudnn.enabled = False

### Params

In [None]:
EPOCHS = 300
SCHEDULER_STEP_SIZE = 40
ES_START_EPOCH = SCHEDULER_STEP_SIZE + 10
PATIENCE = 50

### Prepare Dataset

In [None]:
class PascalVOCDataset(torch.utils.data.Dataset):
    def __init__(self, root, class_list, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, "Images"))))
        self.anns = list(sorted(os.listdir(os.path.join(root, "XML"))))
        self.class_list = class_list

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "Images", self.imgs[idx])
        ann_path = os.path.join(self.root, "XML", self.anns[idx])

        img = Image.open(img_path).convert("RGB")
        img = T.ToTensor()(img) 
        original_width, original_height = self._get_original_image_size_from_xml(ann_path)
        
        boxes, labels = self._get_annotation_data_from_xml(ann_path, original_width, original_height)
        masks = np.array(self.create_segmentation_masks((1024,1024), boxes))

        target = {
            "boxes": torch.as_tensor(boxes, dtype=torch.float32),
            "labels": torch.as_tensor(labels, dtype=torch.int64),
            "masks": torch.tensor(masks, dtype=torch.uint8),
        }

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.imgs)

    # Some Images in the dataset has already been resized to 1024 but the annotated bounding boxes are not normalized
    def _get_original_image_size_from_xml(self, xml_path):
        tree = ET.parse(xml_path)
        root = tree.getroot()
        size = root.find("size")
        width = int(size.find("width").text)
        height = int(size.find("height").text)
        return width, height

    def _get_annotation_data_from_xml(self, xml_path, original_width, original_height, target_size=1024):
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        boxes = []
        labels = []
        
        for obj in root.findall("object"):
            bbox = obj.find("bndbox")
            xmin = float(bbox.find("xmin").text)
            ymin = float(bbox.find("ymin").text)
            xmax = float(bbox.find("xmax").text)
            ymax = float(bbox.find("ymax").text)

            # Normalize bounding box coordinates based on original image size
            xmin_norm = xmin * target_size / original_width
            ymin_norm = ymin * target_size / original_height
            xmax_norm = xmax * target_size / original_width
            ymax_norm = ymax * target_size / original_height

            boxes.append([xmin_norm, ymin_norm, xmax_norm, ymax_norm])
            
            label = obj.find("name").text
            labels.append(self.class_list.index(label))
        
        return boxes, labels
    
    def create_segmentation_masks(self, image_size, boxes):
        masks = []
        for box in boxes:
            mask = Image.new('L', image_size, 0)  # 'L' mode for grayscale
            draw = ImageDraw.Draw(mask)
            xmin, ymin, xmax, ymax = box
            draw.rectangle([xmin, ymin, xmax, ymax], fill=1)
            mask = np.array(mask)
            masks.append(mask)
        return masks

In [None]:
class_list = [
    "ad_unterschrift", "adress_aend", "ad_erzieher", "ad_neue_ad", "ad_schueler_unterschrift",
    "ad_erzieher_name", "ad_erzieher_vorname", "ad_erzieher_tel", "ad_erzieher_email",
    "ad_neue_ad_str_haus_nr", "ad_neue_ad_plz", "ad_neue_ad_stadt", "ad_schueler_datum",
    "schueler", "schueler_name", "schueler_vorname", "schueler_klasse",
    "ag", "ag_auswahl", "ag_unterschrift", "ag_schueler_datum",
    "ag_auswahl_wahl_1", "ag_auswahl_wahl_2", "ag_auswahl_wahl_3", "ag_schueler_unterschrift"]

print(len(class_list))

transforms = T.Compose([T.Resize((1024, 1024))])

In [None]:
# Create dataset instances
train_dataset = PascalVOCDataset('train_dataset', class_list= class_list, transforms=transforms )
val_dataset = PascalVOCDataset('val_dataset', class_list= class_list, transforms=transforms )
test_dataset = PascalVOCDataset('test_dataset', class_list= class_list, transforms=transforms)

def collate_fn(batch):
    return tuple(zip(*batch))

# Define data loaders
train_data_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=4, shuffle=True, num_workers=4,
    collate_fn=lambda x: tuple(zip(*x)))

val_data_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=2, shuffle=False, num_workers=4,
    collate_fn=lambda x: tuple(zip(*x)))

test_data_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=lambda x: tuple(zip(*x)))

### Test Loaders

In [None]:
def plot_image(image, boxes, labels, masks, class_list=None):
    fig, ax = plt.subplots(1)
    image = image.permute(1, 2, 0).cpu().numpy()
    ax.imshow(image)

    for mask in masks:
        if mask.ndim == 2:
            mask = np.expand_dims(mask, axis=0)
        masked_image = np.ma.masked_where(mask[0] == 0, mask[0])
        ax.imshow(masked_image, alpha=0.5, cmap='jet')

    for box, label in zip(boxes, labels):
        xmin, ymin, xmax, ymax = box
        width = xmax - xmin
        height = ymax - ymin
        rect = patches.Rectangle((xmin, ymin), width, height, linewidth=0.2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(xmin, ymin, class_list[label], fontsize=6, color='r')
    plt.axis('off')
    plt.show()

In [None]:
def test_loader(data_loader, name, class_list):
    print(f"Testing {name} data loader...")
    for i, (images, targets) in enumerate(data_loader):
        print(f"Batch {i + 1}:")
        print(f" - Number of images: {len(images)}")
        print(f" - Image size: {images[0].shape}")
        print(f" - Number of targets: {len(targets)}")
        for j in range(len(targets)):
            boxes = targets[j]['boxes']
            labels = targets[j]['labels']
            masks = targets[j]['masks']
            
            print(f" - Target {j + 1}:")
            print(f"   - Boxes shape: {boxes.shape}")
            print(f"   - Labels shape: {labels.shape}")
            print(f" - Mask shape: {targets[j]['masks'].shape}")
            
            plot_image(images[j], boxes, labels, masks, class_list)
        break

# Test each data loader
test_loader(train_data_loader, "train", class_list)
test_loader(val_data_loader, "validation", class_list)
test_loader(test_data_loader, "test", class_list)

### Early Stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=50, delta=0, verbose=False, start_epoch=0):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_loss = float('inf')
        self.start_epoch = start_epoch

    def __call__(self, val_loss, model, epoch):
        if epoch < self.start_epoch:
            return

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            print(f"Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}).  Saving model ...")
        torch.save(model.state_dict(), 'checkpoint.pth')
        self.best_loss = val_loss

### Evaluation Metrics

In [None]:
def compute_metrics(pred_boxes, pred_labels, pred_scores, pred_masks, true_boxes, true_labels, true_masks, iou_threshold=0.5):
    num_classes = np.max(np.concatenate([pred_labels, true_labels])) + 1
    tp = np.zeros(num_classes)
    fp = np.zeros(num_classes)
    fn = np.zeros(num_classes)
    matched = [[] for _ in range(num_classes)]

    for pred_box, pred_label, pred_mask in zip(pred_boxes, pred_labels, pred_masks):
        match_found = False
        for idx, (true_box, true_label, true_mask) in enumerate(zip(true_boxes, true_labels, true_masks)):
            if pred_label == true_label:
                box_iou = compute_box_iou(pred_box, true_box)
                mask_iou = compute_mask_iou(pred_mask, true_mask)
                
                if box_iou >= iou_threshold and mask_iou >= iou_threshold and idx not in matched[true_label]:
                    tp[true_label] += 1
                    matched[true_label].append(idx)
                    match_found = True
                    break
        if not match_found:
            fp[pred_label] += 1

    for idx, true_label in enumerate(true_labels):
        if idx not in matched[true_label]:
            fn[true_label] += 1

    precision = np.divide(tp, tp + fp, out=np.zeros_like(tp), where=(tp + fp) > 0)
    recall = np.divide(tp, tp + fn, out=np.zeros_like(tp), where=(tp + fn) > 0)
    accuracy = np.divide(tp, tp + fp + fn, out=np.zeros_like(tp), where=(tp + fp + fn) > 0)
    ap = compute_ap(tp, fp, fn)
    mAP = np.mean(ap)

    return {
        'precision': precision,
        'recall': recall,
        'accuracy': accuracy,
        'mAP': mAP
    }

def compute_box_iou(box1, box2):
    x1_max = max(box1[0], box2[0])
    y1_max = max(box1[1], box2[1])
    x2_min = min(box1[2], box2[2])
    y2_min = min(box1[3], box2[3])

    intersection_area = max(0, x2_min - x1_max) * max(0, y2_min - y1_max)

    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = intersection_area / float(box1_area + box2_area - intersection_area)
    return iou

def compute_mask_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    iou = intersection / union
    return iou

def compute_ap(tp, fp, fn):
    precision = np.divide(tp, tp + fp, out=np.zeros_like(tp), where=(tp + fp) > 0)
    recall = np.divide(tp, tp + fn, out=np.zeros_like(tp), where=(tp + fn) > 0)
    ap = precision * recall
    return ap


### Training functions

In [None]:
def one_epoch(model, optimizer, data_loader, device, progress_bar):
    model.train()
    running_loss = 0.0
    bar = tqdm(data_loader, desc="Training", leave=False)
    for images, targets in bar:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        optimizer.zero_grad()
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        losses.backward()
        optimizer.step()
        
        running_loss += losses.item()
        progress_bar.update(1)
    
    return running_loss / len(data_loader)

def evaluate(model, data_loader, device, progress_bar, iou_threshold=0.5, nms_threshold=0.5, plot = False):
    model.eval()
    metrics = {
        'precision': [],
        'recall': [],
        'accuracy': [],
        'mAP': []
    }
    running_val_loss = 0.0

    bar = tqdm(data_loader, desc="Validating", leave=False)
    with torch.no_grad():
        for images, targets in bar:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(images)
            
            for i, output in enumerate(outputs):
                keep = []
                unique_labels = output['labels'].unique()
                for label in unique_labels:
                    class_mask = output['labels'] == label
                    boxes = output['boxes'][class_mask]
                    scores = output['scores'][class_mask]
                    class_keep = ops.nms(boxes, scores, nms_threshold)
                    keep.extend(class_keep.cpu().numpy())
                keep = torch.as_tensor(keep, dtype=torch.long, device=device)
                outputs[i] = {k: v[keep] for k, v in output.items()}
            model.train()
            loss_dict = model(images, targets)
            model.eval()

            if isinstance(loss_dict, dict):
                val_loss = sum(loss for loss in loss_dict.values()).item()
            else:
                val_loss = sum(loss for loss in loss_dict).item()
            
            running_val_loss += val_loss
            progress_bar.update(1)

            for output, target in zip(outputs, targets):
                pred_boxes = output['boxes'].cpu().numpy()
                pred_labels = output['labels'].cpu().numpy()
                pred_scores = output['scores'].cpu().numpy()
                pred_masks = output['masks'].cpu().numpy() > 0.5
                true_boxes = target['boxes'].cpu().numpy()
                true_labels = target['labels'].cpu().numpy()
                true_masks = target['masks'].cpu().numpy() > 0.5
                
                metric = compute_metrics(pred_boxes, pred_labels, pred_scores, pred_masks,
                                        true_boxes, true_labels, true_masks, iou_threshold)
                metrics['precision'].append(metric['precision'])
                metrics['recall'].append(metric['recall'])
                metrics['accuracy'].append(metric['accuracy'])
                metrics['mAP'].append(metric['mAP'])
                
                if plot:
                    print("Predicted boxes, labels, and masks:")
                    plot_image(images[0], pred_boxes, pred_labels, pred_masks)
                    print("Ground truth boxes, labels, and masks:")
                    plot_image(images[0], true_boxes, true_labels, true_masks)

    max_classes = max(len(p) for p in metrics['precision'])
    precision_padded = [np.pad(p, (0, max_classes - len(p)), 'constant') for p in metrics['precision']]
    recall_padded = [np.pad(r, (0, max_classes - len(r)), 'constant') for r in metrics['recall']]
    accuracy_padded = [np.pad(a, (0, max_classes - len(a)), 'constant') for a in metrics['accuracy']]

    avg_precision = np.mean(np.stack(precision_padded), axis=0)
    avg_recall = np.mean(np.stack(recall_padded), axis=0)
    avg_accuracy = np.mean(np.stack(accuracy_padded), axis=0)
    avg_mAP = np.mean(metrics['mAP'])
    avg_val_loss = running_val_loss / len(data_loader)

    return avg_val_loss, avg_precision, avg_recall, avg_accuracy, avg_mAP

def main_training_loop(model, optimizer, scheduler, train_loader, val_loader, device, num_epochs, patience=PATIENCE, start_epoch=ES_START_EPOCH):
    early_stopping = EarlyStopping(patience=patience, verbose=True, start_epoch=start_epoch)

    # Calculate the total number of steps
    total_steps = num_epochs * (len(train_loader) + len(val_loader))

    with tqdm(total=total_steps, desc="Training Progress") as pbar:
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}:")

            # Training phase
            train_loss = one_epoch(model, optimizer, train_loader, device, pbar)
            print("Train Loss: {:.4f}".format(train_loss))

            # Validation phase
            model.eval()
            val_loss, avg_precision, avg_recall, avg_accuracy, avg_mAP = evaluate(model, val_loader, device, pbar)
            print("Validation Loss: {:.4f}, Precision: {:.4f}, Recall: {:.4f}, Accuracy: {:.4f}, mAP: {:.4f}".format(
                val_loss,
                np.mean(avg_precision),
                np.mean(avg_recall),
                np.mean(avg_accuracy),
                avg_mAP
            ))
            if scheduler is not None:
                scheduler.step(val_loss)
            # Check early stopping
            early_stopping(val_loss, model, epoch)

            if early_stopping.early_stop:
                print("Early stopping: No improvement in the last {} epochs".format(patience))
                break
    model.load_state_dict(torch.load('checkpoint.pth'))
    torch.save(model.state_dict(), 'best_model.pth')
    print("Training completed!")

### Load and adjust untrained Mask R-CNN

In [None]:
model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
num_classes = len(class_list) + 1 
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes = num_classes)
#in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
#hidden_layer = 128
#model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

device = torch.device("cuda")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005) #SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_STEP_SIZE, gamma=0.5)

In [None]:
main_training_loop(model, optimizer, None, train_data_loader, val_data_loader, device, num_epochs=EPOCHS)

### Test Inference

In [None]:
model_state_dict = torch.load('best_model.pth')
model = maskrcnn_resnet50_fpn()
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes = num_classes)
model.load_state_dict(model_state_dict)
model.to(device)
model.eval()  # Set the model to evaluation mode

In [None]:
for images, targets in test_data_loader:
    images = list(image.to(device) for image in images)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
    with torch.no_grad():
        outputs = model(images)
    for output, target in zip(outputs, targets):
                pred_boxes = output['boxes'].detach().cpu().numpy()
                pred_labels = output['labels'].detach().cpu().numpy()
                pred_scores = output['scores'].detach().cpu().numpy()
                pred_masks = output['masks'].detach().cpu().numpy() > 0.5
                true_boxes = target['boxes'].detach().cpu().numpy()
                true_labels = target['labels'].detach().cpu().numpy()
                true_masks = target['masks'].detach().cpu().numpy() > 0.5
                print("Predicted boxes, labels, and masks:")
                plot_image(images[0], pred_boxes, pred_labels, pred_masks, class_list)
                print("Ground truth boxes, labels, and masks:")
                plot_image(images[0], true_boxes, true_labels, true_masks, class_list)

In [None]:
image = cv2.imread("/mnt/c/Users/jason/GitHubRepos/LectorAI-TextExtraction/tempimages_api/beispiel_form_covered.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = T.ToTensor()(image)
image = T.Resize((1024, 1024))(image)
image = image.unsqueeze(0).to(device)
with torch.no_grad():
    outputs = model(image)
pred_boxes = output['boxes'].detach().cpu().numpy()
pred_labels = output['labels'].detach().cpu().numpy()
pred_scores = output['scores'].detach().cpu().numpy()
pred_masks = output['masks'].detach().cpu().numpy() > 0.5
image = image.squeeze(0).cpu()
plot_image(image, pred_boxes, pred_labels, pred_masks, class_list)