# Сегментация кубов и параллелепипедов с определением связей

Этот ноутбук обучает модель Mask R-CNN для:
1. **Instance Segmentation** красных кубов и зелёных параллелепипедов
2. **Определение связей** какой параллелепипед принадлежит какому кубу

Датасет в формате COCO с дополнительным полем `parent_id`.

## 1. Установка зависимостей

In [None]:
!pip install torch torchvision
!pip install pycocotools
!pip install opencv-python-headless
!pip install albumentations
!pip install matplotlib
!pip install scipy

In [None]:
import os
import json
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import torchvision.transforms as T

from scipy.optimize import linear_sum_assignment

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Конфигурация

In [None]:
# Пути к датасету - измените на свои пути в Kaggle
DATASET_PATH = "/kaggle/input/strawberry-peduncle-segmentation"
IMAGES_PATH = os.path.join(DATASET_PATH, "images")
MASKS_PATH = os.path.join(DATASET_PATH, "masks")
ANNOTATIONS_PATH = os.path.join(DATASET_PATH, "annotations.json")

# Гиперпараметры
NUM_CLASSES = 3  # background + red_cube + green_parallelepiped
BATCH_SIZE = 4
NUM_EPOCHS = 20
LEARNING_RATE = 0.005
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Категории
CATEGORIES = {
    0: "background",
    1: "red_cube",
    2: "green_parallelepiped"
}

## 3. Dataset класс

In [None]:
class CubeParallelepipedDataset(Dataset):
    """Dataset для кубов и параллелепипедов с информацией о связях."""
    
    def __init__(self, images_path, masks_path, annotations_path, transforms=None):
        self.images_path = images_path
        self.masks_path = masks_path
        self.transforms = transforms
        
        # Загрузка аннотаций
        with open(annotations_path, 'r') as f:
            data = json.load(f)
        
        self.images_info = {img['id']: img for img in data['images']}
        
        # Группировка аннотаций по изображениям
        self.annotations_by_image = defaultdict(list)
        for ann in data['annotations']:
            self.annotations_by_image[ann['image_id']].append(ann)
        
        self.image_ids = list(self.images_info.keys())
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_info = self.images_info[image_id]
        
        # Загрузка изображения
        img_path = os.path.join(self.images_path, image_info['file_name'])
        image = Image.open(img_path).convert("RGB")
        image = np.array(image)
        
        # Загрузка маски
        mask_path = os.path.join(self.masks_path, image_info['file_name'])
        mask_image = np.array(Image.open(mask_path).convert("RGB"))
        
        # Получение аннотаций для этого изображения
        annotations = self.annotations_by_image[image_id]
        
        boxes = []
        labels = []
        masks = []
        instance_ids = []
        parent_ids = []
        
        for ann in annotations:
            # Bbox в формате [x, y, width, height] -> [x1, y1, x2, y2]
            x, y, w, h = ann['bbox']
            boxes.append([x, y, x + w, y + h])
            labels.append(ann['category_id'])
            instance_ids.append(ann['instance_id'])
            parent_ids.append(ann['parent_id'])
            
            # Создание бинарной маски из цветовой маски
            seg_color = ann['segmentation_color']
            obj_mask = np.all(mask_image == seg_color, axis=2).astype(np.uint8)
            masks.append(obj_mask)
        
        # Конвертация в тензоры
        boxes = torch.as_tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4), dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)
        masks = torch.as_tensor(np.array(masks), dtype=torch.uint8) if masks else torch.zeros((0, image.shape[0], image.shape[1]), dtype=torch.uint8)
        instance_ids = torch.as_tensor(instance_ids, dtype=torch.int64) if instance_ids else torch.zeros((0,), dtype=torch.int64)
        parent_ids = torch.as_tensor(parent_ids, dtype=torch.int64) if parent_ids else torch.zeros((0,), dtype=torch.int64)
        
        image_id_tensor = torch.tensor([image_id])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) if len(boxes) > 0 else torch.zeros((0,))
        iscrowd = torch.zeros((len(labels),), dtype=torch.int64)
        
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": image_id_tensor,
            "area": area,
            "iscrowd": iscrowd,
            "instance_ids": instance_ids,
            "parent_ids": parent_ids
        }
        
        # Конвертация изображения в тензор
        image = torch.as_tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0
        
        if self.transforms:
            image, target = self.transforms(image, target)
        
        return image, target

