In [1]:
import torch
from torchvision.transforms import v2
from torch.utils.data import DataLoader, Dataset
import os

from img_methods import *

In [2]:
H, W = 200, 200

img_aug = v2.Compose([
    v2.Grayscale(num_output_channels=1),
    v2.RandomResizedCrop(size=(200, 200), antialias=True),
    v2.RandomRotation(degrees=(0, 360)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.PILToTensor(),
])

class DinoNuggieDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root = root_dir
        self.transform = transform
        self.file_list = []
        
        for folder in os.listdir(root_dir):
            if folder == 'mis':
                continue
            folder_fp = os.path.join(root_dir, folder) 
            for img in os.listdir(folder_fp):
                self.file_list.append(os.path.join(folder_fp, img))
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        img = Image.open(self.file_list[idx])
        if self.transform:
            img = self.transform(img)
        return img
    
    
current_dir = os.getcwd()
project_dir = os.path.dirname(current_dir)
images_dir = os.path.join(project_dir, 'images')
    
normal_dataset = DinoNuggieDataset(root_dir = images_dir, transform=img_aug)
normal_dataloader = DataLoader(normal_dataset, batch_size=1, shuffle=True)


In [5]:
normal_flattened = [torch.flatten(img).numpy() for batch in normal_dataloader for img in batch]