In [None]:
!pip install -q timm opencv-python matplotlib Pillow
!pip install datasets

import os
import requests
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import timm

In [None]:
def convert_to_binary_mask(mask_img):
    mask_rgb = np.array(mask_img.convert("RGB"))
    # Edited areas are black: [0, 0, 0]
    binary_mask = np.all(mask_rgb == [0, 0, 0], axis=-1).astype(np.uint8) * 255  # white = edited
    return Image.fromarray(binary_mask, mode="L")

In [None]:
def save_mask(raw_mask, index, convert_mask=False):
    resize_size = (512, 512)

    if convert_mask:
        mask = convert_to_binary_mask(raw_mask)
    else:
        mask = raw_mask.convert("L")
    mask = mask.resize(resize_size, Image.NEAREST)
    mask_path = f"data/masks/mask_{index:05d}.png"
    mask.save(mask_path)

In [None]:
def save_img(img, index, convert_mask=False):
    resize_size = (512, 512)
    img = img.convert("RGB").resize(resize_size, Image.BILINEAR)
    img_path = f"data/images/img_{index:05d}.jpg"

    img.save(img_path)

In [None]:
def save_pair(img, raw_mask, index, convert_mask=False):
    resize_size = (512, 512)

    img = img.convert("RGB").resize(resize_size, Image.BILINEAR)
    img_path = f"data/images/img_{index:05d}.jpg"
    img.save(img_path)

    if convert_mask:
        mask = convert_to_binary_mask(raw_mask)
    else:
        mask = raw_mask.convert("L")

    mask = mask.resize(resize_size, Image.NEAREST)
    mask_path = f"data/masks/mask_{index:05d}.png"
    mask.save(mask_path)

In [None]:
from itertools import islice
from datasets import load_dataset
import os
from PIL import Image
from tqdm import tqdm


dataset = load_dataset("BryanW/HumanEdit", split="train")


In [None]:
os.makedirs("/content/data/images", exist_ok=True)  # Create images directory
os.makedirs("/content/data/masks", exist_ok=True)   # Create masks directory

for i, entry in enumerate(tqdm(dataset)):
    save_pair(entry["INPUT_IMG"], entry["MASK_IMG"], i, convert_mask=True)

In [None]:
del dataset

In [None]:
pipe_dataset_stream = load_dataset("paint-by-inpaint/PIPE", split="train", streaming=True)
pipe_dataset = list(islice(pipe_dataset_stream, 4300))

In [None]:
counter = len(os.listdir("/content/data/images"))
for i, entry in enumerate(tqdm(pipe_dataset)):
    save_img(entry["target_img"], counter + i)

In [None]:
del pipe_dataset, pipe_dataset_stream

In [None]:
mask_dataset_stream = load_dataset("paint-by-inpaint/PIPE_Masks", split="train", streaming=True)
mask_dataset = list(islice(mask_dataset_stream, 4300))

In [None]:
counter = len(os.listdir("/content/data/masks"))
print(counter)
for i, entry in enumerate(tqdm(mask_dataset)):
    save_mask(entry["mask"], counter + i)

In [None]:
del mask_dataset, mask_dataset_stream

In [None]:
print(len(os.listdir("/content/data/images")))
print(len(os.listdir("/content/data/masks")))

In [None]:
import torch.nn.functional as F

class StabilityPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("convnext_tiny", pretrained=True, features_only=True)
        self.decoder = nn.Sequential(
            nn.Conv2d(768, 256, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)[-1]              # [B, 96, 8, 8]
        x = self.decoder(x)                  # [B, 1, 16, 16]
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
        return x

In [None]:
class StabilityDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_paths = sorted(os.listdir(image_dir))
        self.mask_paths = sorted(os.listdir(mask_dir))
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return min(len(self.image_paths), len(self.mask_paths))

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_paths[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_paths[idx])

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        return self.transform(image), self.transform(mask)

dataset = StabilityDataset("/content/data/images", "/content/data/masks")
loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = StabilityPredictor().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss_fn = nn.BCELoss(reduction='none')

for epoch in range(30):
    model.train()
    total_loss = 0
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        preds = model(imgs)

        weights = masks * 12 + 1
        loss = loss_fn(preds, masks)
        weighted_loss = (loss * weights).mean()

        optimizer.zero_grad()
        weighted_loss.backward()
        optimizer.step()
        total_loss += weighted_loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")
torch.save(model.state_dict(), "/content/drive/MyDrive/stability_predictor.pth")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = StabilityPredictor().to(device)
torch.save(model.state_dict(), "/content/drive/MyDrive/stability_predictor.pth")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
model.eval()
img_path = '/content/img_00952_wm.png'

img = Image.open(img_path).convert("RGB")
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
img = transform(img)

with torch.no_grad():
    pred = model(img.unsqueeze(0).to(device)).squeeze().cpu().numpy()

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(img.permute(1, 2, 0))
plt.title("Input Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(pred*10)
plt.title("Predicted Stability Mask")
plt.colorbar()
plt.axis("off")
plt.show()

**TESTING**

In [None]:
!pip install -q timm opencv-python matplotlib Pillow
!pip install datasets
import torch.nn.functional as F
from itertools import islice
from datasets import load_dataset
import os
from PIL import Image
from tqdm import tqdm
import requests
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import timm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def save_mask_test(raw_mask, index, convert_mask=False):
    resize_size = (512, 512)

    if convert_mask:
        mask = convert_to_binary_mask(raw_mask)
    else:
        mask = raw_mask.convert("L")
    mask = mask.resize(resize_size, Image.NEAREST)
    mask_path = f"test/masks/mask_{index:05d}.png"
    mask.save(mask_path)

In [None]:
def save_img_test(img, index, convert_mask=False):
    resize_size = (512, 512)
    img = img.convert("RGB").resize(resize_size, Image.BILINEAR)
    img_path = f"test/images/img_{index:05d}.jpg"

    img.save(img_path)

In [None]:
os.makedirs("/content/test/images", exist_ok=True)
os.makedirs("/content/test/masks", exist_ok=True)

In [None]:
from datasets import load_dataset

In [None]:
pipe_dataset_stream = load_dataset("paint-by-inpaint/PIPE", split="train", streaming=True)
pipe_dataset_test = list(islice(pipe_dataset_stream, 4300, 5000))

In [None]:
counter = 0
for i, entry in enumerate(tqdm(pipe_dataset_test)):
    save_img_test(entry["target_img"], counter + i)

In [None]:
del pipe_dataset_stream, pipe_dataset_test

In [None]:
mask_dataset_stream = load_dataset("paint-by-inpaint/PIPE_Masks", split="train", streaming=True)
mask_dataset_test = list(islice(mask_dataset_stream, 4300, 5000))

In [None]:
counter = 0
print(counter)
for i, entry in enumerate(tqdm(mask_dataset_test)):
    save_mask_test(entry["mask"], counter + i)

In [None]:
del mask_dataset_stream, mask_dataset_test

In [None]:
print(len(os.listdir("/content/test/images")))
print(len(os.listdir("/content/test/masks")))

In [None]:
import torch.nn.functional as F

class StabilityPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("convnext_tiny", pretrained=True, features_only=True)
        self.decoder = nn.Sequential(
            nn.Conv2d(768, 256, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)[-1]              # [B, 96, 8, 8]
        x = self.decoder(x)                  # [B, 1, 16, 16]
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
        return x

In [None]:
class StabilityTestDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_paths = sorted(os.listdir(image_dir))
        self.mask_paths = sorted(os.listdir(mask_dir))
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return min(len(self.image_paths), len(self.mask_paths))

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.image_dir, self.image_paths[idx])).convert("RGB")
        mask = Image.open(os.path.join(self.mask_dir, self.mask_paths[idx])).convert("L")
        return self.transform(image), self.transform(mask)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
model = StabilityPredictor().to(device)
model.load_state_dict(torch.load("/content/drive/MyDrive/stability_predictor.pth"))
model.eval()

# Load test data
test_dataset = StabilityTestDataset("/content/test/images", "/content/test/masks")
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import torch

img, mask = test_dataset[416]
with torch.no_grad():
    pred = model(img.unsqueeze(0).to(device)).squeeze().cpu().numpy()

flat = pred.flatten()
k = flat.size // 2
threshold = torch.topk(torch.tensor(flat), k=k, largest=True).values[-1].item()
binary_mask = (pred >= threshold).astype(float)
binary_mask = binary_mask.reshape(pred.shape)

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(img.permute(1, 2, 0))
plt.title("Input Image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(mask.squeeze(), cmap="gray")
plt.title("Ground Truth Mask")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(binary_mask, cmap="gray")
plt.title("Binary Stability Mask (1=Stable)")
plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import torch

img_path = '/content/img_00952_wm.png'

img = Image.open(img_path).convert("RGB")
transform = transforms.Compose([
    transforms.ToTensor()
])
img = transform(img)

with torch.no_grad():
    pred = model(img.unsqueeze(0).to(device)).squeeze().cpu().numpy()


plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(img.permute(1, 2, 0))
plt.title("Input Image")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(pred, cmap="gray")
plt.title("Binary Stability Mask (1=Stable)")
plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
total_correct_unstable = 0
total_edited = 0
batches = 0

model.eval()
with torch.no_grad():
    for imgs, masks in test_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        preds = model(imgs)

        B, _, H, W = preds.shape
        preds_flat = preds.view(B, -1)

        k = (H * W) // 2
        topk_vals, topk_indices = torch.topk(preds_flat, k=k, largest=True, dim=1)  # highest = most unstable

        binary_mask_flat = torch.zeros_like(preds_flat)
        binary_mask_flat.scatter_(1, topk_indices, 1.0)

        predicted_unstable = binary_mask_flat.view(B, 1, H, W)

        correct_unstable = ((predicted_unstable == 1) & (masks == 1)).sum(dim=(1, 2, 3)).float()
        total_edited_region = (masks == 1).sum(dim=(1, 2, 3)).float()

        total_correct_unstable += correct_unstable.sum()
        total_edited += total_edited_region.sum()
        batches += 1

recall_on_edited = total_correct_unstable / (total_edited + 1e-8)

print(f"% of Edited Region Correctly Predicted as Unstable: {recall_on_edited:.4f}")
