# Visualize CryinGAN training

In [None]:
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt
import ase
from ase.io import read, write
from ase.visualize import view

In [None]:
EXPERIMENT = "../CCGAN/results/3-with-dist-high-weight"

In [None]:
all_runs = pd.read_csv(EXPERIMENT+"/losses.csv") # With distance discriminator 2
run_starts = all_runs["epoch"] == "epoch"

# Filter out the rows that are not part of the runs
# all_runs = all_runs[~run_starts]

# Add "run" column
all_runs["run"] = 0

run_idx = 0

for i in range(1,len(all_runs)):
    if run_starts[i]:
        all_runs.loc[i:, "run"] = run_idx
        run_idx += 1
    else:
        all_runs.loc[i, "run"] = run_idx

print("Number of runs: ", all_runs["run"].max() + 1)
# Remove the first row of each run
all_runs = all_runs[~run_starts]
all_runs["epoch"] = all_runs["epoch"].astype(int)
all_runs.set_index(["run", "epoch"], inplace=True)
all_runs = all_runs.astype(float)

# Read and visualize losses csv

run_idx = 0 # CHANGE THIS TO SELECT A RUN
df = all_runs.loc[run_idx]

fig = plt.figure(figsize=(20,10))
sns.lineplot(data=df, figure=fig, marker="o")
# plt.grid(which="both")

In [None]:
lr_df = pd.read_csv(f"{EXPERIMENT}/learning_rate.csv")
lr_df.set_index("epoch", inplace=True)
lr_df = lr_df.astype(float)

# Read and visualize learning rate csv

fig, ax = plt.subplots(figsize=(20,10))
sns.lineplot(data=lr_df, marker="o")

## Check on the generated structures

Numpy way from training-time generation

In [None]:
ref_struct = read("../data/processed/samples/phi-0.84/samples.extxyz", index=1, format="extxyz")
view(ref_struct,  viewer="x3d", show_unit_cell=True, repeat=(1,1,1))

atoms = ref_struct.copy()
radii = atoms.get_array('rmt')

plt.figure()
plt.scatter(
    atoms.get_scaled_positions()[:,0],
    atoms.get_scaled_positions()[:,1],
    c=radii,
    s=radii*20,
    alpha=0.5,
)

In [None]:
import numpy as np
file = f"{EXPERIMENT}/gen_coords_60.npy"

np_coords = np.load(file)
atoms = ref_struct.copy()

# Pick a couple of structures to visualize
idxs = [0, 1, 2]

radii = atoms.get_array('rmt')

for idx in idxs:
    atoms.set_scaled_positions(np_coords[3,0]),
    plt.figure()
    plt.scatter(
        atoms.get_scaled_positions()[:,0],
        atoms.get_scaled_positions()[:,1],
        c=radii,
        s=radii*20,
        alpha=0.5,
    )

## Make a video


In [None]:
import glob

atoms_list = []
idx = 2 # Pick one generated sample

file_list = glob.glob(f"{EXPERIMENT}/gen_coords_*.npy")
file_list = sorted(file_list, key=lambda x: int(x.split("_")[-1].split(".")[0])) # Hack to sort in the right order

for file in file_list:
    np_coords = np.load(file)
    atoms = ref_struct.copy()
    atoms.set_scaled_positions(np_coords[idx].squeeze())

    atoms_list.append(atoms)
    # view(atoms, viewer="x3d", show_unit_cell=True, repeat=(1,1,1))
len(file_list)

In [None]:
epochs = lr_df.index.values
gen_lr = lr_df['generator_lr'].values
coord_lr = lr_df['coord_disc_lr'].values
dist_lr = lr_df['dist_disc_lr'].values

In [None]:
all_runs.loc[0]["cost_gen"]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from ase.visualize.plot import plot_atoms
import matplotlib.gridspec as gridspec
from matplotlib.animation import FuncAnimation


# Pick a good cmap
plt.set_cmap("viridis")


lr_df = pd.read_csv(f"{EXPERIMENT}/learning_rate.csv")
epochs = lr_df['epoch'].values
gen_lr = lr_df['generator_lr'].values
coord_lr = lr_df['coord_disc_lr'].values
dist_lr = lr_df['dist_disc_lr'].values

