# Segmentation of the images into separate digits

In [1]:
import numpy as np
import torch
import torchvision
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torch.optim import Adam
import matplotlib.pyplot as plt
from PIL import Image
import glob
import os

In [2]:
class DigitSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, npy_files, transform=None):
        self.npy_files = npy_files
        self.transform = transform
        
        self.data = []
        for npy_file in self.npy_files:
            batch = np.load(npy_file)  # Shape: (10000, 40, 168)
            self.data.extend(batch)
        self.data = np.array(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]  # Shape: (40, 168)
        
        # Convert grayscale to RGB by repeating along the channel dimension
        image = np.stack([image] * 3, axis=-1)  # Shape: (40, 168, 3)
        image = Image.fromarray(image.astype(np.uint8))

        # Generate dummy bounding boxes and masks
        # For now, bounding boxes and masks are placeholders.
        height, width = image.size
        boxes = torch.tensor([[0, 0, width // 2, height // 2]], dtype=torch.float32)  # Placeholder box
        labels = torch.tensor([1], dtype=torch.int64)  # Single label
        masks = torch.zeros((1, height, width), dtype=torch.uint8)  # Placeholder mask

        # Apply transforms if available
        if self.transform:
            image = self.transform(image)

        return image, {"boxes": boxes, "labels": labels, "masks": masks}



def visualize_predictions(image, predictions):
    """Visualize the original image with predicted masks overlaid."""
    plt.figure(figsize=(10, 10))
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())
    
    for box, label, score in zip(predictions["boxes"], predictions["labels"], predictions["scores"]):
        if score > 0.5:  # Display high-confidence predictions
            x1, y1, x2, y2 = box
            plt.gca().add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, 
                                              fill=False, edgecolor='red', linewidth=2))
    plt.show()


def save_cropped_digits(image, predictions, output_dir):
    """Save cropped images of each digit."""
    for i, box in enumerate(predictions["boxes"]):
        x1, y1, x2, y2 = box.int().tolist()
        cropped = image[:, y1:y2, x1:x2]
        save_path = os.path.join(output_dir, f"digit_{i}.png")
        torchvision.utils.save_image(cropped, save_path)


In [3]:
import numpy as np
file_path = '../data/data0.npy'  # Replace with an actual file path
data = np.load(file_path)
print("Shape:", data.shape)
print("Data type:", data.dtype)

Shape: (10000, 40, 168)
Data type: uint8


In [4]:
import glob
from torchvision.transforms import Compose, ToTensor, Resize

# Paths to all .npy files
npy_files = sorted(glob.glob('../data/data*.npy'))

# Transformations
transforms = Compose([
    Resize((128, 128)),  # Resize images to a consistent size
    ToTensor(),          # Convert to PyTorch tensor
])

# Dataset and DataLoader
dataset = DigitSegmentationDataset(npy_files, transform=transforms)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
# Test DataLoader
# for images, indices in dataloader:
#     print("Batch of images:", images.shape)  # Example: (8, 3, 128, 128)
#     print("Indices:", indices)
#     break


In [5]:
# Load Pretrained Mask R-CNN
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = maskrcnn_resnet50_fpn(pretrained=True)

# Update the classifier for the number of classes (2: background and digits)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, 2)  # 2 classes

# Update the mask predictor to match output channels of the feature extractor
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
    in_channels=256,  # Number of channels from the feature extractor (256 after FPN)
    num_classes=2,     # Number of classes (background + digits)
    dim_reduced=2
)

model.to(device)

# Optimizer and Hyperparameters
optimizer = Adam(model.parameters(), lr=0.001)
num_epochs = 5




In [6]:
from tqdm import tqdm

model.train()
for epoch in range(num_epochs):
    epoch_loss = 0
    for images, targets in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Move data to the device
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass and optimization
        losses.backward()
        optimizer.step()

        epoch_loss += losses.item()

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}")

# Save the Trained Model
torch.save(model.state_dict(), "mask_rcnn_digit_segmentation.pth")

Epoch 1/5:  39%|███▉      | 1476/3750 [15:50<24:23,  1.55it/s]


KeyboardInterrupt: 

In [None]:
model.eval()
test_images = next(iter(dataloader))[0]  # Get a batch of test images
test_images = [img.to(device) for img in test_images]

# Perform inference
with torch.no_grad():
    predictions = model(test_images)

# Visualize Predictions
for idx, pred in enumerate(predictions):
    print(f"Image {idx+1}:")
    boxes = pred["boxes"].cpu().numpy()
    scores = pred["scores"].cpu().numpy()
    masks = pred["masks"].cpu().numpy()

    # Filter by a confidence threshold (e.g., 0.5)
    high_conf_indices = scores > 0.5
    boxes = boxes[high_conf_indices]
    masks = masks[high_conf_indices]

    print(f"  Detected {len(boxes)} objects")
    for i, box in enumerate(boxes):
        print(f"    Box {i+1}: {box}, Score: {scores[i]}")
