In [1]:
# IMPORT PACKAGES

import os.path
import torch
import matplotlib.pyplot as plt
import functions as fs
import PRIMME
from pathlib import Path

# Set PyTorch Device to CUDA if available:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
# Define test case parameters
nsteps = 800
grain_shape = "grain" # Alternatives include "circle", "hex", "square"
grain_sizes = [[512, 512], 512] # Also tested for 257x257, 1024x1024, 2048x2048, 2400x2400
ic_shape = f"{grain_shape}({grain_sizes[0][0]}_{grain_sizes[0][1]}_{grain_sizes[1]})" if grain_shape != "hex" else "hex"

# Define filename for potential saved data
filename_test = f"{ic_shape}.pickle"
path_load = Path('./data') / filename_test

# Load or generate initial conditions and misorientation data
if path_load.is_file():
    data_dict = fs.load_picke_files(load_dir=Path('./data'), filename_save=filename_test)
    ic, ea, miso_array, miso_matrix = data_dict["ic"], data_dict["ea"], data_dict["miso_array"], data_dict["miso_matrix"]
else:
    ic, ea, miso_array, miso_matrix = fs.generate_train_init(filename_test, grain_shape, grain_sizes, device)

data/grain(512_512_512).pickle start to be loaded

data/grain(512_512_512).pickle has been created



In [3]:
# Define training set and model locations (only the trainset is used in this script)
trainset = "./data/trainset_spparks_sz(257x257)_ng(256-256)_nsets(200)_future(4)_max(100)_kt(0.66)_cut(0).h5"
# model_location = "./data/model_dim(2)_sz(17_17)_lr(5e-05)_reg(1)_ep(1000)_kt(0.66)_cut(0).h5"
# fp_primme = "./data/primme_shape(grain(512_512_512))_model_dim(2)_sz(17_17)_lr(5e-05)_reg(1)_ep(1000)_kt(0.66)_cut(0).h5"

In [None]:
### Train PRIMME using the above training set from SPPARKS
model_location = PRIMME.train_primme(trainset, n_step=nsteps, n_samples=200, mode="Single_Step", num_eps=100, dims=2, obs_dim=17, act_dim=17, lr=5e-5, reg=1, pad_mode="circular", if_plot=False)

<KeysViewHDF5 ['euler_angles', 'ims_energy', 'ims_id', 'miso_array', 'num_features_17', 'num_features_25', 'num_features_35', 'num_features_49']>


  features = my_unfoldNd(local_energy.float(), obs_dim, pad_mode=pad_mode).T.reshape((np.prod(size),)+(obs_dim,)*(len(size)-1))


In [None]:
# Run PRIMME model
ims_id, fp_primme = PRIMME.run_primme(ic, ea, miso_array, miso_matrix, nsteps=nsteps, ic_shape=ic_shape, modelname=model_location, pad_mode='circular', if_plot=False)


In [None]:
# Generate plots
fs.compute_grain_stats(fp_primme)
fs.make_videos(fp_primme, ic_shape=ic_shape) #saves to 'plots'
fs.make_time_plots(fp_primme, ic_shape=ic_shape) #saves to 'plots'