In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm_notebook
from models import Generator
from minkowski import compute_minkowski
from utils import fix_random_seed, postprocess_cube, two_point_correlation

seeds = np.random.choice(range(1000 * 3), size=1000)
checkpoint_path = Path('../../experiments') / 'original-berea' / 'berea_generator_epoch_24.pth'
covariance_dir = Path('../../experiments') / 'original-berea' / 'covariance_stats'
covariance_dir.mkdir(exist_ok=True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

net_g = Generator(
    img_size=64,
    z_dim=512,
    num_channels=1,
    num_filters=64
).to(device)

In [2]:
from collections import OrderedDict
new_dict = OrderedDict()
pre_dict = torch.load(checkpoint_path)

new_dict['net.0.weight'] = pre_dict['main.0.weight']
new_dict['net.1.weight'] = pre_dict['main.1.weight']
new_dict['net.1.bias'] = pre_dict['main.1.bias']
new_dict['net.1.running_mean'] = pre_dict['main.1.running_mean']
new_dict['net.1.running_var'] = pre_dict['main.1.running_var']
# new_dict['net.1.num_batches_tracked'] = net_g.state_dict()['net.1.num_batches_tracked']
new_dict['net.3.weight'] = pre_dict['main.3.weight']
new_dict['net.4.weight'] = pre_dict['main.4.weight']
new_dict['net.4.bias'] = pre_dict['main.4.bias']
new_dict['net.4.running_mean'] = pre_dict['main.4.running_mean']
new_dict['net.4.running_var'] = pre_dict['main.4.running_var']
# new_dict['net.4.num_batches_tracked'] = net_g.state_dict()['net.4.num_batches_tracked']
new_dict['net.6.weight'] = pre_dict['main.6.weight']
new_dict['net.7.weight'] = pre_dict['main.7.weight']
new_dict['net.7.bias'] = pre_dict['main.7.bias']
new_dict['net.7.running_mean'] = pre_dict['main.7.running_mean']
new_dict['net.7.running_var'] = pre_dict['main.7.running_var']
# new_dict['net.7.num_batches_tracked'] = net_g.state_dict()['net.7.num_batches_tracked']
new_dict['net.9.weight'] = pre_dict['main.9.weight']
new_dict['net.10.weight'] = pre_dict['main.10.weight']
new_dict['net.10.bias'] = pre_dict['main.10.bias']
new_dict['net.10.running_mean'] = pre_dict['main.10.running_mean']
new_dict['net.10.running_var'] = pre_dict['main.10.running_var']
# new_dict['net.10.num_batches_tracked'] = net_g.state_dict()['net.10.num_batches_tracked']
new_dict['net.12.weight'] = pre_dict['main.12.weight']

In [3]:
_ = net_g.load_state_dict(new_dict)

In [5]:
data = {
    'V': [],
    'S': [],
    'B': [],
    'Xi': []
}
for seed in tqdm_notebook(seeds, desc="Generate iteration"):
    _ = fix_random_seed(seed)
    noise = torch.randn(1, 512, 1, 1, 1, device=device)
    cube = net_g(noise).squeeze().detach().cpu()
    cube = cube.mul(0.5).add(0.5).numpy()
    cube = postprocess_cube(cube)
    cube = np.pad(cube, ((1, 1), (1, 1), (1, 1)), mode='constant', constant_values=0)
    v, s, b, xi = compute_minkowski(cube)
    data['V'].append(v)
    data['S'].append(s)
    data['B'].append(b)
    data['Xi'].append(xi)

    two_point_covariance = {}
    grain_value = cube.max()
    for i, direct in enumerate(["x", "y", "z"]):
        two_point_direct = two_point_correlation(cube, i, var=grain_value)
        two_point_covariance[direct] = two_point_direct
    # phase averaging
    direct_covariances = {}
    for direct in ["x", "y", "z"]:
        direct_covariances[direct] = np.mean(np.mean(two_point_covariance[direct], axis=0), axis=0)
    # covariance storage
    covariance_df = pd.DataFrame(direct_covariances)
    covariance_df.to_csv(covariance_dir / ("seed_" + str(seed) + ".csv"), index=False)

df = pd.DataFrame(data)
df.to_csv(Path('../../experiments') / 'original-berea' / 'seeds_analyze.csv', index=False)

HBox(children=(IntProgress(value=0, description='Generate iteration', max=1000, style=ProgressStyle(descriptio…


