In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install wandb



In [3]:
# !cp /content/drive/MyDrive/OSCD.zip /content/
# !unzip /content/OSCD.zip -d /content/OSCD/

In [4]:
import zipfile

zip_file_path = '/content/drive/MyDrive/OSCD.zip'
dataset_folder = '/content/drive/MyDrive/OSCD/'

if len(os.listdir(dataset_folder)) == 0:
  with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
      zip_ref.extractall(dataset_folder)

  print(f"Unzipped to: {dataset_folder}")

In [5]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_folder = os.path.join(dataset_folder, 'coco_carton/oneclass_carton/images/train2017')
val_folder = os.path.join(dataset_folder, 'coco_carton/oneclass_carton/images/val2017')
train_annotation = os.path.join(dataset_folder, 'coco_carton/oneclass_carton/annotations/instances_train2017.json')
val_annotation = os.path.join(dataset_folder, 'coco_carton/oneclass_carton/annotations/instances_val2017.json')

In [6]:
import torch
from torchvision.datasets import CocoDetection
from PIL import Image
from torchvision.tv_tensors import Mask
from torchvision.transforms.functional import to_tensor
from skimage.draw import polygon as sk_polygon
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F

class OSCDDataset(CocoDetection):
    def __init__(self, img_folder, ann_file, transforms=None):
        super().__init__(img_folder, ann_file, transforms=None)
        self._transforms = transforms

    def __len__(self) -> int:
       return super().__len__()

    def __getitem__(self, idx):
        img, anns = super().__getitem__(idx)

        if anns:
          labels = []
          areas = []
          iscrowd = []
          masks = []
          boxes = []
          ids = []
          for ann in anns:
              x, y, w, h = list(map(int, ann['bbox']))
              boxes.append([x, y, x + w, y + h])
              labels.append(ann['category_id'])
              areas.append(ann['area'])
              iscrowd.append(ann['iscrowd'])
              mask = self.get_mask(ann['segmentation'], img.size[1], img.size[0])
              masks.append(mask)
              ids.append(ann['id'])

          labels = torch.tensor(labels, dtype=torch.int64)
          areas = torch.tensor(areas, dtype=torch.float16)
          iscrowd = torch.tensor(iscrowd, dtype=torch.uint8)
          boxes = torch.tensor(boxes, dtype=torch.int64)
          boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
          masks = torch.stack(masks, dim=0)
          ids = torch.tensor(ids, dtype=torch.int64)
          img_id = torch.tensor(ann['image_id'], dtype=torch.int64)

          target = {
              "boxes": boxes,
              "labels": labels,
              "image_id": img_id,
              "ids": ids,
              "area": areas,
              "iscrowd": iscrowd,
              "masks": Mask(masks),
          }
          img = tv_tensors.Image(img)
          if self._transforms is not None:
            img, target = self._transforms(img, target)
          return img, target
        else:
          # target = {
          #   'boxes': torch.empty((0, 4), dtype=torch.float32),
          #   'labels': torch.empty(0, dtype=torch.int64),
          #   'masks': torch.empty((0, img.size[1], img.size[0]), dtype=torch.bool),
          #   'area': torch.empty(0, dtype=torch.float32),
          #   'iscrowd': torch.zeros(0, dtype=torch.int64),
          # }
          target = {}

          return img, target




    def get_mask(self, segmentation, height, width):
        mask = torch.zeros((height, width), dtype=torch.bool)
        poly_x = segmentation[0][::2]
        poly_y = segmentation[0][1::2]
        rr, cc = sk_polygon(poly_y, poly_x, shape=(height, width))
        mask[rr, cc] = 1
        return mask


In [7]:
from torchvision.transforms import v2

def collate_fn(batch):
    images, targets = [], []
    for (image, target) in batch:
      if not target:
        continue
      else:
        images.append(image)
        targets.append(target)
    return images, targets

def get_transforms(train=False):
  transforms = []
  if train:
    transforms.append(v2.RandomZoomOut(p=0.7, side_range=(1.0, 1.8), fill=0))
    # transforms.append(v2.RandomHorizontalFlip(p=0.5))
    # transforms.append(v2.RandomVerticalFlip(p=0.5))
    # transforms.append(v2.RandomRotation(degrees=(-180, 180)))
    transforms.append(v2.RandomPerspective(distortion_scale=0.5, p=0.5))
    # transforms.append(v2.RandomCrop(size=(512, 512), pad_if_needed=True))
    # transforms.append(v2.RandomIoUCrop(0.5))
    transforms.append(v2.SanitizeBoundingBoxes())

  transforms.append(v2.ToDtype(torch.float32, scale=True))
  transforms.append(v2.ToPureTensor())
  return v2.Compose(transforms)

train_dataset = OSCDDataset(train_folder, train_annotation, transforms=get_transforms(train=True))
val_dataset = OSCDDataset(val_folder, val_annotation, transforms=get_transforms())

train_dataset_small = torch.utils.data.Subset(train_dataset, list(range(2000)))
val_dataset_small = torch.utils.data.Subset(val_dataset, list(range(100)))

