In [2]:
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class InputImages(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe      # A pandas DataFrame with image paths and labels
        self.transform = transform      # Optional transformation pipeline

    def __len__(self):
        return len(self.dataframe)      # total number of samples

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]  # Get row at index `idx`
        image_path = row['frontview_funda_url_split']
        image = Image.open(image_path).convert("RGB")  # Open image and ensure it's RGB

        if self.transform:
            image = self.transform(image)  # Apply transformations (e.g., resize, ToTensor, normalize)

        label = row['woningtype']          # Assumes a 'label' column with class info
        return image, label


# waardes hangen af van gebruikte model (CNN, ResNet)
img_preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()              # hierna komt normalisatie, mean+std hangen van gebruikte model af (bijv: transforms.Normalize([0.5, 0.5, 0.5], [0.5,0.5,0.5])
])

In [3]:
# test code
import matplotlib.pyplot as plt

frontview_images = InputImages(df_sample_with_urls, img_preprocess)

fig = plt.figure(figsize=(12, 4))

for i in range(4):
    sample = frontview_images[i]
    image_tensor, label = sample
    image_np = image_tensor.numpy().transpose((1, 2, 0))


    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')

    ax.imshow(image_np)

plt.show()

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
# gebruik voor training en validatie

train_dataset = InputImages(train_df, transform=transform)  # train_df veranderen naar hoe de echte train set heet
val_dataset = InputImages(val_df, transform=transform)      # zeldfde geldt voor val_df

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)