In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
"""
Generate a CSV file containing matched pairs of geometries and flow field file names.
"""

import os
path = Path("spacio_training_2/processed")
all_files = os.listdir(path)
geom_files = [x for x in all_files if "geom" in x]
cfd_files = [x for x in all_files if "_2.npy" in x]
# cfd_files.pop() # Drop mask file

print(f"Geometry files: {len(geom_files)}, Flow files: {len(cfd_files)}")

df = pd.DataFrame({'X':geom_files, 'Y':cfd_files})

def match(X, Y):
    X = X.split("_")
    Y = Y.split("_")
    if X[0] == Y[0] and X[1] == Y[2]:
        return True
    return False

df['test'] = df.apply(lambda row: match(row.X, row.Y), axis=1)

if len(df.query("test == False")) > 0:
    raise Exception("Train/Test files matched incrorectly!")

else:
    df = df.drop(columns=['test'])
    df.to_csv(path/'../test_train.csv', index=False)

Geometry files: 2608, Flow files: 2608


In [93]:
"""
Create the data set. This uses a csv file with the names of all of the geometries
and flow patterns to index and find the appropriate file when called.
"""
class TestTrainData(Dataset):
    def __init__(self, root_dir, dir_file) -> None:
        self.root_dir = root_dir
        self.file_directory = pd.read_csv(dir_file)

    def __len__(self):
        return len(self.file_directory)
    
    def __getitem__(self, idx):
        x = np.load(path/self.file_directory.iloc[idx].X)
        y = np.load(path/self.file_directory.iloc[idx].Y)
        x_batch = torch.tensor(x).permute(2, 0, 1)
        y_batch = torch.tensor(y).permute(2, 0, 1)
        return x_batch, y_batch
    
dataset = TestTrainData(path, path/"../test_train_matrix.csv")

# 80/20 test/train split.
train_set, test_set = random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])

# On Windows num_workers must be set to 0.
train_loader = DataLoader(train_set,
                          batch_size=1, shuffle=True,
                          num_workers=0)

test_loader = DataLoader(test_set,
                         batch_size=1, shuffle=True,
                         num_workers=0)

for i, data in enumerate(train_loader, 0):
    X, Y = data
    print(X.shape, Y.shape)
    break


torch.Size([1, 3, 1024, 1024]) torch.Size([1, 3, 1024, 1024])
