In [5]:
import os
import json
import numpy as np
import cv2
from tqdm import tqdm
from PIL import Image
from skimage.draw import polygon

def create_multiclass_mask(image_info, annotations, mask_output_dir, category_map):
    height, width = image_info['height'], image_info['width']
    mask = np.zeros((height, width), dtype=np.uint8)
    image_id = image_info['id']

    for ann in annotations:
        if ann['image_id'] != image_id:
            continue

        category_id = ann['category_id']
        if category_id not in category_map:
            continue

        category_value = category_map[category_id]
        seg = ann.get('segmentation', [])

        if not seg or not isinstance(seg, list):
            continue

        for poly in seg:
            if len(poly) < 6:
                continue  # invalid polygon
            try:
                poly = np.array(poly).reshape((-1, 2))
                rr, cc = polygon(poly[:, 1], poly[:, 0], shape=mask.shape)
                mask[rr, cc] = category_value
            except Exception as e:
                print(f"Error processing polygon: {e}")
                continue

    # Save with same name as image, but in mask folder
    file_stem = os.path.splitext(image_info['file_name'])[0]
    mask_filename = f"{file_stem}.png"
    Image.fromarray(mask).save(os.path.join(mask_output_dir, mask_filename))

def main(json_path, image_dir, mask_output_dir, limit=6000):
    with open(json_path, 'r') as f:
        coco = json.load(f)

    os.makedirs(mask_output_dir, exist_ok=True)

    images = coco['images'][:limit]
    annotations = coco['annotations']
    categories = coco['categories']

    # category_id → consecutive integers starting from 1
    category_map = {cat['id']: idx + 1 for idx, cat in enumerate(categories)}
    print(f"Category map: {category_map}")

    for image_info in tqdm(images, desc="Processing COCO images"):
        create_multiclass_mask(image_info, annotations, mask_output_dir, category_map)

if __name__ == '__main__':
    json_path = 'annotations/instances_train2017.json'
    image_dir = 'train2017'
    mask_output_dir = 'masks'
    main(json_path, image_dir, mask_output_dir, limit=8000)


Category map: {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 27: 25, 28: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 44: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 65: 60, 67: 61, 70: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 82: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 90: 80}


Processing COCO images: 100%|██████████| 8000/8000 [11:15:38<00:00,  5.07s/it]      


In [11]:
import os
import json
import cv2
import numpy as np
from PIL import Image
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


In [19]:
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.psi(F.relu(g1 + x1))
        return x * psi


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=91):
        super(UNet, self).__init__()

        self.down1 = ConvBlock(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)

        self.down2 = ConvBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.down3 = ConvBlock(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.down4 = ConvBlock(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.bridge = ConvBlock(512, 1024)

        self.att4 = AttentionBlock(512, 512, 256)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.up_conv4 = ConvBlock(1024, 512)

        self.att3 = AttentionBlock(256, 256, 128)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.up_conv3 = ConvBlock(512, 256)

        self.att2 = AttentionBlock(128, 128, 64)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up_conv2 = ConvBlock(256, 128)

        self.att1 = AttentionBlock(64, 64, 32)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up_conv1 = ConvBlock(128, 64)

        self.final = nn.Conv2d(64, out_ch, kernel_size=1)

    def forward(self, x):
        c1 = self.down1(x)
        p1 = self.pool1(c1)

        c2 = self.down2(p1)
        p2 = self.pool2(c2)

        c3 = self.down3(p2)
        p3 = self.pool3(c3)

        c4 = self.down4(p3)
        p4 = self.pool4(c4)

        bridge = self.bridge(p4)

        up4 = self.up4(bridge)
        att4 = self.att4(g=up4, x=c4)
        merge4 = torch.cat([up4, att4], dim=1)
        c_up4 = self.up_conv4(merge4)

        up3 = self.up3(c_up4)
        att3 = self.att3(g=up3, x=c3)
        merge3 = torch.cat([up3, att3], dim=1)
        c_up3 = self.up_conv3(merge3)

        up2 = self.up2(c_up3)
        att2 = self.att2(g=up2, x=c2)
        merge2 = torch.cat([up2, att2], dim=1)
        c_up2 = self.up_conv2(merge2)

        up1 = self.up1(c_up2)
        att1 = self.att1(g=up1, x=c1)
        merge1 = torch.cat([up1, att1], dim=1)
        c_up1 = self.up_conv1(merge1)

        return self.final(c_up1)


In [15]:
class COCOSegmentation(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_paths = sorted(glob(os.path.join(image_dir, "*.jpg")))
        self.mask_paths = sorted(glob(os.path.join(mask_dir, "*.png")))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx])

        image = np.array(image)
        mask = np.array(mask, dtype=np.int64)

        if self.transform:
            aug = self.transform(image=image, mask=mask)
            image = aug['image']
            mask = aug['mask']

        image = T.ToTensor()(image)
        return image, torch.tensor(mask, dtype=torch.long)


In [20]:
def train_model(model, dataloader, criterion, optimizer, device, epochs=5):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for images, masks in tqdm(dataloader):
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader):.4f}")


In [None]:
# Paths to your dataset (images and masks generated previously)
image_dir = "train2017"
mask_dir = "masks"

# Transforms
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
])

# Dataset and DataLoader
dataset = COCOSegmentation(image_dir, mask_dir, transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

# Model, loss, optimizer
model = UNet(in_channels=3, num_classes=91).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 25

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for images, masks in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1} loss: {epoch_loss / len(dataloader):.4f}")


  0%|          | 0/3442 [00:00<?, ?it/s]

In [None]:
# Inference on a single batch
model.eval()
with torch.no_grad():
    for images, masks in dataloader:
        images = images.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
        break  # Show only first batch

# Visualize results
for i in range(4):
    plt.figure(figsize=(10, 3))
    plt.subplot(1, 3, 1)
    plt.imshow(images[i].permute(1, 2, 0).cpu())
    plt.title("Input Image")

    plt.subplot(1, 3, 2)
    plt.imshow(masks[i].cpu())
    plt.title("Ground Truth")

    plt.subplot(1, 3, 3)
    plt.imshow(preds[i])
    plt.title("Prediction")

    plt.tight_layout()
    plt.show()
