<a href="https://colab.research.google.com/github/Sindhya456/Emotions-Recognition/blob/master/THIS_ONE_of_Lumo_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import files
uploaded = files.upload()  # select dataset.zip

In [None]:
!unzip -q rl_datasets.zip -d /content/

In [None]:
!pip install ImageHash


In [None]:
import os
import pandas as pd

assembled_dir = "/content/rl_dataset/assembled"
parts_dir = "/content/rl_dataset/parts"

pairs = []

# Get all parts files indexed by object name (ignore extension)
parts_map = {}
for fname in os.listdir(parts_dir):
    obj_name = fname.replace("parts_", "").split(".")[0]  # remove prefix + extension
    parts_map[obj_name] = os.path.join(parts_dir, fname)

# Try matching assembled images to parts
for fname in os.listdir(assembled_dir):
    obj_name = fname.replace("assembled_", "").split(".")[0]  # remove prefix + extension
    assembled_path = os.path.join(assembled_dir, fname)

    if obj_name in parts_map:
        pairs.append({"input": assembled_path, "output": parts_map[obj_name]})
    else:
        print(f"⚠️ No match found for {obj_name}")

df = pd.DataFrame(pairs)
df.to_csv("rl_dataset.csv", index=False)
print(f"✅ Final dataset size: {len(df)} pairs")
df.head()



In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd

# Load dataset (assuming rl_dataset.csv already created)
df = pd.read_csv("rl_dataset.csv")

def show_dataset_samples(df, n=5):
    """Show n random assembled ↔ parts pairs from the dataset."""
    sample = df.sample(n)

    plt.figure(figsize=(10, 2*n))

    for i, row in enumerate(sample.itertuples(), 1):
        # Show assembled image (input)
        assembled_img = Image.open(row.input)
        plt.subplot(n, 2, 2*i-1)
        plt.imshow(assembled_img)
        plt.title("Assembled")
        plt.axis("off")

        # Show parts image (output)
        parts_img = Image.open(row.output)
        plt.subplot(n, 2, 2*i)
        plt.imshow(parts_img)
        plt.title("Parts")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

# Example: visualize 5 random pairs
show_dataset_samples(df, n=5)


Image segmentation Model

In [None]:
# ==========================
# Imports
# ==========================
%matplotlib inline
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

from sklearn.cluster import KMeans

# ==========================
# User settings
# ==========================
csv_file = "/content/rl_dataset.csv"    # CSV with columns "input","output"
test_image = "/content/table.png"       # single test image
epochs = 20
batch_size = 8
img_size = 128
lr = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_parts = 32  # expected number of semantic part classes

