In [26]:
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
from fastai.vision.all import untar_data, URLs
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

class CombinedDataset(Dataset):
    def __init__(self, transform=None):
        # Download and extract the Imagenette and Imagewoof datasets
        path_imagenette = untar_data(URLs.IMAGENETTE)
        path_imagewoof = untar_data(URLs.IMAGEWOOF)
        # Load both datasets
        self.dataset_imagenette = ImageFolder(root=path_imagenette, transform=transform)
        self.dataset_imagewoof = ImageFolder(root=path_imagewoof, transform=transform)

        # Combine the datasets and assign binary labels
        self.data = self.dataset_imagenette.samples + self.dataset_imagewoof.samples
        self.targets = [0] * len(self.dataset_imagenette.samples) + [1] * len(self.dataset_imagewoof.samples)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, _ = self.data[idx]
        label = self.targets[idx]
        img = Image.open(img_path).convert('RGB')
        if self.dataset_imagenette.transform is not None:
            img = self.dataset_imagenette.transform(img)
        return img, label


In [32]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Create an instance of your combined dataset
combined_dataset = CombinedDataset(
    transform=transform
)

# # Function to display images with their combined labels
# def show_combined_images(dataset, num_images=5):
#     fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
#     for i, ax in enumerate(axes):
#         img, label = dataset[i]
#         ax.imshow(img.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
#         ax.axis('off')
#         ax.set_title(f'Label: {label}')
#     plt.show()

# # Display some images from the combined dataset
# show_combined_images(combined_dataset)

# # Create a DataLoader
# loader = DataLoader(combined_dataset, batch_size=32, shuffle=True)

features = combined_dataset.data
targets = combined_dataset.targets

pos_class = [1]
neg_class = [0]
p_data_idx = np.where(np.isin(targets, pos_class))[0]
n_data_idx = np.where(np.isin(targets, neg_class) if neg_class 
                      else np.isin(targets, pos_class, invert=True))[0]

p_data = features[p_data_idx]

TypeError: only integer scalar arrays can be converted to a scalar index

In [31]:
# p_data_idx
n_data_idx

array([    0,     1,     2, ..., 13391, 13392, 13393])