loading annotations into memory...
Done (t=3.22s)
creating index...
index created!
loading annotations into memory...
Done (t=0.48s)
creating index...
index created!


In [8]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.rpn import RegionProposalNetwork
from torchvision.models.detection.rpn import AnchorGenerator


def create_model(num_classes=2,
                 rpn_fg_iou_thresh=0.7,
                 rpn_bg_iou_thresh=0.3,
                 rpn_batch_size_per_image=256,
                 rpn_positive_fraction=0.5,
                 rpn_nms_thresh=0.7,
                 rpn_pre_nms_top_n_train=2000,
                 rpn_pre_nms_top_n_test=1000,
                 rpn_post_nms_top_n_train=2000,
                 rpn_post_nms_top_n_test=1000,
                 rpn_score_thresh=0
                 ):

  model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
  in_features = model.roi_heads.box_predictor.cls_score.in_features
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

  in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
  model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)
  old_rpn = model.rpn

  post_nms_top_n = {"training": rpn_pre_nms_top_n_train, "testing": rpn_pre_nms_top_n_test}
  pre_nms_top_n = {"training": rpn_pre_nms_top_n_train, "testing": rpn_pre_nms_top_n_test}

  # anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  # aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  # anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
  anchor_generator = old_rpn.anchor_generator
  new_rpn = RegionProposalNetwork(
          anchor_generator=anchor_generator,
          head=old_rpn.head,
          fg_iou_thresh=rpn_fg_iou_thresh,
          bg_iou_thresh=rpn_bg_iou_thresh,
          batch_size_per_image=rpn_batch_size_per_image,
          positive_fraction=rpn_positive_fraction,
          nms_thresh=rpn_nms_thresh,
          post_nms_top_n=post_nms_top_n,
          pre_nms_top_n=pre_nms_top_n,
          score_thresh=rpn_score_thresh
          )
  model.rpn = new_rpn
  return model

In [9]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [10]:
import math
from tqdm.auto import tqdm

from torch.amp import autocast, GradScaler

def run_epoch(model, dataloader, optimizer, device, scaler, is_training):
    model.train()
    progress_bar = tqdm(total=len(dataloader), desc="Train" if is_training else "Valid")  # Initialize a progress bar
    epoch_total_loss = 0.
    epoch_losses = {
      'loss_classifier': 0,
      'loss_box_reg': 0.,
      'loss_mask': 0.,
      'loss_objectness': 0.,
      'loss_rpn_box_reg': 0.}
    num_batches = 0
    for batch_id, (images, targets) in enumerate(dataloader):
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        if len(targets) == 0:
          continue
        images = [image.to(device) for image in images]
        num_batches += 1

        with autocast(device_type=device.type, dtype=torch.bfloat16):
            if is_training:
                losses = model(images, targets)
            else:
                with torch.no_grad():
                    losses = model(images, targets)

            total_loss = sum([loss for loss in losses.values()])

        if is_training:
            optimizer.zero_grad()
            if scaler:
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                optimizer.step()

        epoch_losses = {k: v.item() + epoch_losses[k] for k, v in losses.items()}
        epoch_total_loss += total_loss.item()
        progress_bar_dict = dict(avg_loss=epoch_total_loss/(num_batches+1))
        progress_bar.set_postfix(progress_bar_dict)
        progress_bar.update()
        if is_training:
          assert not math.isnan(total_loss.item()) and math.isfinite(total_loss.item()), "Loss is NaN or infinite. Stopping training."
    progress_bar.close()
    epoch_losses = {k: v/(num_batches + 1) for k, v in epoch_losses.items()}
    return epoch_losses

In [15]:
sweep_config = {
    'method': 'random'
    }

metric = {
    'name': 'valid/loss',
    'goal': 'minimize'
    }

sweep_config['metric'] = metric
parameters_dict = {
    'epochs': {
        'values': [60]
        },
    'lr': {
        'values': [1e-5]
        },
    'weight_decay': {
          'values': [1e-2]
        },
    'bs': {
          'values': [2]
        },
    'save_model_every': {
          'values': [10]
        },
    'scheduler': {
          'values': ['step']
        },
    'step_size': {
          'values': [500]
        },
    'gamma': {
          'values': [0.1]
        },
    'optimizer_type': {
          'values': ['adamw']
        },
    'rpn_fg_iou_thresh': {
          'values': [0.8] # default 0.7. increase it. be more strict to detect true positives because of overlap
        },
    'rpn_bg_iou_thresh': {
      'values': [0.4] # default 0.3. increase it. increase the number of background detection
        },
    'rpn_batch_size_per_image': {
      'values': [256] # default 256
          },
    'rpn_positive_fraction': {
      'values': [0.5] # default 0.5
          },
    'rpn_nms_thresh': {
      'values': [0.6] # default 0.7 reduce it. this will reduce overlap
          },
    'rpn_pre_nms_top_n_train': {
      'values': [500] # default 2000
          },
    'rpn_pre_nms_top_n_test': {
      'values': [500] # default 1000
          },
    'rpn_post_nms_top_n_train': {
      'values': [250] # default 2000
          },
    'rpn_post_nms_top_n_test': {
      'values': [250] # default 1000
          },
    'rpn_score_thresh': {
      'values': [0] # default 0
          }
    }