# ==========================
# Mask utilities
# ==========================
def quantize_mask(mask_pil, step=16):
    arr = np.array(mask_pil.convert("RGB"))
    arr = (arr // step) * step
    return Image.fromarray(arr.astype(np.uint8))

def build_color_palette_robust(csv_path, mask_col="output", size=(128,128), n_parts=32):
    """
    Build a color2idx mapping using k-means clustering on all mask pixels.
    n_parts: expected number of semantic parts (output classes)
    """
    df = pd.read_csv(csv_path)
    all_pixels = []

    for mask_path in df[mask_col]:
        mask = Image.open(mask_path).convert("RGB").resize(size, Image.NEAREST)
        arr = np.array(mask).reshape(-1,3)
        all_pixels.append(arr)

    all_pixels = np.vstack(all_pixels)

    # KMeans clustering to reduce to n_parts
    kmeans = KMeans(n_clusters=n_parts, random_state=42, n_init=10)
    kmeans.fit(all_pixels)
    colors = [tuple(map(int, c)) for c in kmeans.cluster_centers_]

    color2idx = {c:i for i,c in enumerate(colors)}
    palette = colors
    print(f"Built robust palette with {len(palette)} classes.")
    return color2idx, palette

def color_mask_to_label(mask_pil, color2idx, size=(128,128)):
    mask = mask_pil.resize(size, resample=Image.NEAREST)
    arr = np.array(mask).reshape(-1,3)
    # assign each pixel to nearest palette color
    pixels = arr.astype(np.float32)
    palette_arr = np.array(list(color2idx.keys())).astype(np.float32)
    # compute distances
    dists = np.linalg.norm(pixels[:,None,:]-palette_arr[None,:,:], axis=2)
    idxs = dists.argmin(axis=1)
    return torch.from_numpy(idxs.reshape(size[1], size[0])).long()

def label_to_color_mask(label_tensor, palette):
    if isinstance(label_tensor, torch.Tensor):
        label = label_tensor.cpu().numpy()
    else:
        label = label_tensor
    h,w = label.shape
    out = np.zeros((h,w,3), dtype=np.uint8)
    for idx, col in enumerate(palette):
        mask = (label == idx)
        out[mask] = col
    return Image.fromarray(out)

# ==========================
# Dataset
# ==========================
class SegDataset(Dataset):
    def __init__(self, csv_file, color2idx, size=(128,128), input_col="input", mask_col="output"):
        self.df = pd.read_csv(csv_file)
        self.color2idx = color2idx
        self.size = size
        self.input_col = input_col
        self.mask_col = mask_col

        self.input_transform = transforms.Compose([
            transforms.Resize(size, interpolation=Image.BILINEAR),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row[self.input_col]).convert("RGB")
        mask = Image.open(row[self.mask_col]).convert("RGB")
        img = self.input_transform(img)
        mask_lbl = color_mask_to_label(mask, self.color2idx, size=self.size)
        return img, mask_lbl

# ==========================
# Dice & IoU functions
# ==========================
def dice_loss_logits(logits, targets, eps=1e-6):
    probs = F.softmax(logits, dim=1)
    B, C, H, W = probs.shape
    targets_onehot = F.one_hot(targets.long(), num_classes=C).permute(0,3,1,2).float()
    inter = (probs * targets_onehot).sum(dim=[2,3])
    unions = (probs + targets_onehot).sum(dim=[2,3])
    dice = (2*inter + eps) / (unions + eps)
    return 1.0 - dice.mean()

def mean_iou(logits, targets, num_classes):
    preds = logits.argmax(dim=1)
    ious = []
    for cls in range(num_classes):
        pred_mask = (preds == cls)
        true_mask = (targets == cls)
        inter = (pred_mask & true_mask).sum().item()
        union = (pred_mask | true_mask).sum().item()
        if union == 0:
            continue
        ious.append(inter / union)
    if len(ious) == 0:
        return 0.0
    return float(np.mean(ious))

# ==========================
# U-Net model
# ==========================
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, in_ch=3, base=32, num_classes=3):
        super().__init__()
        self.enc1 = DoubleConv(in_ch, base)
        self.pool = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(base, base*2)
        self.enc3 = DoubleConv(base*2, base*4)
        self.enc4 = DoubleConv(base*4, base*8)
        self.bottleneck = DoubleConv(base*8, base*16)

        self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, stride=2)
        self.dec4 = DoubleConv(base*16, base*8)
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = DoubleConv(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = DoubleConv(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = DoubleConv(base*2, base)

        self.outc = nn.Conv2d(base, num_classes, 1)

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool(x1))
        x3 = self.enc3(self.pool(x2))
        x4 = self.enc4(self.pool(x3))
        b = self.bottleneck(self.pool(x4))

        d4 = self.up4(b)
        d4 = torch.cat([d4, x4], dim=1)
        d4 = self.dec4(d4)
        d3 = self.up3(d4)
        d3 = torch.cat([d3, x3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, x2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, x1], dim=1)
        d1 = self.dec1(d1)
        out = self.outc(d1)
        return out

# ==========================
# Build palette & dataset
# ==========================
color2idx, palette = build_color_palette_robust(csv_file, mask_col="output", size=(img_size,img_size), n_parts=n_parts)
num_classes = len(palette)
dataset = SegDataset(csv_file, color2idx, size=(img_size, img_size))
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

# ==========================
# Model, optimizer, loss
# ==========================
model = UNet(in_ch=3, base=32, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
ce_loss = nn.CrossEntropyLoss()

# ==========================
# Training loop
# ==========================
print("Starting training...")
for epoch in range(1, epochs+1):
    model.train()
    total_loss = 0.0
    total_iou = 0.0
    steps = 0
    for imgs, masks in loader:
        imgs = imgs.to(device)
        masks = masks.to(device)

        logits = model(imgs)
        loss_ce = ce_loss(logits, masks)
        loss_dice = dice_loss_logits(logits, masks)
        loss = loss_ce + loss_dice

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        total_iou += mean_iou(logits, masks, num_classes)
        steps += 1

    print(f"Epoch {epoch}/{epochs} | Loss: {total_loss/steps:.4f} | Mean IoU: {total_iou/steps:.4f}")

# ==========================
# Test single image
# ==========================
def run_test(model, test_image_path, palette, size=(128,128)):
    model.eval()
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor()
    ])
    img = Image.open(test_image_path).convert("RGB")
    inp = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(inp)
        pred = logits.argmax(dim=1)[0]
    mask_out = label_to_color_mask(pred, palette)
    # show input + predicted mask
    fig, axs = plt.subplots(1,2, figsize=(6,3))
    axs[0].imshow(img); axs[0].set_title("Input")
    axs[1].imshow(mask_out); axs[1].set_title("Predicted mask")
    for ax in axs: ax.axis("off")
    plt.show()
    # save output
    mask_out.save("predicted_mask.png")
    print("Predicted mask saved as predicted_mask.png")