## 4. Модель Mask R-CNN с головой для связей

In [None]:
class AssociationHead(nn.Module):
    """Голова для предсказания связей между объектами.
    
    Для каждой пары (параллелепипед, куб) предсказывает вероятность связи.
    """
    
    def __init__(self, feature_dim=256):
        super().__init__()
        
        self.feature_dim = feature_dim
        
        # MLP для обработки признаков пары объектов
        self.pair_mlp = nn.Sequential(
            nn.Linear(feature_dim * 2 + 4, 256),  # +4 для относительных координат
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
    
    def forward(self, cube_features, para_features, cube_boxes, para_boxes):
        """
        Args:
            cube_features: (N_cubes, feature_dim) - признаки кубов
            para_features: (N_paras, feature_dim) - признаки параллелепипедов
            cube_boxes: (N_cubes, 4) - bbox кубов
            para_boxes: (N_paras, 4) - bbox параллелепипедов
            
        Returns:
            association_scores: (N_paras, N_cubes) - матрица вероятностей связей
        """
        n_cubes = cube_features.shape[0]
        n_paras = para_features.shape[0]
        
        if n_cubes == 0 or n_paras == 0:
            return torch.zeros(n_paras, n_cubes, device=cube_features.device)
        
        # Вычисляем центры bbox
        cube_centers = (cube_boxes[:, :2] + cube_boxes[:, 2:]) / 2
        para_centers = (para_boxes[:, :2] + para_boxes[:, 2:]) / 2
        
        scores = torch.zeros(n_paras, n_cubes, device=cube_features.device)
        
        for i in range(n_paras):
            for j in range(n_cubes):
                # Относительные координаты (нормализованные)
                rel_pos = para_centers[i] - cube_centers[j]
                rel_size = (
                    (para_boxes[i, 2:] - para_boxes[i, :2]) / 
                    (cube_boxes[j, 2:] - cube_boxes[j, :2] + 1e-6)
                )
                
                # Конкатенация признаков
                pair_input = torch.cat([
                    para_features[i],
                    cube_features[j],
                    rel_pos,
                    rel_size
                ])
                
                scores[i, j] = self.pair_mlp(pair_input)
        
        return torch.sigmoid(scores)

In [None]:
def get_model(num_classes, pretrained=True):
    """Создаёт модель Mask R-CNN."""
    
    # Загружаем предобученную модель
    model = maskrcnn_resnet50_fpn(pretrained=pretrained)
    
    # Заменяем голову классификатора
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    # Заменяем голову масок
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask, hidden_layer, num_classes
    )
    
    return model

In [None]:
class CubeParallelepipedModel(nn.Module):
    """Полная модель: Mask R-CNN + Association Head."""
    
    def __init__(self, num_classes):
        super().__init__()
        
        self.detector = get_model(num_classes)
        self.association_head = AssociationHead(feature_dim=1024)
        
        # Проектор для box features -> association features
        self.feature_projector = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU()
        )
    
    def forward(self, images, targets=None):
        """
        Training: returns losses dict
        Inference: returns detections + associations
        """
        if self.training and targets is not None:
            # Стандартный forward для Mask R-CNN
            losses = self.detector(images, targets)
            
            # TODO: Добавить association loss при необходимости
            return losses
        else:
            # Inference
            detections = self.detector(images)
            return detections
    
    def predict_associations(self, detections):
        """Предсказывает связи для детекций после inference.
        
        Использует простую эвристику на основе расстояния,
        так как у нас нет доступа к внутренним features после inference.
        """
        results = []
        
        for det in detections:
            labels = det['labels'].cpu().numpy()
            boxes = det['boxes'].cpu().numpy()
            scores = det['scores'].cpu().numpy()
            masks = det['masks'].cpu().numpy()
            
            # Индексы кубов и параллелепипедов
            cube_indices = np.where(labels == 1)[0]
            para_indices = np.where(labels == 2)[0]
            
            # Предсказание связей на основе расстояния
            associations = {}
            
            for para_idx in para_indices:
                para_box = boxes[para_idx]
                para_center = np.array([(para_box[0] + para_box[2]) / 2, 
                                        (para_box[1] + para_box[3]) / 2])
                
                min_dist = float('inf')
                best_cube_idx = -1
                
                for cube_idx in cube_indices:
                    cube_box = boxes[cube_idx]
                    cube_center = np.array([(cube_box[0] + cube_box[2]) / 2,
                                           (cube_box[1] + cube_box[3]) / 2])
                    
                    # Проверяем, что параллелепипед выше куба (меньше y = выше)
                    if para_center[1] < cube_center[1]:
                        dist = np.linalg.norm(para_center - cube_center)
                        if dist < min_dist:
                            min_dist = dist
                            best_cube_idx = cube_idx
                
                associations[int(para_idx)] = int(best_cube_idx) if best_cube_idx >= 0 else None
            
            results.append({
                'labels': labels,
                'boxes': boxes,
                'scores': scores,
                'masks': masks,
                'associations': associations
            })
        
        return results

