In [1]:
from copy import deepcopy
import yaml
from pathlib import Path
import torch
import shutil
import matplotlib.pyplot as plt

%reload_ext autoreload
%autoreload 2

# Cut boxes

In [2]:
dataset_name = "dataset_2d_small_100dp inputs_gksi"
with open("paths.yaml", "r") as paths:
    paths = yaml.safe_load(paths)
    prepared1_dir = Path(paths["datasets_prepared_dir"]) / dataset_name

## Cut dataset into x boxes

In [5]:
number_boxes = 4
prepared_pieces_dir = Path(paths["datasets_prepared_dir"]) / f"{dataset_name} cut_{number_boxes}pieces separate_boxes"
prepared_pieces_dir.mkdir(parents=True, exist_ok=True)
for box in range(number_boxes):
    (prepared_pieces_dir / f"Inputs Box {box}").mkdir(parents=True, exist_ok=True)
    (prepared_pieces_dir / f"Label Box {box}").mkdir(parents=True, exist_ok=True)
shutil.copy(prepared1_dir / "info.yaml", prepared_pieces_dir / "info.yaml")

for datapoint in zip((prepared1_dir / "Inputs").iterdir(), (prepared1_dir / "Labels").iterdir()):
    input = torch.load(datapoint[0])
    label = torch.load(datapoint[1])
    name = datapoint[0].stem

    input_boxes = []
    label_boxes = []
    for i in range(number_boxes):
        len_box = input.shape[1] // number_boxes
        input_boxes.append(input[:, i * len_box : (i + 1) * len_box, :])
        label_boxes.append(label[:, i * len_box : (i + 1) * len_box, :])


    for i in range(number_boxes):
        torch.save(input_boxes[i], prepared_pieces_dir / f"Inputs Box {i}" / f"{name}.pt",)
        torch.save(label_boxes[i], prepared_pieces_dir / f"Label Box {i}" / f"{name}.pt",)

## Store boxes for 2 levels in 2 datasets

In [12]:
# prepare 1st level
prepared_dir_1stlevel = Path(paths["datasets_prepared_dir"]) / f"{dataset_name} cut_{number_boxes}pieces separate_boxes 1st level"
prepared_dir_1stlevel.mkdir(parents=True, exist_ok=True)

shutil.copy(prepared_pieces_dir / "info.yaml", prepared_dir_1stlevel / "info.yaml")
shutil.copytree(prepared_pieces_dir / "Inputs Box 0", prepared_dir_1stlevel / "Inputs")
shutil.copytree(prepared_pieces_dir / "Label Box 0", prepared_dir_1stlevel / "Labels")

In [30]:
# prepare 2nd level
prepared_dir_2ndlevel = Path(paths["datasets_prepared_dir"]) / f"{dataset_name} cut_{number_boxes}pieces separate_boxes 2nd level gkt"
prepared_dir_2ndlevel.mkdir(parents=True, exist_ok=True)
(prepared_dir_2ndlevel / "Inputs").mkdir(parents=True, exist_ok=True)
(prepared_dir_2ndlevel / "Labels").mkdir(parents=True, exist_ok=True)

info = yaml.safe_load(open(prepared_dir_1stlevel / "info.yaml", "r"))
info_g    = deepcopy(info["Inputs"]["Pressure Gradient [-]"])
info_k    = deepcopy(info["Inputs"]["Permeability X [m^2]"])
info["Inputs"] = deepcopy(info["Labels"])
info["Inputs"]["Pressure Gradient [-]"] = info_g
info["Inputs"]["Permeability X [m^2]"] = info_k
info["Inputs"]["Temperature [C]"]["index"] = 2
# assert indices of inputs double
idx_g = info["Inputs"]["Pressure Gradient [-]"]["index"]
idx_k = info["Inputs"]["Permeability X [m^2]"]["index"]
idx_t = info["Inputs"]["Temperature [C]"]["index"]
assert  idx_g != idx_k, "indices of inputs double"
assert  idx_g != idx_t, "indices of inputs double"
assert  idx_k != idx_t, "indices of inputs double"

yaml.safe_dump(info, open(prepared_dir_2ndlevel / "info.yaml", "w"))

for box in range(number_boxes-1):
    for file_in_temp in (prepared_pieces_dir / f"Label Box {box}").iterdir():
        file_id = int(file_in_temp.stem.split("_")[1])
        new_id = file_id + (box) * 1000
        temp_in = torch.load(file_in_temp)
        file_inputs = prepared_pieces_dir / f"Inputs Box {box}" / f"RUN_{file_id}.pt"
        g_in = torch.load(file_inputs)[idx_g]
        k_in = torch.load(file_inputs)[idx_k]
        inputs = torch.zeros([3, *g_in.shape])
        inputs[idx_g] = g_in
        inputs[idx_k] = k_in
        inputs[idx_t] = temp_in
        
        torch.save(inputs, prepared_dir_2ndlevel / "Inputs" / f"RUN_{new_id}.pt")

        file_label = prepared_pieces_dir / f"Label Box {box+1}" / f"RUN_{file_id}.pt"
        shutil.copy(file_label, prepared_dir_2ndlevel / "Labels" / f"RUN_{new_id}.pt")