run_test(model, test_image, palette, size=(img_size,img_size))


main one ?

In [None]:
import os

print("Files in assembled folder:")
assembled_files = os.listdir("/content/rl_dataset/assembled")
for f in assembled_files[:5]:  # Show first 5
    print(f"  {f}")

print("\nFiles in parts folder:")
parts_files = os.listdir("/content/rl_dataset/parts")
for f in parts_files[:5]:  # Show first 5
    print(f"  {f}")

In [None]:
import os

# Get all files
assembled_files = sorted(os.listdir("/content/rl_dataset/assembled"))
parts_files = sorted(os.listdir("/content/rl_dataset/parts"))

print(f"Total assembled: {len(assembled_files)}")
print(f"Total parts: {len(parts_files)}")

# Extract base names
assembled_bases = set()
for f in assembled_files:
    base = f.replace('assembled_', '').replace('.png', '').replace('.jpg', '')
    assembled_bases.add(base)

parts_bases = set()
for f in parts_files:
    base = f.replace('parts_', '').replace('.png', '').replace('.jpg', '')
    parts_bases.add(base)

# Find matches
matches = assembled_bases & parts_bases
print(f"\nMatching items: {len(matches)}")
print(f"Matches: {sorted(list(matches))[:10]}")  # Show first 10

# Find mismatches
only_assembled = assembled_bases - parts_bases
only_parts = parts_bases - assembled_bases

print(f"\nOnly in assembled: {len(only_assembled)}")
print(f"Examples: {list(only_assembled)[:5]}")

print(f"\nOnly in parts: {len(only_parts)}")
print(f"Examples: {list(only_parts)[:5]}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt

# -------------------------
# Dataset
# -------------------------
class CADImageDataset(Dataset):
    def __init__(self, input_dir, output_dir, transform=None):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.transform = transform
        self.image_filenames = sorted(os.listdir(input_dir))

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        assembled_filename = self.image_filenames[idx]

        # Convert: assembled_coatstand.jpg/.png -> parts_coatstand.png
        base_name = assembled_filename.replace('assembled_', '')
        base_name = os.path.splitext(base_name)[0]
        disassembled_filename = f'parts_{base_name}.png'

        input_path = os.path.join(self.input_dir, assembled_filename)
        output_path = os.path.join(self.output_dir, disassembled_filename)

        input_img = Image.open(input_path).convert("L")
        output_img = Image.open(output_path).convert("L")

        if self.transform:
            input_img = self.transform(input_img)
            output_img = self.transform(output_img)

        return input_img, output_img

# -------------------------
# Perceptual Loss
# -------------------------
class PerceptualLoss(nn.Module):
    def __init__(self, layers=[2, 7, 12]):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.slices = nn.ModuleList([vgg[:l].eval() for l in layers])
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, pred, target):
        if pred.shape[1] == 1:
            pred = pred.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        loss = 0
        for slice in self.slices:
            loss += F.l1_loss(slice(pred), slice(target))
        return loss


# -------------------------
# Generator
# -------------------------
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_dropout=False):
        super(UNetBlock, self).__init__()
        self.down = down
        if down:
            self.block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, 2, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )
        else:
            self.block = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        self.use_dropout = use_dropout
        if use_dropout:
            self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.block(x)
        if self.use_dropout:
            x = self.dropout(x)
        return x


