# Visualize CryinGAN training

In [None]:
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt
from ase.io import read, write
from ase.visualize import view
import numpy as np
from glob import glob

In [None]:
EXPERIMENT = "../CCGAN/results/4-with-dist-extreme-weight"
EXPERIMENT = "../CCGAN/results/3-with-dist-high-weight-long"
EXPERIMENT = "../CCGAN/results/3-2-with-dist-high-weight-only-84"
EXPERIMENT_NAME = EXPERIMENT.split("/")[-1]+"/" # Get the experiment name from the path


In [None]:
all_runs = pd.read_csv(EXPERIMENT+"/losses.csv") # With distance discriminator 2
run_starts = all_runs["epoch"] == "epoch"

# Filter out the rows that are not part of the runs
# all_runs = all_runs[~run_starts]

# Add "run" column
all_runs["run"] = 0

run_idx = 0

for i in range(1,len(all_runs)):
    if run_starts[i]:
        all_runs.loc[i:, "run"] = run_idx
        run_idx += 1
    else:
        all_runs.loc[i, "run"] = run_idx

print("Number of runs: ", all_runs["run"].max() + 1)
# Remove the first row of each run
all_runs = all_runs[~run_starts]
all_runs["epoch"] = all_runs["epoch"].astype(int)
all_runs.set_index("epoch", inplace=True)
all_runs = all_runs.astype(float)

# Read and visualize losses csv

fig = plt.figure(figsize=(20,10))
sns.lineplot(data=all_runs, figure=fig, marker="o")
# plt.grid(which="both")
plt.title(EXPERIMENT_NAME)

In [None]:
lr_df = pd.read_csv(f"{EXPERIMENT}/learning_rate.csv")
lr_df.set_index("epoch", inplace=True)
lr_df = lr_df.astype(float)

# Read and visualize learning rate csv

fig, ax = plt.subplots(figsize=(20,10))
sns.lineplot(data=lr_df, marker="o")

## Check on the generated structures

Numpy way from training-time generation

In [None]:
ref_struct = read("../data/processed/samples/phi-0.84/samples.extxyz", index=1, format="extxyz")
view(ref_struct,  viewer="x3d", show_unit_cell=True, repeat=(1,1,1))

atoms = ref_struct.copy()
radii = atoms.get_array('rmt')

plt.figure()
plt.scatter(
    atoms.get_scaled_positions()[:,0],
    atoms.get_scaled_positions()[:,1],
    c=radii,
    s=radii*20,
    alpha=0.5,
)

In [None]:

latest_generated_sample = glob(f"{EXPERIMENT}/gen_coords_*.npy")
latest_generated_sample.sort(key=lambda x: int(x.split("/")[-1].split("_")[2].split(".")[0]))
latest_generated_sample = latest_generated_sample[-1]
print("Latest generated sample: ", latest_generated_sample)

file = latest_generated_sample
# Or set manuall
# file = f"{EXPERIMENT}/gen_coords_60.npy"

np_coords = np.load(file)
atoms = ref_struct.copy()

# Pick a couple of structures to visualize
idxs = [0, 1, 2]

radii = atoms.get_array('rmt')

for idx in idxs:
    atoms.set_scaled_positions(np_coords[3,0]),
    plt.figure()
    plt.scatter(
        atoms.get_scaled_positions()[:,0],
        atoms.get_scaled_positions()[:,1],
        c=radii,
        s=radii*20,
        alpha=0.5,
    )

## Make a video


In [None]:
atoms_list = []
idx = 2 # Pick one generated sample

file_list = glob(f"{EXPERIMENT}/gen_coords_*.npy")
file_list = sorted(file_list, key=lambda x: int(x.split("_")[-1].split(".")[0])) # Hack to sort in the right order

for file in file_list:
    np_coords = np.load(file)
    atoms = ref_struct.copy()
    atoms.set_scaled_positions(np_coords[idx].squeeze())

    atoms_list.append(atoms)
    # view(atoms, viewer="x3d", show_unit_cell=True, repeat=(1,1,1))
len(file_list)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.animation import FuncAnimation


# Pick a good cmap
plt.set_cmap("viridis")


