# Imports

In [1]:
import os
import cv2
import math
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from scipy.ndimage import label
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from scipy.ndimage import binary_opening, binary_closing

# Model

In [2]:
class CNNBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)


class MetalSurface(nn.Module):
    def __init__(self, in_ch=3, base=32):
        super().__init__()

        self.enc1 = CNNBlock(in_ch, base)
        self.enc2 = CNNBlock(base, base*2)
        self.enc3 = CNNBlock(base*2, base*4)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = CNNBlock(base*4, base*4)

        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = CNNBlock(base*4, base*2)

        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = CNNBlock(base*2, base)

        self.out = nn.Conv2d(base, 1, kernel_size=1)

    def forward(self, x):
        # x: (B, 3, 480, 480)
        e1 = self.enc1(x) # (B, 32, 480, 480)
        p1 = self.pool(e1) # (B, 32, 240, 240)
        e2 = self.enc2(p1) # (B, 64, 240, 240)
        p2 = self.pool(e2) # (B, 64, 120, 120)
        e3 = self.enc3(p2) # (B, 128, 120, 120)
        
        b = self.bottleneck(e3) # (B, 128, 120, 120)
        
        u2 = self.up2(b) # (B, 64, 240 240)
        cat2 = torch.cat([u2, e2], dim=1) # (B, 128, 240, 240)
        d2 = self.dec2(cat2) # (B, 64, 240, 240)
        
        u1 = self.up1(d2) # (B, 32, 480, 480)
        cat1 = torch.cat([u1, e1], dim=1) # (B, 64, 480, 480)
        d1 = self.dec1(cat1) # (B, 32, 480, 480)
        
        return self.out(d1) # (B, 1, 480, 480)

# Dataset

In [3]:
dataset = torch.load("datasets/grabbability_dataset_rgb_onehot_bce.pt")

class TupleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

N = len(dataset)

dataloader = DataLoader(
    TupleDataset(dataset),
    batch_size=1,
    shuffle=True,
    pin_memory=True
)

  dataset = torch.load("datasets/grabbability_dataset_rgb_onehot_bce.pt")


# Load Model

In [4]:
device = 'cuda'
model = MetalSurface()
model.to(device)
model.load_state_dict(torch.load("models/metal_surface_rgb_onehot_focalbce_globalcnn_maxpoolred.pth"))
model.eval()

  model.load_state_dict(torch.load("models/metal_surface_rgb_onehot_focalbce_globalcnn_maxpoolred.pth"))


MetalSurface(
  (enc1): CNNBlock(
    (net): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc2): CNNBlock(
    (net): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc3): CNNBlock(
    (net): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3),

# Segmentation

In [5]:
output_directory = "segments"

In [6]:
scaler = torch.amp.GradScaler(device)
for batch_idx, (image, _) in enumerate(dataloader):
    image = image.to(device)
    with torch.amp.autocast(device):
        output = torch.sigmoid(model(image).reshape(1, 480, 480)).detach().cpu()
        output = F.avg_pool2d(output, kernel_size=5, stride=1, padding=2)
        binary = (output >= 0.35)
        binary = binary.reshape(480, 480)
        binary = binary.numpy()
        
        image = image.detach().cpu().numpy().reshape(3, 480, 480).transpose(1, 2, 0)
        if image.max() <= 1.0:
            image = (image * 255).astype(np.uint8)
        
        binary = binary_opening(binary, structure=np.ones((3,3)))
        binary = binary_closing(binary, structure=np.ones((5,5)))
        labels_cc, num_components = label(binary)
        labels_cc_reshaped = labels_cc.reshape(480, 480, 1)
        
        crop_idx = 0
        for k in range(1, num_components + 1):
            num_pixels = np.sum(labels_cc == k)
            if num_pixels < 400:
                continue
            ys, xs = np.where(labels_cc == k)
            y_min, y_max = ys.min(), ys.max()
            x_min, x_max = xs.min(), xs.max()
            kth_image = image * (labels_cc_reshaped == k)
            
            kth_crop = kth_image[y_min: y_max + 1, x_min: x_max + 1, :]
            crop_name = f"{batch_idx}_{crop_idx}.png"
            crop_idx += 1
            crop_path = os.path.join(output_directory, crop_name)
            crop_img = Image.fromarray(kth_crop.astype(np.uint8))
            crop_img = crop_img.resize((224, 224), Image.BILINEAR)
            crop_img.save(crop_path)