cost_gen = all_runs.loc[0]["cost_gen"].values

n_frames = len(atoms_list)

fig = plt.figure(figsize=(20, 15))
gs  = gridspec.GridSpec(2, 2, height_ratios=[3, 1], hspace=0.3)

# top row: two scatter plots
ax_gen = fig.add_subplot(gs[0, 0])
ax_ref = fig.add_subplot(gs[0, 1])

# bottom row: learning-rate curve spanning both columns
ax_lr  = fig.add_subplot(gs[1, :])


def update(i):

    current_epoch = epochs[i*20] # 20 epochs per frame
    
    # — Top: generated vs reference scatter plots —
    ax_gen.clear()
    ax_ref.clear()
    
    atoms = atoms_list[i]
    radii = atoms.get_array("rmt")
    ax_gen.scatter(
        atoms.get_scaled_positions()[:,0],
        atoms.get_scaled_positions()[:,1],
        c=radii, s=radii*20, alpha=0.5
    )
    ax_gen.set_title("Generated")
    
    radii_ref = ref_struct.get_array("rmt")
    ax_ref.scatter(
        ref_struct.get_scaled_positions()[:,0],
        ref_struct.get_scaled_positions()[:,1],
        c=radii_ref, s=radii_ref*20, alpha=0.5
    )
    ax_ref.set_title("Reference")

    # Set the same limits for both axes
    ax_gen.set_xlim(-0.05, 1.05)
    ax_gen.set_ylim(-0.05, 1.05)
    ax_ref.set_xlim(-0.05, 1.05)
    ax_ref.set_ylim(-0.05, 1.05)
    
    fig.suptitle(f"Generated vs Reference — epoch {current_epoch}")
    
    # — Bottom: all three LR curves + moving red line —
    ax_lr.clear()
    ax_lr.plot(epochs, gen_lr, label="Generator LR",    linewidth=2)
    ax_lr.plot(epochs, coord_lr, label="Coord Disc LR", linewidth=2)
    ax_lr.plot(epochs, dist_lr, label="Dist Disc LR",   linewidth=2)

    # Add another y axis for the generator cost
    ax_cost_gen = ax_lr.twinx()
    ax_cost_gen.plot(epochs, cost_gen, label="Generator cost", color='orange', linewidth=2)
    ax_cost_gen.set_ylabel("Generator cost")

    ax_lr.axvline(current_epoch, color='red', linestyle='--', linewidth=2)
    
    ax_lr.set_xlabel("Epoch")
    ax_lr.set_ylabel("Learning rate")
    ax_lr.set_title("Learning‐rate schedules")
    ax_lr.set_xlim(epochs.min(), epochs.max())
    ax_lr.set_ylim(
        min(gen_lr.min(), coord_lr.min(), dist_lr.min()),
        max(gen_lr.max(), coord_lr.max(), dist_lr.max())
    )

    # Concatenate the legend entries from both axes
    lines, labels = ax_lr.get_legend_handles_labels()
    lines2, labels2 = ax_cost_gen.get_legend_handles_labels()
    ax_lr.legend(lines + lines2, labels + labels2, loc='upper right')



    # ax_lr.legend(loc='upper right')

    
# — Create and show the animation —
ani = FuncAnimation(fig, update, frames=n_frames, interval=200)

In [None]:
ani.save(f"{EXPERIMENT}/training.gif", writer='pillow', fps=0.5)
try:
    ani.save(f"{EXPERIMENT}/training.mp4", writer='ffmpeg', fps=1, dpi=72) # Low dpi for fast export
except:
    print("ffmpeg not installed, skipping mp4 export")
    pass


## Read manually generated samples

NOTE: To generate these samples you have to run (for example)

```
python generate_structures.py --load_generator results/3-with-dist-high-weight/generator_600 --n_struc 50 --ref_struc ../data/processed/samples/phi-0.80/samples.extxyz --label_phis 0.80 --n_labels 1 --gen_channels_1 256 --latent_dim 256 --gen_label_dim 128
```

Make sure to pick the right generator and have all arguments identical to your training run.

Pick generated phi value with --label_phis