sweep_config['parameters'] = parameters_dict

In [16]:
import wandb
sweep_id = wandb.sweep(sweep_config, project="box_segmentation")

Create sweep with ID: 51ihy9r0
Sweep URL: https://wandb.ai/abdelrahman-farhan/box_segmentation/sweeps/51ihy9r0


In [17]:
from torch.optim import AdamW
import datetime
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


num_workers = 12
train_data = train_dataset
val_data = val_dataset

def train(config=None):
  with wandb.init(config=config):
    config = wandb.config

    model = create_model(
        rpn_fg_iou_thresh=config.rpn_fg_iou_thresh,
        rpn_bg_iou_thresh=config.rpn_bg_iou_thresh,
        rpn_batch_size_per_image=config.rpn_batch_size_per_image,
        rpn_positive_fraction=config.rpn_positive_fraction,
        rpn_nms_thresh=config.rpn_nms_thresh,
        rpn_pre_nms_top_n_train=config.rpn_pre_nms_top_n_train,
        rpn_pre_nms_top_n_test=config.rpn_pre_nms_top_n_test,
        rpn_post_nms_top_n_train=config.rpn_post_nms_top_n_train,
        rpn_post_nms_top_n_test=config.rpn_post_nms_top_n_test,
        rpn_score_thresh=config.rpn_score_thresh
    )
    model.to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer_type = config.optimizer_type

    if optimizer_type == 'sgd':
      optimizer = torch.optim.SGD(params, lr=config.lr, momentum=0.9, weight_decay=config.weight_decay)
    elif optimizer_type == 'adamw':
      optimizer = torch.optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay)

    if config.scheduler == 'step':
      lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
    elif config.scheduler == 'linear':
      lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=config.epochs)
    elif config.scheduler == 'cyclic':
      lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.epochs)


    train_loader = DataLoader(train_data, batch_size=config.bs, shuffle=True, collate_fn=collate_fn, num_workers=num_workers)
    val_loader = DataLoader(val_data, batch_size=config.bs, shuffle=False, collate_fn=collate_fn, num_workers=num_workers)

    scaler = GradScaler()
    for epoch in tqdm(range(config.epochs), desc="Epochs"):

        train_losses = run_epoch(model, train_loader, optimizer, device, scaler, is_training=True)

        with torch.no_grad():
            valid_losses = run_epoch(model, val_loader, None, device, scaler, is_training=False)
        lr_scheduler.step()

        train_losses = {f'train/{k}': v for k, v in train_losses.items()}
        wandb.log(train_losses)
        train_loss = sum(train_losses.values())
        wandb.log({'train/loss': train_loss})

        valid_losses = {f'valid/{k}': v for k, v in valid_losses.items()}
        wandb.log(valid_losses)
        valid_loss = sum(valid_losses.values())
        wandb.log({'valid/loss': valid_loss})

        wandb.log({'lr': lr_scheduler.get_last_lr()[0]})
        model_name = f'model_{wandb.run.name}_{wandb.run.sweep_id}_{epoch+1}.pth'
        if (epoch+1) % config.save_model_every == 0:
          model_path = os.path.join(dataset_folder, 'model', model_name)
          torch.save(model.state_dict(), model_path)

In [18]:
wandb.agent(sweep_id, train, count=1)

[34m[1mwandb[0m: Agent Starting Run: 5a6ffiyi with config:
[34m[1mwandb[0m: 	bs: 2
[34m[1mwandb[0m: 	epochs: 60
[34m[1mwandb[0m: 	gamma: 0.1
[34m[1mwandb[0m: 	lr: 1e-05
[34m[1mwandb[0m: 	optimizer_type: adamw
[34m[1mwandb[0m: 	rpn_batch_size_per_image: 256
[34m[1mwandb[0m: 	rpn_bg_iou_thresh: 0.4
[34m[1mwandb[0m: 	rpn_fg_iou_thresh: 0.8
[34m[1mwandb[0m: 	rpn_nms_thresh: 0.6
[34m[1mwandb[0m: 	rpn_positive_fraction: 0.5
[34m[1mwandb[0m: 	rpn_post_nms_top_n_test: 250
[34m[1mwandb[0m: 	rpn_post_nms_top_n_train: 250
[34m[1mwandb[0m: 	rpn_pre_nms_top_n_test: 500
[34m[1mwandb[0m: 	rpn_pre_nms_top_n_train: 500
[34m[1mwandb[0m: 	rpn_score_thresh: 0
[34m[1mwandb[0m: 	save_model_every: 10
[34m[1mwandb[0m: 	scheduler: step
[34m[1mwandb[0m: 	step_size: 500
[34m[1mwandb[0m: 	weight_decay: 0.01


Epochs:   0%|          | 0/60 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

Valid:   0%|          | 0/500 [00:00<?, ?it/s]

Train:   0%|          | 0/3701 [00:00<?, ?it/s]

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
