In [46]:
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transform

In [50]:
class CustomImageDataset(Dataset):
    def __init__(self, path, transform=None, target_transform=None, data=None):
        if target_transform is None:
            target_transform = lambda x : x

        if data is not None:
            self.path=path
            self.transform=transform
            self.target_transform=target_transform
            self.final_dataset=data
            self.final_dataset_size=len(data)
            return

        self.path = path
        self.transform = transform
        self.target_transform = target_transform

        time_skips = list(range(1,13))

        self.final_dataset = []
        self.final_dataset_size = 0

        for skip in time_skips:
            result, dataset_size = self.read_dataset(time_skip = skip)
            self.final_dataset.extend(
                [(target_transform(start),
                target_transform(end),
                time_skip)
                for start, end, time_skip in result])
            self.final_dataset_size += dataset_size
        
        self.final_dataset

    def __len__(self):
        return self.final_dataset_size

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return CustomImageDataset(
                path=self.path, 
                transform=self.transform, 
                target_transform=self.target_transform, 
                data=self.final_dataset[idx.start:idx.stop:idx.step])
        
        elif isinstance(idx, int):
            return self.final_dataset[idx]
        else:
            raise TypeError("Invalid argument type.")
    
    def read_dataset(self, time_skip = 1):
        tuples = []
        for root, dirs, files in os.walk(self.path):
            for folder in dirs:
                img_folder = os.path.join(root, folder, 'images')
                if(os.path.exists(img_folder)):
                    img_files = os.listdir(img_folder)
                    img_files.sort()
                    if len(img_files) >= time_skip + 1:
                        for i in range(0, len(img_files) - time_skip, time_skip):
                            start_img_name = img_files[i]
                            end_img_name = img_files[i + time_skip]
                            start_img_path = os.path.join(img_folder, start_img_name)
                            end_img_path = os.path.join(img_folder, end_img_name)

                            start_img_parts = start_img_name.split("_")
                            end_img_parts = end_img_name.split("_")

                            start_year = start_img_parts[2]
                            start_month = start_img_parts[3]
                            end_year = end_img_parts[2]
                            end_month = end_img_parts[3]    

                            #rotation
                            if self.transform is not None:
                                start_img = Image.open(start_img_path)
                                end_img = Image.open(end_img_path)
                                start_img = self.transform(start_img)
                                end_img = self.transform(end_img)

                            tuple_time_skip = (int(end_year) - int(start_year)) * 12 + int(end_month) - int(start_month)
                            tuples.append((start_img, end_img, tuple_time_skip))
        dataset_size = len(tuples)
        return tuples, dataset_size

In [None]:
class Layer:
    def __init__(self, weights, biases):
        self.weights = weights
        self.biases = biases

In [49]:
def rotation_transform(image):
    angle = random.randint(-90, 90)
    return image.rotate(angle)

In [51]:
dataset = CustomImageDataset(path="Homework Dataset", transform=rotation_transform, target_transform=transform.ToTensor())

In [44]:
dataloader : DataLoader = DataLoader(dataset, batch_size=32, shuffle=True)

In [45]:
for starts, ends, time_skips in dataloader:
    print(time_skips)
    break

tensor([ 5,  5,  1,  6,  5,  8,  4,  3,  3,  3,  3,  1,  1,  2,  2,  3,  2,  3,
         1,  3,  1,  1,  4,  9,  4,  2,  2,  2,  9,  1, 10, 12])


In [55]:
train_ratio = 0.7
validate_ratio = 0.15
test_ratio = 0.15

train_split_index = int(len(dataset) * train_ratio)
val_split_index = int(len(dataset) * (train_ratio + validate_ratio))

train_dataset = dataset[:train_split_index]
val_dataset = dataset[train_split_index:val_split_index]
test_dataset = dataset[val_split_index:]

train_dataloader : DataLoader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader : DataLoader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataloader : DataLoader = DataLoader(test_dataset, batch_size=32, shuffle=True)

print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

2776
595
595


In [56]:
def run(epochs):
    for epoch in range(epochs):
        for starts, ends, time_skips in train_dataloader:
            print(time_skips)
            break

In [None]:
net = [
    Layer(torch.randn(16384, 200), torch.randn(200)),
]