In [3]:
import pandas as pd
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision import transforms

class spuriousDermDataset(VisionDataset):
    def __init__(self, file_path, split='train', transform=None):
        super(spuriousDermDataset, self).__init__(root=None, transform=transform)

        self.file_path = file_path
        self.split = split
        self.transform = transform

        # Load metadata from CSV
        self.metadata = pd.read_csv(file_path+'metadata.csv')

        # Filter metadata based on split
        self.metadata_for_split = self.metadata.iloc[[self.split in x for x in self.metadata['image']]].reset_index(drop=True)

    def __len__(self):
        return len(self.metadata_for_split)

    def __getitem__(self, index):
        img_path = self.file_path + self.metadata_for_split.iloc[index]['image']
        melanoma_label = self.metadata_for_split.iloc[index]['benign_malignant']
        group_label = self.metadata_for_split.iloc[index]['class']

        # Load image
        img = Image.open(img_path).convert('RGB')


        if self.transform:
            img = self.transform(img)

        return img, melanoma_label, group_label
    
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize image
    transforms.ToTensor(),           # Convert to tensor
])

filepath = "/data/scratch/wgerych/spurious_ISIC/" #PATH_TO_SPURIOUS_ISIC 
dataset = spuriousDermDataset(file_path=filepath, split='extra', transform=transform)

img, melanoma_label, group_label = dataset[0]

In [4]:
img

tensor([[[0.4392, 0.4941, 0.4902,  ..., 0.4980, 0.4980, 0.4510],
         [0.5765, 0.8353, 0.8431,  ..., 0.8745, 0.8667, 0.6000],
         [0.5882, 0.8431, 0.8471,  ..., 0.8863, 0.8784, 0.6000],
         ...,
         [0.5922, 0.8431, 0.8431,  ..., 0.8784, 0.8784, 0.6039],
         [0.5961, 0.8314, 0.8392,  ..., 0.8627, 0.8588, 0.5961],
         [0.4510, 0.4863, 0.4824,  ..., 0.4980, 0.4980, 0.4510]],

        [[0.4353, 0.4902, 0.4863,  ..., 0.4980, 0.4980, 0.4510],
         [0.5725, 0.8314, 0.8392,  ..., 0.8745, 0.8667, 0.6000],
         [0.5843, 0.8392, 0.8431,  ..., 0.8863, 0.8784, 0.6000],
         ...,
         [0.5725, 0.8039, 0.8078,  ..., 0.8784, 0.8784, 0.6039],
         [0.5765, 0.8118, 0.8078,  ..., 0.8627, 0.8588, 0.5961],
         [0.4392, 0.4745, 0.4706,  ..., 0.4980, 0.4980, 0.4510]],

        [[0.4314, 0.4824, 0.4784,  ..., 0.5020, 0.4980, 0.4510],
         [0.5647, 0.8235, 0.8314,  ..., 0.8745, 0.8667, 0.6000],
         [0.5765, 0.8314, 0.8353,  ..., 0.8863, 0.8784, 0.

In [5]:
melanoma_label

'malignant'

In [6]:
group_label

1