In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from sklearn.model_selection import train_test_split
from torchvision import transforms

imageTransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
maskTransform = transforms.Compose([transforms.ToTensor()])

class LandslideDataset(Dataset):
    train_list = None
    train_mask = None
    test_list = None
    test_mask = None

    @classmethod
    def splitData(cls, dataset_dir):
        image_dir = os.path.join(dataset_dir, "img")
        mask_dir = os.path.join(dataset_dir, "mask")

        all_images = sorted(os.listdir(image_dir))
        train_set, test_set = train_test_split(all_images, test_size=.2, random_state=42)
        # issue here with mask/img/...
        train_mask = [os.path.join(mask_dir, img_name) for img_name in train_set]
        test_mask = [os.path.join(mask_dir, img_name) for img_name in test_set]

        cls.train_list = train_set
        cls.train_mask = train_mask
        cls.test_list = test_set
        cls.test_mask = test_mask

    def __init__(self, isTrain):
        if(isTrain):
            self.img_list = LandslideDataset.train_list
            self.mask_list = LandslideDataset.train_mask
        else:
            self.img_list = LandslideDataset.test_list
            self.mask_list = LandslideDataset.test_mask
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, index):
        img_dir = self.img_list[index]
        mask_dir = self.mask_list[index]

        img = Image.open(img_dir).convert("RGB")
        mask = Image.open(mask_dir).convert("L")

        imageTransform(img)
        maskTransform(mask)

        return img, mask


LandslideDataset.splitData("data/Wenchuan")

train_dataset = LandslideDataset(isTrain=True)
test_dataset = LandslideDataset(isTrain=False)

batch_size = 4

trainLoader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
testLoader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False,num_workers=2)


hello world