lr_df = pd.read_csv(f"{EXPERIMENT}/learning_rate.csv")
lr_df.set_index("epoch", inplace=True)
epochs = lr_df.index.values
gen_lr = lr_df['generator_lr'].values
coord_lr = lr_df['coord_disc_lr'].values
dist_lr = lr_df['dist_disc_lr'].values

cost_gen = all_runs.loc[lr_df.index,"cost_gen"].values

n_frames = len(atoms_list)

fig = plt.figure(figsize=(20, 15))
gs  = gridspec.GridSpec(2, 2, height_ratios=[3, 1], hspace=0.3)

# top row: two scatter plots
ax_gen = fig.add_subplot(gs[0, 0])
ax_ref = fig.add_subplot(gs[0, 1])

# bottom row: learning-rate curve spanning both columns
ax_lr  = fig.add_subplot(gs[1, :])


def update(i):

    current_epoch = epochs[i*20] # 20 epochs per frame
    
    # — Top: generated vs reference scatter plots —
    ax_gen.clear()
    ax_ref.clear()
    
    atoms = atoms_list[i]
    radii = atoms.get_array("rmt")
    ax_gen.scatter(
        atoms.get_scaled_positions()[:,0],
        atoms.get_scaled_positions()[:,1],
        c=radii, s=radii*20, alpha=0.5
    )
    ax_gen.set_title("Generated")
    
    radii_ref = ref_struct.get_array("rmt")
    ax_ref.scatter(
        ref_struct.get_scaled_positions()[:,0],
        ref_struct.get_scaled_positions()[:,1],
        c=radii_ref, s=radii_ref*20, alpha=0.5
    )
    ax_ref.set_title("Reference")

    # Set the same limits for both axes
    ax_gen.set_xlim(-0.05, 1.05)
    ax_gen.set_ylim(-0.05, 1.05)
    ax_ref.set_xlim(-0.05, 1.05)
    ax_ref.set_ylim(-0.05, 1.05)
    
    fig.suptitle(f"Generated vs Reference — epoch {current_epoch}")
    
    # — Bottom: all three LR curves + moving red line —
    ax_lr.clear()
    ax_lr.plot(epochs, gen_lr, label="Generator LR",    linewidth=2)
    ax_lr.plot(epochs, coord_lr, label="Coord Disc LR", linewidth=2)
    ax_lr.plot(epochs, dist_lr, label="Dist Disc LR",   linewidth=2)

    # Add another y axis for the generator cost
    ax_cost_gen = ax_lr.twinx()
    ax_cost_gen.plot(epochs, cost_gen, label="Generator cost", color='orange', linewidth=2)
    ax_cost_gen.set_ylabel("Generator cost")

    ax_lr.axvline(current_epoch, color='red', linestyle='--', linewidth=2)
    
    ax_lr.set_xlabel("Epoch")
    ax_lr.set_ylabel("Learning rate")
    ax_lr.set_title("Learning‐rate schedules")
    ax_lr.set_xlim(epochs.min(), epochs.max())
    ax_lr.set_ylim(
        min(gen_lr.min(), coord_lr.min(), dist_lr.min()),
        max(gen_lr.max(), coord_lr.max(), dist_lr.max())
    )

    # Concatenate the legend entries from both axes
    lines, labels = ax_lr.get_legend_handles_labels()
    lines2, labels2 = ax_cost_gen.get_legend_handles_labels()
    ax_lr.legend(lines + lines2, labels + labels2, loc='upper right')


    # ax_lr.legend(loc='upper right')

    
# — Create and show the animation —
ani = FuncAnimation(fig, update, frames=n_frames, interval=200)

In [None]:
ani.save(f"{EXPERIMENT}/training.gif", writer='pillow', fps=0.5)
try:
    ani.save(f"{EXPERIMENT}/training.mp4", writer='ffmpeg', fps=1, dpi=72) # Low dpi for fast export
except:
    print("ffmpeg not installed, skipping mp4 export")
    pass


# Next steps

1. Provide Anshul with generated samples (50 each) in our data format
   1. Low packing fraction 0.70
   2. Mid 0.78
   3. High packing fraction 0.84
   4. Very high 0.86
2. Special assignment: Try out different loss functions to make the generated structures more physically feasible (or something else)
   1. Adding the bond distance discriminator
   2. Radius / overlap loss
   3. NN distance based loss
   4. Hexatic (k=5) order loss
   5. Other physical 
3. (Later: Conditioning on phi / other descriptors)