In [1]:
import torch
import torchvision

from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import transforms
from torchvision.transforms import v2

import os
import json
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

In [2]:
class DatasetReg(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.transform = transform

        self.list_name_file = os.listdir(path)
        if "cords.json" in self.list_name_file:
            self.list_name_file.remove("cords.json")
        self.len_dataset = len(self.list_name_file)
        with open(os.path.join(self.path, "cords.json"), "r") as f:
            self.dict_cords = json.load(f)
            
    def __len__(self):
        return self.len_dataset
        
    def __getitem__(self, index):
        name_file = self.list_name_file[index]
        path_img = os.path.join(self.path, name_file)
        
        img = Image.open(path_img)
        cords = torch.tensor(self.dict_cords[name_file], dtype= torch.float32)

        if self.transform is not None:
            img = self.transform(img)
        return img, cords
        

In [3]:
transform = v2.Compose(
    [
        v2.ToImage(),
        # v2.Grayscale(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=(0.5, ), std=(0.5, ))
    ]
)

In [4]:
dataset = DatasetReg(path = r"C:\Users\voron\IT\pytorch\dataset", transform=transform)

In [5]:
for i, (sample, target) in enumerate(dataset):
    if i == 3:
        break
    
    print(f"номер bach: {i + 1}")
    print(f"размер sample: ", sample.shape)
    print(f"размер target: ", target.shape)
    print(f"min-- {sample.min()}, max-- {sample.max()}")
    
    
    

номер bach: 1
размер sample:  torch.Size([1, 64, 64])
размер target:  torch.Size([2])
min-- -1.0, max-- 0.7098040580749512
номер bach: 2
размер sample:  torch.Size([1, 64, 64])
размер target:  torch.Size([2])
min-- -1.0, max-- 0.6627452373504639
номер bach: 3
размер sample:  torch.Size([1, 64, 64])
размер target:  torch.Size([2])
min-- -1.0, max-- 0.7019609212875366


In [6]:
train_data, valid_data, test_data = random_split(dataset, [0.7, 0.2, 0.1])

In [7]:
bach_size = 32
train_loader = DataLoader(train_data, bach_size, shuffle=True)
valid_loader = DataLoader(valid_data, bach_size, shuffle=False)
test_loader = DataLoader(test_data, bach_size, shuffle=False)


In [8]:
for i, (sample, target) in enumerate(train_loader):
    if i == 3:
        break
    
    print(f"номер bach: {i + 1}")
    print(f"размер sample: ", sample.shape)
    print(f"размер target: ", target.shape)
    print(f"min-- {sample.min()}, max-- {sample.max()}")
    
    
    

номер bach: 1
размер sample:  torch.Size([32, 1, 64, 64])
размер target:  torch.Size([32, 2])
min-- -1.0, max-- 0.7960785627365112
номер bach: 2
размер sample:  torch.Size([32, 1, 64, 64])
размер target:  torch.Size([32, 2])
min-- -1.0, max-- 0.8039216995239258
номер bach: 3
размер sample:  torch.Size([32, 1, 64, 64])
размер target:  torch.Size([32, 2])
min-- -1.0, max-- 0.7960785627365112
