<a href="https://colab.research.google.com/github/aina-c/Octopus-arm/blob/main/U_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!git clone https://github.com/pedropro/TACO.git



Cloning into 'TACO'...
remote: Enumerating objects: 740, done.[K
remote: Total 740 (delta 0), reused 0 (delta 0), pack-reused 740 (from 1)[K
Receiving objects: 100% (740/740), 107.77 MiB | 17.73 MiB/s, done.
Resolving deltas: 100% (498/498), done.


In [9]:

#!mkdir /content/TACO/data/images

!python3 TACO/download.py #open download and put a default /content/TACO/data/annotations.json


Note. If for any reason the connection is broken. Just call me again and I will start where I left.
Finished


In [10]:
import os
import json
import numpy as np
import cv2
from pycocotools.coco import COCO
from tqdm import tqdm


# Paths
dataset_dir = "/content/TACO/data"
annotation_file = os.path.join(dataset_dir, "annotations.json")
image_dir = os.path.join(dataset_dir, "images")
mask_dir = os.path.join(dataset_dir, "masks")

# Create directory for masks
os.makedirs(mask_dir, exist_ok=True)

# Load annotations
coco = COCO(annotation_file)
category_ids = coco.getCatIds()
categories = coco.loadCats(category_ids)
category_mapping = {cat["id"]: i for i, cat in enumerate(categories)}  # Map category ID to index

# Process images
for img_id in tqdm(coco.getImgIds()):
    img_info = coco.loadImgs(img_id)[0]
    img_path = os.path.join(image_dir, img_info["file_name"])
    mask_path = os.path.join(mask_dir, img_info["file_name"].replace(".jpg", ".png"))
    #print(img_path,img_info)

    # Load image and create blank mask
    img = cv2.imread(img_path)
    #print(img)  #seems to work but kinda sus
    h, w = img_info["height"], img_info["width"]
    mask = np.zeros((h, w), dtype=np.uint8)

    # Load annotations for this image
    ann_ids = coco.getAnnIds(imgIds=img_id)
    anns = coco.loadAnns(ann_ids)

    for ann in anns:
        cat_id = ann["category_id"]
        mask_idx = category_mapping[cat_id]  # Convert to sequential class index

        # Decode segmentation mask
        if "segmentation" in ann:
            for seg in ann["segmentation"]:
                poly = np.array(seg).reshape(-1, 2).astype(np.int32)
                cv2.fillPoly(mask, [poly], mask_idx)

    # Save mask
    cv2.imwrite(mask_path, mask)

loading annotations into memory...
Done (t=0.22s)
creating index...
index created!


100%|██████████| 1500/1500 [00:02<00:00, 551.79it/s]


In [13]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
import cv2
import os

# Define transformations
transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # ImageNet normalization
    ToTensorV2()
])

# Custom dataset
class TrashSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.image_filenames[idx].replace(".jpg", ".png"))

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # Single-channel mask

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"].long()  # Convert to long tensor for loss calculation

        return image, mask

# Initialize dataset and dataloader
train_dataset = TrashSegmentationDataset("/content/TACO/data/", "/content/TACO/data/masks/", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)

In [None]:
!pip install segmentation-models-pytorch torch torchvision albumentations


In [15]:

import torch
import segmentation_models_pytorch as smp

NUM_CLASESS =len(category_mapping)

# Define the model
model = smp.Unet(
    encoder_name="resnet34",        # Backbone model
    encoder_weights="imagenet",     # Load pretrained ImageNet weights
    in_channels=3,                  # Number of input channels (RGB images)
    classes=10              # Number of classes for segmentation
)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 269MB/s]


Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [16]:
import torch.nn as nn
import segmentation_models_pytorch.losses as smp_losses

# Define loss functions
loss_fn = smp_losses.DiceLoss(mode="multiclass") + nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

TypeError: unsupported operand type(s) for +: 'DiceLoss' and 'CrossEntropyLoss'

In [None]:
from tqdm import tqdm

def train(model, train_loader, loss_fn, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        loop = tqdm(train_loader, leave=True)

        for images, masks in loop:
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(outputs, masks)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            loop.set_description(f"Epoch {epoch+1}/{epochs}")
            loop.set_postfix(loss=epoch_loss / len(train_loader))

    print("Training Complete.")

# Train the model
train(model, train_loader, loss_fn, optimizer, device, epochs=10)

In [None]:
# Save trained model
torch.save(model.state_dict(), "unet_trash_segmentation.pth")

# Load for inference
model.load_state_dict(torch.load("unet_trash_segmentation.pth"))
model.eval()

# Function to visualize prediction
import matplotlib.pyplot as plt
import numpy as np

def predict(image_path, model, transform, device):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    augmented = transform(image=image)
    input_tensor = augmented["image"].unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        mask_pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original Image")

    plt.subplot(1, 2, 2)
    plt.imshow(mask_pred, cmap="jet")
    plt.title("Predicted Mask")
    plt.show()

# Test on an image
predict("test_image.jpg", model, transform, device)