class GeneratorUNet(nn.Module):
    def __init__(self, input_channels=1, output_channels=1):
        super(GeneratorUNet, self).__init__()
        # Encoder
        self.down1 = UNetBlock(input_channels, 64)
        self.down2 = UNetBlock(64, 128)
        self.down3 = UNetBlock(128, 256)
        self.down4 = UNetBlock(256, 512)
        self.down5 = UNetBlock(512, 512)
        self.down6 = UNetBlock(512, 512)
        self.down7 = UNetBlock(512, 512)

        # Decoder
        self.up1 = UNetBlock(512, 512, down=False, use_dropout=True)
        self.up2 = UNetBlock(1024, 512, down=False, use_dropout=True)
        self.up3 = UNetBlock(1024, 512, down=False, use_dropout=True)
        self.up4 = UNetBlock(1024, 256, down=False)
        self.up5 = UNetBlock(512, 128, down=False)
        self.up6 = UNetBlock(256, 64, down=False)
        self.final = nn.ConvTranspose2d(128, output_channels, 4, 2, 1)

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)

        u1 = self.up1(d7)
        u2 = self.up2(torch.cat([u1, d6], 1))
        u3 = self.up3(torch.cat([u2, d5], 1))
        u4 = self.up4(torch.cat([u3, d4], 1))
        u5 = self.up5(torch.cat([u4, d3], 1))
        u6 = self.up6(torch.cat([u5, d2], 1))
        output = torch.tanh(self.final(torch.cat([u6, d1], 1)))
        return output

# -------------------------
# Discriminator (DEFINE THIS FIRST)
# -------------------------
class Discriminator(nn.Module):
    def __init__(self, input_channels=1):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(input_channels * 2, 64, 4, 2, 1)),
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(256, 512, 4, 1, 1)),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 1, 4, 1, 1)
        )

    def forward(self, x, y):
        return self.model(torch.cat([x, y], 1))


# -------------------------
# Multi-Scale Discriminator (NOW THIS CAN USE Discriminator)
# -------------------------
class MultiScaleDiscriminator(nn.Module):
    def __init__(self, input_channels=1):
        super().__init__()
        self.discriminators = nn.ModuleList([
            Discriminator(input_channels) for _ in range(3)
        ])
        self.downsample = nn.AvgPool2d(3, stride=2, padding=1)

    def forward(self, x, y):
        results = []
        for i, disc in enumerate(self.discriminators):
            if i > 0:
                x = self.downsample(x)
                y = self.downsample(y)
            results.append(disc(x, y))
        return results


# -------------------------
# Loss Functions
# -------------------------
def part_separation_loss(pred, min_separation=0.1):
    """Encourages distinct separated regions"""
    edges_x = torch.abs(pred[:, :, :, :-1] - pred[:, :, :, 1:])
    edges_y = torch.abs(pred[:, :, :-1, :] - pred[:, :, 1:, :])

    weak_edge_penalty = torch.mean(torch.relu(min_separation - edges_x)) + \
                        torch.mean(torch.relu(min_separation - edges_y))

    return weak_edge_penalty


def gap_enforcement_loss(pred, target):
    """Penalize if there aren't enough 'background' pixels (gaps) between parts"""
    background_mask = (pred < -0.5).float()
    target_background = (target < -0.5).float()

    gap_loss = F.mse_loss(torch.mean(background_mask), torch.mean(target_background))
    return gap_loss


def combined_loss(pred, target, lambda_ssim=0.2):
    mse = F.mse_loss(pred, target)
    # REMOVED TV loss - it was preventing sharp edges
    ssim = torch.mean((1 - ((2 * pred * target + 0.01) / (pred ** 2 + target ** 2 + 0.01))))
    return mse + lambda_ssim * ssim


# -------------------------
# Training Setup
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
])

input_dir = "/content/rl_dataset/assembled"
output_dir = "/content/rl_dataset/parts"
dataset = CADImageDataset(input_dir, output_dir, transform)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

G = GeneratorUNet().to(device)
D = MultiScaleDiscriminator().to(device)  # Now this works!
perceptual_loss_fn = PerceptualLoss().to(device)

opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

