In [1]:
import torch
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import os
        
imgs_path = './data/hw3_mycocodata_img_comp_zlib.h5'
masks_path = './data/hw3_mycocodata_mask_comp_zlib.h5'
labels_path = './data/hw3_mycocodata_labels_comp_zlib.npy'
bboxes_path = './data/hw3_mycocodata_bboxes_comp_zlib.npy'
paths = [imgs_path, masks_path, labels_path, bboxes_path]
    
img_path, mask_path, label_path, bbox_path = paths

# Loading the data
with h5py.File(img_path, 'r') as file:
    gb_images = file[list(file.keys())[0]][:30]
with h5py.File(mask_path, 'r') as file:
    gb_masks = file[list(file.keys())[0]][:100]
gb_labels = np.load(label_path, allow_pickle=True)[:30]
gb_bboxes = np.load(bbox_path, allow_pickle=True)[:30]

In [None]:
## Author: Lishuo Pan 2020/4/18

import torch
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import os

class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, paths):
        # Paths to the data
        img_path, mask_path, label_path, bbox_path = paths
        
        # Loading the data
        # with h5py.File(img_path, 'r') as file:
        #     self.images = file[list(file.keys())[0]][:]
        # with h5py.File(mask_path, 'r') as file:
        #     self.masks = file[list(file.keys())[0]][:]
        # self.labels = np.load(label_path, allow_pickle=True)
        # self.bboxes = np.load(bbox_path, allow_pickle=True)
        
        self.images = gb_images
        self.labels = gb_labels
        self.bboxes = gb_bboxes
        self.masks = gb_masks

        # Match masks to images
        self.match_masks_to_images()

    def calculate_mask_bounding_box(self, mask):
        """Calculate the bounding box of the non-zero elements in a binary mask."""
        rows = np.any(mask, axis=0)
        cols = np.any(mask, axis=1)
        ymin, ymax = np.where(rows)[0][[0, -1]]
        xmin, xmax = np.where(cols)[0][[0, -1]]
        return xmin, ymin, xmax, ymax

    def match_masks_to_images(self):
        """Match masks to images based on the bounding boxes."""
        self.matched_images = []
        self.matched_masks = []
        self.matched_labels = []
        self.matched_bboxes = []


        for i, label in enumerate(self.labels):  # Assuming labels[i] corresponds to image i
            n_obj = len(label)
            curr_masks = []
            for j in range(n_obj):
                
                mask = self.masks[10]
                
                curr_masks.append(mask)
            
            curr_masks = (curr_masks)
            self.matched_masks.append(curr_masks)
            self.matched_images.append(self.images[i])
            self.matched_labels.append(self.labels[i])
            self.matched_bboxes.append(self.bboxes[i])
            
            if i == 10:
                break
            
        self.matched_masks = np.array(self.matched_masks, dtype=object)
        self.matched_images = np.array(self.matched_images)
        self.matched_labels = np.array(self.matched_labels, dtype=object)
        self.matched_bboxes = np.array(self.matched_bboxes, dtype=object)
            
        print(len(self.matched_images), len(self.matched_masks), len(self.masks))



    # output:
        # transed_img
        # label
        # transed_mask
        # transed_bbox
        
    def __getitem__(self, index):
        # TODO: __getitem__

        # check flag
        # print(self.images.shape)
        img = self.matched_images[index]
        mask = self.matched_masks[index]
        bbox = self.matched_bboxes[index]
        label = self.matched_labels[index]

    
        tensor = torch.tensor(img)
        
        if tensor.dtype != torch.float32:
            tensor = tensor.float()

        tensor = F.interpolate(tensor.unsqueeze(0), size=(800, 1066), mode='bilinear', align_corners=False).squeeze(0)

        tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        # print(tensor.shape)
 
        
        tensor = F.pad(tensor, (11, 11, 0, 0), "constant", 0)

        

        # print(list(tensor.shape))
        assert list(tensor.shape) == [3, 800, 1088]
        
        assert bbox.shape[0] == mask.shape[0]
        
        return tensor, label, mask, bbox
    
    
    def __len__(self):
        return len(self.matched_images)
    

    # This function take care of the pre-process of img,mask,bbox
    # in the input mini-batch
    # input:
        # img: 3*300*400
        # mask: 3*300*400
        # bbox: n_box*4
    def pre_process_batch(self, img, mask, bbox):
        # TODO: image preprocess

        # check flag
        assert img.shape == (3, 800, 1088)
        assert bbox.shape[0] == mask.squeeze(0).shape[0]
        return img, mask, bbox


class BuildDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, batch_size, shuffle, num_workers):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        
        # Initialize the DataLoader with the custom collect function
        self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=self.shuffle,
                                     num_workers=self.num_workers, collate_fn=self.collect_fn)


    # output:
        # img: (bz, 3, 800, 1088)
        # label_list: list, len:bz, each (n_obj,)
        # transed_mask_list: list, len:bz, each (n_obj, 800,1088)
        # transed_bbox_list: list, len:bz, each (n_obj, 4)
        # img: (bz, 3, 300, 400)
    def collect_fn(self, batch):
        transed_img_list = []
        label_list = []
        transed_mask_list = []
        transed_bbox_list = []
        
        for transed_img, label, transed_mask, transed_bbox in batch:
            transed_img_list.append(transed_img)
            label_list.append(label)
            transed_mask_list.append(transed_mask)
            transed_bbox_list.append(transed_bbox)
            
        return torch.stack(transed_img_list, dim=0), label_list, transed_mask_list, transed_bbox_list


    def loader(self):
        return self.dataloader

## Visualize debugging
if __name__ == '__main__':
    # file path and make a list
    imgs_path = './data/hw3_mycocodata_img_comp_zlib.h5'
    masks_path = './data/hw3_mycocodata_mask_comp_zlib.h5'
    labels_path = './data/hw3_mycocodata_labels_comp_zlib.npy'
    bboxes_path = './data/hw3_mycocodata_bboxes_comp_zlib.npy'
    paths = [imgs_path, masks_path, labels_path, bboxes_path]
    # load the data into data.Dataset
    dataset = BuildDataset(paths)

    ## Visualize debugging
    # --------------------------------------------
    # build the dataloader
    # set 20% of the dataset as the training data
    full_size = len(dataset)
    train_size = int(full_size * 0.8)
    test_size = full_size - train_size
    # random split the dataset into training and testset
    # set seed
    torch.random.manual_seed(1)
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    # push the randomized training data into the dataloader

    batch_size = 2
    train_build_loader = BuildDataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    train_loader = train_build_loader.loader()
    test_build_loader = BuildDataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = test_build_loader.loader()

    mask_color_list = ["jet", "ocean", "Spectral", "spring", "cool"]
    # loop the image
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    for iter, data in enumerate(train_loader, 0):

        img, label, mask, bbox = [data[i] for i in range(len(data))]
        # check flag
        assert img.shape == (batch_size, 3, 800, 1088)
        assert len(mask) == batch_size

        label = [label_img.to(device) for label_img in label]
        mask = [mask_img.to(device) for mask_img in mask]
        bbox = [bbox_img.to(device) for bbox_img in bbox]


        # plot the origin img
        for i in range(batch_size):
            ## TODO: plot images with annotations
            plt.savefig("./testfig/visualtrainset"+str(iter)+".png")
            plt.show()

        if iter == 10:
            break

