In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
from ipywidgets import interact, IntSlider, FloatSlider, Dropdown

seed = 42
model = "resnet18"
dataset = "cifar100"
# Define the base directory
base_dir = f"/jumbo/yaoqingyang/kinshuk/TempBalance/results/flatten/{seed}/{model}/{dataset}"

# Function to visualize the data
def visualize_data(row_samples, q_ratio, num_ops):
    # Construct the directory path
    slide = True
    tmp_str = f"row_{row_samples}/qr_{q_ratio}/ops_{num_ops}"
    if not slide:
        tmp_str = ""
    data_dir = f"{base_dir}/slide_{slide}/{tmp_str}/esd_est"
    
    # Collect all epoch files
    if not os.path.exists(data_dir):
        print(f"Directory does not exist: {data_dir}")
        return
    
    epoch_files = [f for f in os.listdir(data_dir) if f.startswith("epoch_")]
    if not epoch_files:
        print(f"No epoch files found in: {data_dir}")
        return
    
    epoch_files.sort(key=lambda f: int(f.split('_')[1]))
    epoch_files = epoch_files[:50]  # Restrict to the first 50 epochs
    
    # Initialize data storage
    layer_alpha_data = []
    
    # Process each epoch file
    for _, epoch_file in enumerate(epoch_files):
        epoch_path = os.path.join(data_dir, epoch_file)
        epoch_number = int(epoch_file.split('_')[1])  # Extract epoch number from filename
        df = pd.read_csv(epoch_path)
        df['epoch'] = epoch_number
        layer_alpha_data.append(df[['alpha', 'epoch']])  # Keep only 'alpha' and 'epoch'
    
    # Concatenate data from all epochs
    all_data = pd.concat(layer_alpha_data, ignore_index=True)
    
    # Add layer index
    all_data['layer_index'] = all_data.groupby('epoch').cumcount()
    
    # Set color map based on epochs
    num_epochs = len(epoch_files)
    cmap = plt.get_cmap('jet_r')
    colors = [cmap(i / num_epochs) for i in range(num_epochs)]
    
    # Create the figure and axes
    fig, ax = plt.subplots(figsize=(12, 8))
    fig.patch.set_facecolor('white')
    
    # Plot alpha values
    for epoch_number, color in enumerate(colors):
        epoch_data = all_data[all_data['epoch'] == epoch_number]
        ax.plot(
            epoch_data['layer_index'],
            epoch_data['alpha'],
            color=color,
            alpha=0.8,
            linewidth=2.5,
            linestyle='-',
        )
    
    # Customize grid style
    ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
    
    # Set axis labels and title
    ax.set_title(f"ResNet | Layer Sampling: {slide}", fontsize=18, fontweight='bold', pad=15)
    ax.set_xlabel("Layer Index", fontsize=16)
    ax.set_ylabel("Alpha", fontsize=16)
    ax.tick_params(axis='both', which='major', labelsize=14)
    
    # Add colorbar
    norm = Normalize(vmin=0, vmax=num_epochs)
    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.02)
    cbar.set_label('Training Epochs', fontsize=16)
    cbar.ax.tick_params(labelsize=14)
    
    # Adjust layout
    plt.tight_layout()
    plt.ylim(1.25, 3)
    
    # Display the plot
    plt.show()

# Define interactive widgets
row_samples_widget = IntSlider(value=50, min=10, max=100, step=10, description='Row Samples:')
q_ratio_widget = FloatSlider(value=2.0, min=1.0, max=5.0, step=0.5, description='Q Ratio:')
num_ops_widget = IntSlider(value=10, min=1, max=20, step=1, description='Num Ops:')

# Create an interactive widget
interact(visualize_data, 
         row_samples=row_samples_widget, 
         q_ratio=q_ratio_widget, 
         num_ops=num_ops_widget)


interactive(children=(IntSlider(value=50, description='Row Samples:', min=10, step=10), FloatSlider(value=2.0,…

<function __main__.visualize_data(row_samples, q_ratio, num_ops)>