In [1]:
# Imports
import pathlib
import csv

import random
import numpy as np
#from tqdm.auto import tqdm
from PIL import Image

#from IPython.display import display
import matplotlib.pyplot as plt

import torch
#import torch.nn as nn
#import torch.optim as optim

#from torchvision.utils import make_grid
import torchvision.transforms as transforms
#from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, TensorDataset

In [None]:
# CODE FROM MOBIN, related to this task
# class CustomImageDataset(Dataset):
#     def __init__(self, root_dir, transform=None):
#         """
#         Args:
#             root_dir (string): Directory with all the subdirectories containing images and CSV files.
#             transform (callable, optional): Optional transform to be applied on a sample.
#         """
#         self.root_dir = root_dir
#         self.transform = transform
#         self.image_paths = []
#         self.labels = []
#         self._load_data()

#     def _load_data(self):
#         # Iterate over each subdirectory
#         for sub_dir in os.listdir(self.root_dir):
#             sub_dir_path = os.path.join(self.root_dir, sub_dir)
#             if os.path.isdir(sub_dir_path):
#                 # Load the CSV file
#                 csv_file = [file for file in os.listdir(sub_dir_path) if file.endswith('.csv')]
#                 if len(csv_file) == 1:
#                     csv_path = os.path.join(sub_dir_path, csv_file[0])
#                     df = pd.read_csv(csv_path)

#                     # Iterate over each row in the CSV
#                     for _, row in df.iterrows():
#                         image_path = os.path.join(sub_dir_path, row['filename'])
#                         label = row['label']
#                         if os.path.isfile(image_path):
#                             self.image_paths.append(image_path)
#                             self.labels.append(label)
#                 else:
#                     print(f"Error: More than one or no CSV files found in {sub_dir_path}")

#     def __len__(self):
#         return len(self.image_paths)

#     def __getitem__(self, idx):
#         if torch.is_tensor(idx):
#             idx = idx.tolist()

#         img_path = self.image_paths[idx]
#         image = Image.open(img_path).convert('RGB')
#         label = self.labels[idx]

#         if self.transform:
#             image = self.transform(image)

#         return image, torch.tensor(label, dtype=torch.float32)

In [2]:
# @markdown `shuffle-spli and dataset constructor`
def shuffle_and_split_data(X, y, seed, split_percent=0.2):
  """
  Helper function to shuffle and split data, yoinked from W1D3 tutorial

  Args:
    X: torch.tensor
      Input data
    y: torch.tensor
      Corresponding target variables
    seed: int
      Set seed for reproducibility

  Returns:
    X_test: torch.tensor
      Test data [20% of X]
    y_test: torch.tensor
      Labels corresponding to above mentioned test data
    X_train: torch.tensor
      Train data [80% of X]
    y_train: torch.tensor
      Labels corresponding to above mentioned train data
  """
  # Set seed for reproducibility
  torch.manual_seed(seed)
  # Number of samples
  N = X.shape[0]
  # Shuffle data
  shuffled_indices = torch.randperm(N)   # Get indices to shuffle data, could use torch.randperm
  X = X[shuffled_indices]
  y = y[shuffled_indices]

  # Split data into train/test
  test_size = int(split_percent * N)    # Assign test datset size using 20% of samples
  X_test = X[:test_size]
  y_test = y[:test_size]
  X_train = X[test_size:]
  y_train = y[test_size:]

  return X_test, y_test, X_train, y_train