In [None]:
PHI = 0.8
file = f"../CCGAN/gen-phi-{PHI}.extxyz" # For generated data

all_atoms = read(file, index=":", format="extxyz")
print("Read {} atoms".format(len(all_atoms)))


display(view(all_atoms[0], viewer="x3d", show_unit_cell=True))
display(view(all_atoms[25], viewer="x3d", show_unit_cell=True))
display(view(all_atoms[49], viewer="x3d", show_unit_cell=True))

# Next steps

1. Provide Anshul with generated samples (50 each) in our data format
   1. Low packing fraction 0.70
   2. Mid 0.78
   3. High packing fraction 0.84
   4. Very high 0.86
2. Special assignment: Try out different loss functions to make the generated structures more physically feasible (or something else)
   1. Adding the bond distance discriminator
   2. Radius / overlap loss
   3. NN distance based loss
   4. Hexatic (k=5) order loss
   5. Other physical 
3. (Later: Conditioning on phi / other descriptors)

In [None]:
# Generate samples for all different phi values
import os
phis = [
    "0.86",
    "0.84",
    "0.80",
    "0.70",
]

experiment_name = EXPERIMENT.split("/")[-1]+"/" # Get the experiment name from the path


command = lambda phi: f"""
cd .. && cd CCGAN && python generate_structures.py \
    --load_generator {EXPERIMENT}/generator_600 \
    --n_struc 50 \
    --ref_struc ../data/processed/samples/phi-{phi}/samples.extxyz \
    --label_phis {phi} \
    --n_labels 1 \
    --gen_channels_1 256 \
    --latent_dim 256 \
    --gen_label_dim 128 \
    --write_fname ../data/gen-extxyz/{experiment_name} \
"""


for phi in phis:
    os.system(command(phi))

In [None]:
files = glob.glob(f"../data/gen-extxyz/{experiment_name}/phi-*.extxyz")
files

In [None]:
# Save the coordinates to a CSV file

import os
# os.mkdir("../data/gen")
# os.mkdir("../data/gen/phi-0.84")

from pathlib import Path

files = glob.glob(f"../data/gen-extxyz/{experiment_name}/phi-*.extxyz")

for file in files:
    print(file)
    output_folder = file.split("phi-")[-1].split(".")[1]
    # output_file = output_file.split("-")[1]
    output_folder= Path(f"../data/gen/{experiment_name}/phi-0.{output_folder}")
    output_folder.mkdir(parents=True, exist_ok=True)
    all_atoms = read(file, index=":", format="extxyz")
    
    for i, atoms in enumerate(all_atoms):

        output_file = f"{output_folder}/sample-{i}"
        
        with open(output_file, "w+") as f:
            
            L = atoms.info["L"]
            N = atoms.info["N"]

            header = "\t{N}\t{L}\t1626.81570886301\t\n".format(N=N, L=L)
            f.write(header)

            _df = pd.DataFrame({
                "class": atoms.get_atomic_numbers(),
                "x": atoms.get_positions()[:,0],
                "y": atoms.get_positions()[:,1],
            }).to_csv(f, index=False, header=False, sep="\t")


In [None]:
# Check one sample from each generated phi

for PHI in phis:

    atoms = read(f"../data/gen-extxyz/{experiment_name}/phi-{float(PHI)}.extxyz", index=0, format="extxyz")
    ref_struct = read(f"../data/processed/samples/phi-{PHI}/samples.extxyz", index=1, format="extxyz")


    fig, axs = plt.subplots(1, 2, figsize=(20, 10))

    radii = atoms.get_array('rmt')
    axs[0].scatter(
        atoms.get_scaled_positions()[:,0],
        atoms.get_scaled_positions()[:,1],
        c=radii,
        s=radii*20,
        alpha=0.5,
    )

    radii = ref_struct.get_array('rmt')
    axs[1].scatter(
        ref_struct.get_scaled_positions()[:,0],
        ref_struct.get_scaled_positions()[:,1],
        c=radii,
        s=radii*20,
        alpha=0.5,
    )

    axs[0].set_title("Generated")
    axs[1].set_title("Reference")
    plt.suptitle(f"Generated vs Reference, phi {PHI}")