## 5. Функции обучения

In [None]:
def collate_fn(batch):
    """Кастомная функция для DataLoader."""
    return tuple(zip(*batch))


def train_one_epoch(model, optimizer, data_loader, device, epoch):
    """Обучение одной эпохи."""
    model.train()
    total_loss = 0
    
    for batch_idx, (images, targets) in enumerate(data_loader):
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        total_loss += losses.item()
        
        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch}] Batch [{batch_idx}/{len(data_loader)}] "
                  f"Loss: {losses.item():.4f}")
    
    return total_loss / len(data_loader)


@torch.no_grad()
def evaluate(model, data_loader, device):
    """Оценка модели."""
    model.eval()
    
    all_predictions = []
    all_targets = []
    
    for images, targets in data_loader:
        images = list(img.to(device) for img in images)
        
        outputs = model(images)
        
        all_predictions.extend(outputs)
        all_targets.extend(targets)
    
    return all_predictions, all_targets

## 6. Метрики для оценки связей

In [None]:
def compute_association_accuracy(predictions, targets, iou_threshold=0.5):
    """Вычисляет точность предсказания связей.
    
    Для каждого правильно обнаруженного параллелепипеда проверяет,
    правильно ли предсказан связанный куб.
    """
    correct_associations = 0
    total_associations = 0
    
    for pred, target in zip(predictions, targets):
        pred_labels = pred['labels'].cpu().numpy()
        pred_boxes = pred['boxes'].cpu().numpy()
        
        target_labels = target['labels'].cpu().numpy()
        target_boxes = target['boxes'].cpu().numpy()
        target_instance_ids = target['instance_ids'].cpu().numpy()
        target_parent_ids = target['parent_ids'].cpu().numpy()
        
        # Находим соответствия между предсказаниями и GT по IoU
        pred_para_indices = np.where(pred_labels == 2)[0]
        target_para_indices = np.where(target_labels == 2)[0]
        
        for pred_idx in pred_para_indices:
            pred_box = pred_boxes[pred_idx]
            
            # Находим лучшее совпадение в GT
            best_iou = 0
            best_target_idx = -1
            
            for target_idx in target_para_indices:
                target_box = target_boxes[target_idx]
                iou = compute_iou(pred_box, target_box)
                if iou > best_iou:
                    best_iou = iou
                    best_target_idx = target_idx
            
            if best_iou >= iou_threshold and best_target_idx >= 0:
                total_associations += 1
                
                # Проверяем правильность связи
                # GT parent_id для этого параллелепипеда
                gt_parent = target_parent_ids[best_target_idx]
                
                # Предсказанная связь (по расстоянию)
                pred_associations = predict_associations_simple(pred_boxes, pred_labels)
                pred_parent_idx = pred_associations.get(pred_idx, None)
                
                if pred_parent_idx is not None:
                    # Проверяем, совпадает ли предсказанный куб с GT
                    pred_cube_box = pred_boxes[pred_parent_idx]
                    
                    # Находим GT куб с этим parent_id
                    gt_cube_indices = np.where(
                        (target_labels == 1) & (target_instance_ids == gt_parent)
                    )[0]
                    
                    if len(gt_cube_indices) > 0:
                        gt_cube_box = target_boxes[gt_cube_indices[0]]
                        cube_iou = compute_iou(pred_cube_box, gt_cube_box)
                        
                        if cube_iou >= iou_threshold:
                            correct_associations += 1
    
    accuracy = correct_associations / total_associations if total_associations > 0 else 0
    return accuracy, correct_associations, total_associations


