In [1]:
from PIL import Image
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torchvision.transforms import Compose, ToTensor

In [2]:
def rgb_to_hexint_image(rgb_image):
    """
    Converts an RGB image (H, W, 3) to a 2D image (H, W) with integer hex values.
    Example: [255, 0, 170] → 0xFF00AA → 16711850
    """
    r = rgb_image[:, :, 0].astype(np.uint32)
    g = rgb_image[:, :, 1].astype(np.uint32)
    b = rgb_image[:, :, 2].astype(np.uint32)
    hexint_image = (r << 16) + (g << 8) + b
    return hexint_image


def hexint_to_rgb_image(hexint_image):
    """
    Converts a 2D image (H, W) with integer hex values back to an RGB image (H, W, 3).
    Example: 16711850 → [255, 0, 170]
    """
    r = (hexint_image >> 16) & 0xFF
    g = (hexint_image >> 8) & 0xFF
    b = hexint_image & 0xFF
    rgb_image = np.stack([r, g, b], axis=-1).astype(np.uint8)
    return rgb_image

In [7]:
transform = Compose([
    Image.open,
    np.array,
    rgb_to_hexint_image,
    ToTensor(),
    torch.flatten,
])

In [17]:
current_dir = Path(".")

zeros_files = list(current_dir.glob("data/0/*"))
ones_files = list(current_dir.glob("data/1/*"))