def get_datasets(images_pathfull, image_name_prefix, CSVlabels_pathfull, N_samples, percent_test, percent_valid, transform=None):
    """
    Helper function to get train, test, and validation datasets, do note that the images come out rescaled to [0,1]

    Args:
        images_pathfull: string
            path to the folder that includes all the images
        image_name_prefix: string
            the first part of the image name, with full filename being image_name_prefix+"INTEGER"+".png"
        CSVlabels_pathfull: string
            path to a .CSV file that includes all the labels
        N_samples: integer
            total number of samples to procure
        transform: torch.transform
            additional transformations to apply to images
        percent_test: float
            ratio of N_samples to be moved to test dataset, test dataset size is N_samples*percent_test
        percent_valid: float
            ratio of N_samples to be moved to valid dataset, valid dataset size is N_samples*percent_valid
      
    Returns:
        train_data: torch.TensorDataset
            train data, comes out shuffled, total number of samples = N_samples*(1-percent_test-percent_valid) 
        test_data: torch.TensorDataset
            test data, comes out shuffled, total number of samples = N_samples*percent_test
        valid_data: torch.TensorDataset
            validation data, comes out shuffled, total number of samples = N_samples*percent_valid
    """  
    #LABELS...
    #load the labels
    with open(images_pathfull, newline='') as csvfile:
        CSVlabels = list(csv.reader(csvfile))
    assert len(labels) >= N_samples
    #take N_samples from the total dataset
    shuffled_indices = torch.randperm(int(len(labels)))[:N_samples]
    labels = [CSVlabels[i] for i in shuffled_indices]
    #encode strings into numbers and store in a np array
    #what follows is a very dumb implementation of the aforementioned
    string_size=len(labels[0][0])
    np_labels = np.empty((1,string_size))
    for dat in labels:
        i=0
        temp = np.empty((1,string_size))
        for char in dat[0]:
            score = 4*(ord(char)-96) #get ascii code for char, rescale to 0-100
            if score == -256: score = 0 #set spaces to 0s (otherwise, they end up very negative)
            temp[0,i] = score
            i+=1
        np_labels = np.vstack((np_labels, temp))
    np_labels = np_labels[1:,:]
    #finally transform the labels to tensors
    trans_totensor = transforms.ToTensor()
    y_tensor = totensor_trans(np_labels)[0]

    #IMAGES...
    #prepare transforms, then load all the images  
    trans_to_img = transforms.PILToTensor()
    #load first image, then concat all the others
    im = Image.open( os.path.join(images_pathfull, image_name_prefix + f"{shuffled_indices[0]}.png") )
    im_tensor = 1.-trans_to_img(im).float()/255.
    for i in shuffled_indices[1:]:
        im = Image.open( os.path.join(images_pathfull, image_name_prefix + {i} + ".png") )
        im_tensor = torch.cat((im_tensor, 1.-trans_to_img(im).float()/255. ),0)
    #apply additional transform is speficied
    if transform != None: im_tensor=transform(im_tensor)
        
    #shuffle and split our data
    X_test, y_test, X_train, y_train = shuffle_and_split_data( X=im_tensor, y=y_tensor, seed=SEED,
                                                               split_percent=percent_test )
    X_test, y_test, X_valid, y_valid = shuffle_and_split_data( X=X_test,    y=y_test, seed=SEED,
                                                               split_percent= percent_valid/(1-percent_test) )
   
    
    test_data =  TensorDataset(X_test,  y_test)
    train_data = TensorDataset(X_train, y_train)
    valid_data = TensorDataset(X_valid, y_valid)
    return train_data, test_data, valid_data


In [None]:
#usage code example...

train_data, test_data, valid_data = get_datasets(
    images_pathfull="/home/workstation319/Desktop/python/NMA codes and stuffs/project/temp", 
    image_name_prefix="base_img", 
    CSVlabels_pathfull="/home/workstation319/Desktop/python/NMA codes and stuffs/project/labels.csv", 
    N_samples=int(1e4), 
    percent_test=0.1, 
    percent_valid=0.2, 
    transform=None
)
# we should be able to use predefined torch dataloader:
batch_size = 256
g_seed = torch.Generator()
g_seed.manual_seed(SEED)

train_loader = DataLoader(train_data,
                          batch_size=batch_size,
                          drop_last=True,
                          shuffle=True, 
                          worker_init_fn=seed_worker,
                          generator=g_seed,
                          )
test_loader = DataLoader(test_data,
                          batch_size=batch_size,
                          drop_last=True,
                          shuffle=True,
                          worker_init_fn=seed_worker,
                          generator=g_seed,
                          )
valid_loader = DataLoader(valid_data,
                          batch_size=batch_size,
                          drop_last=True,
                          shuffle=True,
                          worker_init_fn=seed_worker,
                          generator=g_seed,
                          )