epochs = 200
scheduler_G = optim.lr_scheduler.StepLR(opt_G, step_size=30, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(opt_D, step_size=30, gamma=0.5)

# -------------------------
# Training Loop
# -------------------------
for epoch in range(epochs):
    for i, (input_img, output_img) in enumerate(loader):
        input_img, output_img = input_img.to(device), output_img.to(device)

        # Train Discriminator
        opt_D.zero_grad()
        fake_output = G(input_img)

        real_preds = D(input_img, output_img)
        fake_preds = D(input_img, fake_output.detach())

        # Multi-scale discriminator loss
        loss_D = 0
        for real_pred, fake_pred in zip(real_preds, fake_preds):
            loss_D += -torch.mean(torch.log(torch.sigmoid(real_pred) + 1e-8) +
                                 torch.log(1 - torch.sigmoid(fake_pred) + 1e-8))

        loss_D.backward()
        opt_D.step()

        # Train Generator
        opt_G.zero_grad()
        fake_output = G(input_img)
        fake_preds = D(input_img, fake_output)

        # Multi-scale GAN loss
        loss_GAN = 0
        for fake_pred in fake_preds:
            loss_GAN += -torch.mean(torch.log(torch.sigmoid(fake_pred) + 1e-8))

        loss_recon = combined_loss(fake_output, output_img)
        loss_perceptual = perceptual_loss_fn(fake_output, output_img)
        loss_separation = part_separation_loss(fake_output)
        loss_gaps = gap_enforcement_loss(fake_output, output_img)

        # Adjusted weights: emphasize reconstruction and separation
        loss_G = 2 * loss_GAN + 100 * loss_recon + 5.0 * loss_perceptual + \
                 5.0 * loss_separation + 3.0 * loss_gaps

        loss_G.backward()
        opt_G.step()

    # Step schedulers once per epoch
    scheduler_G.step()
    scheduler_D.step()

    print(f"Epoch [{epoch+1}/{epochs}]  Loss_D: {loss_D.item():.4f}  Loss_G: {loss_G.item():.4f}")

# -------------------------
# Test Visualization
# -------------------------
test_input, test_output = next(iter(loader))
test_input = test_input.to(device)
with torch.no_grad():
    pred = G(test_input)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(test_input[0].cpu().squeeze(), cmap="gray")
plt.title("Test Input")
plt.subplot(1, 2, 2)
plt.imshow(pred[0].cpu().squeeze(), cmap="gray")
plt.title("Predicted Disassembled")
plt.show()

segmenting parts images with masks-I tried to do opencv segmentation and the outputs were not that good so moving on to deep learning based segmentation

after dl,switched to sam

In [None]:
# --- install dependencies ---
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python matplotlib

import torch, cv2
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# --- load SAM model ---
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

# download checkpoint if not already
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# --- load your test image ---
img_path = "/content/parts_coatstand.png"
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# --- automatic mask generator ---
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,         # controls density of prompts, increase for more parts
    pred_iou_thresh=0.88,       # confidence threshold
    stability_score_thresh=0.90,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=200    # filter out tiny specks
)

masks = mask_generator.generate(image)
print(f"Generated {len(masks)} masks")

# --- visualize ---
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    img = np.ones((sorted_anns[0]['segmentation'].shape[0],
                   sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()


geometric augmentation-NERF/diffusion

In [None]:
from google.colab import files
uploaded = files.upload()  # upload parts_dataset.zip

!unzip parts.zip -d parts
!ls parts



trying to get the sam generator to generate masks for all of the images--need more gpu!

In [None]:
# --- install deps (if not already) ---
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python matplotlib

import os, cv2, torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

# --- setup SAM ---
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,        # controls granularity (16=fast, 32=finer, 64=very fine but slow)
    pred_iou_thresh=0.88,
    stability_score_thresh=0.90,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=200   # filter tiny blobs
)

# --- function to process one image ---
def process_image(img_path, save_dir="sam_output", visualize=False):
    os.makedirs(save_dir, exist_ok=True)
    fname = os.path.splitext(os.path.basename(img_path))[0]

    # load + convert to RGB
    image = cv2.imread(img_path)
    if image is None:
        print(f"⚠️ Skipping {img_path} (not found)")
        return
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # generate masks
    masks = mask_generator.generate(image)
    print(f"{fname}: {len(masks)} masks")

    # convert to stacked (N,H,W) array
    H, W, _ = image.shape
    stacked = np.zeros((len(masks), H, W), dtype=np.uint8)
    for i, m in enumerate(masks):
        stacked[i] = m["segmentation"].astype(np.uint8)

    # save tensor
    np.save(os.path.join(save_dir, f"{fname}_parts.npy"), stacked)

    # optional visualization
    if visualize:
        plt.figure(figsize=(8,8))
        plt.imshow(image)
        for m in masks:
            mask = m["segmentation"]
            color = np.random.rand(3)
            plt.imshow(np.dstack([mask*color[0], mask*color[1], mask*color[2]])*0.5, alpha=0.5)
        plt.axis("off")
        plt.title(f"{fname}: {len(masks)} parts")
        plt.show()

# --- process a folder of images ---
input_dir = "/content/rl/parts"# change this to your folder
for file in os.listdir(input_dir):
    if file.lower().endswith((".png",".jpg",".jpeg")):
        process_image(os.path.join(input_dir, file), visualize=True)