def compute_iou(box1, box2):
    """Вычисляет IoU между двумя bbox."""
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    inter_area = max(0, x2 - x1) * max(0, y2 - y1)
    
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    
    union_area = box1_area + box2_area - inter_area
    
    return inter_area / union_area if union_area > 0 else 0


def predict_associations_simple(boxes, labels):
    """Простое предсказание связей на основе расстояния."""
    cube_indices = np.where(labels == 1)[0]
    para_indices = np.where(labels == 2)[0]
    
    associations = {}
    
    for para_idx in para_indices:
        para_box = boxes[para_idx]
        para_center = np.array([(para_box[0] + para_box[2]) / 2,
                                (para_box[1] + para_box[3]) / 2])
        
        min_dist = float('inf')
        best_cube_idx = None
        
        for cube_idx in cube_indices:
            cube_box = boxes[cube_idx]
            cube_center = np.array([(cube_box[0] + cube_box[2]) / 2,
                                   (cube_box[1] + cube_box[3]) / 2])
            
            # Параллелепипед должен быть над кубом
            if para_box[3] <= cube_box[1]:  # para bottom <= cube top
                dist = np.linalg.norm(para_center - cube_center)
                if dist < min_dist:
                    min_dist = dist
                    best_cube_idx = cube_idx
        
        associations[para_idx] = best_cube_idx
    
    return associations

## 7. Обучение модели

In [None]:
# Создание датасетов
dataset = CubeParallelepipedDataset(
    images_path=IMAGES_PATH,
    masks_path=MASKS_PATH,
    annotations_path=ANNOTATIONS_PATH
)

# Разделение на train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
# Создание DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=2
)

In [None]:
# Создание модели
model = get_model(NUM_CLASSES, pretrained=True)
model.to(DEVICE)

# Оптимизатор
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=LEARNING_RATE, momentum=0.9, weight_decay=0.0005)

# Learning rate scheduler
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

print(f"Model loaded on {DEVICE}")

In [None]:
# Цикл обучения
train_losses = []
best_loss = float('inf')

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n{'='*50}")
    print(f"Epoch {epoch}/{NUM_EPOCHS}")
    print(f"{'='*50}")
    
    # Обучение
    avg_loss = train_one_epoch(model, optimizer, train_loader, DEVICE, epoch)
    train_losses.append(avg_loss)
    
    print(f"\nAverage training loss: {avg_loss:.4f}")
    
    # Обновление learning rate
    lr_scheduler.step()
    
    # Сохранение лучшей модели
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Saved best model with loss: {best_loss:.4f}")
    
    # Периодическая оценка
    if epoch % 5 == 0:
        predictions, targets = evaluate(model, val_loader, DEVICE)
        acc, correct, total = compute_association_accuracy(predictions, targets)
        print(f"\nAssociation accuracy: {acc:.4f} ({correct}/{total})")

In [None]:
# График обучения
plt.figure(figsize=(10, 5))
plt.plot(train_losses, 'b-', label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True)
plt.savefig('training_progress.png')
plt.show()

## 8. Визуализация результатов

