### Imports

In [1]:
import csv
import numpy as np
from collections import defaultdict
from PIL import Image

In [2]:
def rle_decode(mask_rle, shape): 
    '''     
    mask_rle: run-length as string formated (start length)     
    shape: (height,width) of array to return      
    Returns numpy array, 1 - mask, 0 - background      
    ''' 
    s = mask_rle.split() 
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] 
    starts -= 1 
    ends = starts + lengths 
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8) 
    for lo, hi in zip(starts, ends): 
         img[lo:hi] = 1 
    return img.reshape(shape)

In [85]:
def get_masks(file, size):
    segm = csv.reader(open(file, 'r'))
    masks = {}
    
    next(segm)
       
    for _ in range(size):
        try:
            curr = next(segm);
            if curr[0] not in masks:
                masks[curr[0]] = []
            if curr[1] != '':
                masks[curr[0]].append(rle_decode(curr[1], (768, 768)))
        except StopIteration:
            break
        
    return masks

In [4]:
# def generate_segmentation_mask(img):
#     segmentation_mask = np.zeros((768,768), dtype=np.uint8)
#     og_image = Image.open(f'airbus-ship-detection/train_v2/{img}')
#     adjusted_img = og_image.rotate(90).transpose(Image.FLIP_TOP_BOTTOM)
#     img_arr = np.asarray(adjusted_img).copy()
        
#     img_arr[mask == 1] = [255,255,255]

#     segmentation_image = Image.fromarray(img_arr)
#     segmentation_image.save(f'sample_imgs/{img}_segmentation_result.png')

In [86]:
from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
import torch
import os

class ShipDataset(torch.utils.data.Dataset):
    def __init__(self, root, size, transforms=None):
        self.root = root
        self.transforms = transforms
        masks = get_masks(os.path.join(root, 'train_ship_segmentations_v2.csv'), size)
        self.imgs = list(masks.keys())
        self.masks = list(masks.values())

    def __getitem__(self, idx):
        #load image
        img_name = os.path.join(self.root, 'train_v2', self.imgs[idx])
        image = read_image(img_name)
        
        num_objs = len(self.masks[idx])
        # get bounding boxes coordinates for each mask
        boxes = torch.zeros((0, 4), dtype=torch.float32)
        area = torch.zeros((0,), dtype=torch.float32)
        
        if num_objs != 0:
            masks = torch.stack([torch.tensor(m, dtype=torch.uint8) for m in self.masks[idx]])
            boxes = masks_to_boxes(masks)
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        
        img_id = idx
        
        image = tv_tensors.Image(image)
        
        target = {}
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format='xyxy', canvas_size=F.get_size(image))
        target["masks"] = tv_tensors.Mask(self.masks[idx])
        target["labels"] = labels
        target["image_id"] = img_id
        target["area"] = area
        
        if self.transforms is not None:
            image, target = self.transforms(image, target)
        
        return image, target
    
    def __len__(self):
        return len(self.imgs)

In [6]:
from torchvision.transforms import v2 as T

def get_transforms(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

In [32]:
from torchvision.models import resnet101 
import Utils

data_size = 4000

# model = resnet101(pretrained=True)
dataset = ShipDataset('airbus-ship-detection', data_size, get_transforms(train=True))      
data_loader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=2, 
    shuffle=True, 
    num_workers=0, 
    collate_fn=Utils.collate_fn
)

# For Training
images, targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
# output = model(images, targets)   # Returns losses and detections

# For inference
# model.eval()
# x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
# predictions = model(x)           # Returns predictions
# print(predictions[0])

RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.