In [None]:
def visualize_predictions(image, prediction, score_threshold=0.5):
    """Визуализирует предсказания модели с связями."""
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    # Оригинальное изображение с bbox
    img_np = image.cpu().numpy().transpose(1, 2, 0)
    axes[0].imshow(img_np)
    axes[0].set_title('Detections')
    
    labels = prediction['labels'].cpu().numpy()
    boxes = prediction['boxes'].cpu().numpy()
    scores = prediction['scores'].cpu().numpy()
    masks = prediction['masks'].cpu().numpy()
    
    colors = {1: 'red', 2: 'green'}
    
    # Рисуем bbox
    for i in range(len(labels)):
        if scores[i] >= score_threshold:
            box = boxes[i]
            label = labels[i]
            color = colors.get(label, 'blue')
            
            rect = plt.Rectangle(
                (box[0], box[1]), box[2] - box[0], box[3] - box[1],
                fill=False, edgecolor=color, linewidth=2
            )
            axes[0].add_patch(rect)
            axes[0].text(box[0], box[1] - 5, 
                        f"{CATEGORIES[label]}: {scores[i]:.2f}",
                        color=color, fontsize=8)
    
    # Маски с связями
    combined_mask = np.zeros((*img_np.shape[:2], 3))
    
    # Предсказываем связи
    associations = predict_associations_simple(boxes, labels)
    
    for i in range(len(labels)):
        if scores[i] >= score_threshold:
            mask = masks[i, 0] > 0.5
            label = labels[i]
            
            if label == 1:  # Куб - красный
                combined_mask[mask] = [1, 0, 0]
            elif label == 2:  # Параллелепипед - зелёный
                combined_mask[mask] = [0, 1, 0]
    
    axes[1].imshow(img_np)
    axes[1].imshow(combined_mask, alpha=0.5)
    axes[1].set_title('Masks with Associations')
    
    # Рисуем линии связей
    for para_idx, cube_idx in associations.items():
        if scores[para_idx] >= score_threshold and cube_idx is not None and scores[cube_idx] >= score_threshold:
            para_box = boxes[para_idx]
            cube_box = boxes[cube_idx]
            
            para_center = [(para_box[0] + para_box[2]) / 2, (para_box[1] + para_box[3]) / 2]
            cube_center = [(cube_box[0] + cube_box[2]) / 2, (cube_box[1] + cube_box[3]) / 2]
            
            axes[1].plot([para_center[0], cube_center[0]], 
                        [para_center[1], cube_center[1]], 
                        'y-', linewidth=2)
            axes[1].plot(*para_center, 'yo', markersize=8)
            axes[1].plot(*cube_center, 'yo', markersize=8)
    
    plt.tight_layout()
    plt.show()
    
    return associations

In [None]:
# Загрузка лучшей модели и визуализация
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# Получаем несколько примеров из validation set
sample_images = []
sample_targets = []

for i in range(min(5, len(val_dataset))):
    img, target = val_dataset[i]
    sample_images.append(img)
    sample_targets.append(target)

# Предсказания
with torch.no_grad():
    images_tensor = [img.to(DEVICE) for img in sample_images]
    predictions = model(images_tensor)

# Визуализация каждого примера
for i, (img, pred) in enumerate(zip(sample_images, predictions)):
    print(f"\n--- Sample {i + 1} ---")
    associations = visualize_predictions(img, pred)
    print(f"Predicted associations: {associations}")

## 9. Сохранение модели

In [None]:
# Сохранение финальной модели
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epochs': NUM_EPOCHS,
    'train_losses': train_losses,
}, 'cube_parallelepiped_model.pth')

print("Model saved to cube_parallelepiped_model.pth")

## 10. Inference на новых изображениях

In [None]:
def inference(model, image_path, device, score_threshold=0.5):
    """Запуск inference на одном изображении."""
    
    # Загрузка изображения
    image = Image.open(image_path).convert("RGB")
    image_tensor = torch.as_tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1) / 255.0
    
    model.eval()
    with torch.no_grad():
        prediction = model([image_tensor.to(device)])[0]
    
    # Фильтрация по score
    keep = prediction['scores'] >= score_threshold
    filtered_pred = {
        'boxes': prediction['boxes'][keep],
        'labels': prediction['labels'][keep],
        'scores': prediction['scores'][keep],
        'masks': prediction['masks'][keep]
    }
    
    # Визуализация
    visualize_predictions(image_tensor, filtered_pred, score_threshold)
    
    return filtered_pred

# Пример использования:
# result = inference(model, 'path/to/new/image.png', DEVICE)