In [1]:
import pandas as pd
import json
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path
from functools import partial
import numpy as np


In [2]:
from neptune.internal.utils.logger import get_logger
get_logger().setLevel('ERROR')

In [3]:
from typing import Optional, Any, Tuple, Dict, List, Iterable, Union
from ast import literal_eval as make_tuple
from functools import reduce

import neptune.api
import neptune.attributes


NEPTUNE_PROJECT = "pmtest/llm-random"
STD_TAIL_QUANTILE = 0.9
STD_TAIL_KEY = f"std_q{STD_TAIL_QUANTILE}"
IMAGES_PATH = Path("/Users/szysad/mimdir/magisterka/images")


project = neptune.init_project(
  project=NEPTUNE_PROJECT,
  mode="read-only"
)
columns = [
    "sys/tags",
    "loss_interval/100",
    "args/learning_rate",
    "sys/name",
    "args/grad_modif_params",
    "step",
    "sys/id",
    "sys/failed"
]

placements = set([
    "post_attn_and_ff",
    "post_norm",
    "post_add",
    "all",
])


class SeriesCache:
    def __init__(self, index_path: Path, data_dir: Path):
        self.index_path = index_path
        self.data_dir = data_dir
        self.index = {}
        if not self.index_path.exists():
            with open(self.index_path, 'w') as f:
                json.dump({}, f)
        else:
            with open(self.index_path, 'r') as f:
                self.index = json.load(f)

        if not self.data_dir.exists():
            data_dir.mkdir(parents=True)
    
    def add_entry(self, path: List[str], series: pd.DataFrame, overwrite: bool = False):
        sub_data = self.index
        for p in path[:-1]:
            if p not in sub_data:
                sub_data[p] = {}
            sub_data = sub_data[p]
        
        if path[-1] in sub_data and not overwrite:
            return

        fname = '-'.join(path) + ".pkl"
        series.to_pickle(self.data_dir / fname)
        sub_data[path[-1]] = fname
    
    def get_entry(self, path: List[str]) -> pd.DataFrame:
        sub_data = self.index
        for p in path[:-1]:
            if p not in sub_data:
                return None
            sub_data = sub_data[p]
        
        if path[-1] not in sub_data:
            return None
        
        fname = sub_data[path[-1]]
        return pd.read_pickle(self.data_dir / fname)
    
    def save(self):
        with open(self.index_path, 'w') as f:
            json.dump(self.index, f)

series_cache = SeriesCache(Path("./grad_series_index.json"), Path("./grad_series_data"))

def rename_to_common(df: pd.DataFrame):
    return df.rename(columns={"loss_interval/100": "loss", "args/learning_rate": "lr", "args/grad_modif_params": "grad_modif_params"}, inplace=False)


def infere_layer_type(tags: str):
    tags = tags.split(',')
    if "true_baseline" in tags:
        return "baseline"
    else:
        return "regular"

def infere_k(grad_modif_params: str):
    params = grad_modif_params.split(',')
    for p in params:
        key, val = p.split('=')
        if key == "k":
            if isinstance(val, str):
                return val
            return float(val)
    return None

def infere_norm_type(tags: str):
    tags = tags.split(',')
    for t in tags:
        if t == "std_norm":
            return t
        elif t == "scale_norm":
            return t
    raise ValueError(f"Unknown norm type: {tags}")

def infere_placement_type(tags: str):
    tags = tags.split(',')
    for t in tags:
        if t in placements:
            return t


def infere_c(grad_modif_params: str):
    params = grad_modif_params.split(',')
    for p in params:
        key, val = p.split('=')
        if key == "c":
            return float(val)
    return None

def infere_eps(grad_modif_params: str):
    params = grad_modif_params.split(',')
    for p in params:
        key, val = p.split('=')
        if key == "eps":
            return float(val)
    return None

def infere_std_norm_dims_from_tags(tags: str) -> str:
    tags = tags.split(',')
    for t in tags:
        if t == "std_v1":
            return (0,1,2)
        if t == "std_v4":
            return (1,2)
        if t == "std_v2":
            return (2,)
    raise ValueError(f"Unknown norm dims: {tags}")

def infere_norm_dims(grad_modif_params: str) -> Tuple[int, ...]:
    start_idx = grad_modif_params.find("norm_dims")
    if start_idx == -1:
        return None

    next_opening_bracket_idx = grad_modif_params.find("(", start_idx)
    next_closing_bracket_idx = grad_modif_params.find(")", start_idx)

    return tuple(make_tuple(grad_modif_params[next_opening_bracket_idx:next_closing_bracket_idx+1]))

def extract_structure_metrics(structure: Dict[str, Any], key: str, path: List[str] = []) -> Iterable[Tuple[str, str, Any]]:
    for k, v in structure.items():
        if k == key:
            yield '.'.join(path), k, v
        elif isinstance(v, dict):
            yield from extract_structure_metrics(v, key, path + [k])


def fetch_run_mean_grad_stds_series(run_id: str) -> pd.DataFrame:
    with neptune.init_run(project=NEPTUNE_PROJECT, with_id=run_id, mode="read-only") as run:
        structure = run.get_structure()
        grad_norms = extract_structure_metrics(structure, "raw_grad_norms")
        dfs = []
        for path, _, series in grad_norms:
            series_df = series['std'].fetch_values(include_timestamp=False)
            series_df.rename(columns={'value': path}, inplace=True)
            dfs.append(series_df)
        
        df_merged = reduce(lambda left, right: pd.merge(left, right, on=['step'], how='outer'), dfs)
        df_merged['mean_stds'] = df_merged.drop(columns=['step']).mean(axis=1)
        return df_merged[['step', 'mean_stds']]


def fetch_and_save_run_grad_series(run_id: str, series_cache: SeriesCache, overwrite: bool = False):
    if run_id in series_cache.index and not overwrite:
        return
    with neptune.init_run(project=NEPTUNE_PROJECT, with_id=run_id, mode="read-only") as run:
        structure = run.get_structure()
        grad_norms = extract_structure_metrics(structure, "raw_grad_norms")
        statistics = ['std', 'mean']
        dfs = dict(zip(statistics, [list() for _ in statistics]))
        for path, _, series in grad_norms:
            for statistic in statistics:
                series_df = series[statistic].fetch_values(include_timestamp=False)
                series_df.rename(columns={'value': path}, inplace=True)
                dfs[statistic].append(series_df)
        
        for statistic in statistics:
            df_merged = reduce(lambda left, right: pd.merge(left, right, on=['step'], how='outer'), dfs[statistic])
            series_cache.add_entry([run_id, statistic], df_merged, overwrite=overwrite)
        
        series_cache.save()
    
def get_run_mean_grad_std_quantile_diff(run_id: str, series_cache: SeriesCache, qbase: float, qref: float) -> pd.DataFrame:
    if run_id not in series_cache.index:
        fetch_and_save_run_grad_series(run_id, series_cache)
    grad_std_series = series_cache.get_entry([run_id, 'std'])
    std_norm_qbase = grad_std_series.drop(columns=['step']).quantile(q=qbase, axis=1)
    std_norm_qref = grad_std_series.drop(columns=['step']).quantile(q=qref, axis=1)
    std_norm_quantile = (std_norm_qref - std_norm_qbase) / std_norm_qbase
    return std_norm_quantile.to_list()

def get_run_std_quantile_series(run_id: str, series_cache: SeriesCache, q: float) -> pd.DataFrame:
    if run_id not in series_cache.index:
        fetch_and_save_run_grad_series(run_id, series_cache)
    grad_std_series = series_cache.get_entry([run_id, 'std'])
    std_norm_q = grad_std_series.drop(columns=['step']).quantile(q=q, axis=1)
    return std_norm_q.to_list()

def get_run_mean_grad_std_steps(run_id: str, series_cache: SeriesCache) -> pd.DataFrame:
    if run_id not in series_cache.index:
        fetch_and_save_run_grad_series(run_id, series_cache)
    grad_std_series = series_cache.get_entry([run_id, 'std'])
    return grad_std_series['step'].to_list()

def get_run_loss_steps_series(run_id: str, series_cache: SeriesCache, smoothing: str = "100") -> pd.DataFrame:
    key = f"loss_{smoothing}"
    if run_id in series_cache.index and key in series_cache.index[run_id]:
        return series_cache.get_entry([run_id, key])['step'].to_list()
    
    with neptune.init_run(project=NEPTUNE_PROJECT, with_id=run_id, mode="read-only") as run:
        loss_series_df = run[f'loss_interval/{smoothing}'].fetch_values(include_timestamp=False)
        loss_series_df.rename(columns={'value': key}, inplace=True)
        series_cache.add_entry([run_id, key], loss_series_df)
        return loss_series_df['step'].to_list()

def get_run_loss_series(run_id: str, series_cache: SeriesCache, smoothing: str = "100") -> pd.DataFrame:
    key = f"loss_{smoothing}"
    if run_id in series_cache.index and key in series_cache.index[run_id]:
        return series_cache.get_entry([run_id, key])[key].to_list()
    
    with neptune.init_run(project=NEPTUNE_PROJECT, with_id=run_id, mode="read-only") as run:
        loss_series_df = run[f'loss_interval/{smoothing}'].fetch_values(include_timestamp=False)
        loss_series_df.rename(columns={'value': key}, inplace=True)
        series_cache.add_entry([run_id, key], loss_series_df)
        return loss_series_df[key].to_list()

def get_run_mean_grad_q_series(run_id: str, series_cache: SeriesCache, q: float) -> pd.DataFrame:
    if run_id not in series_cache.index:
        fetch_and_save_run_grad_series(run_id, series_cache)
    grad_std_series = series_cache.get_entry([run_id, 'mean'])
    mean_norm_q = grad_std_series.drop(columns=['step']).quantile(q=q, axis=1)
    return mean_norm_q.to_list()

def get_run_mean_grad_avg_series(run_id: str, series_cache: SeriesCache) -> pd.DataFrame:
    if run_id not in series_cache.index:
        fetch_and_save_run_grad_series(run_id, series_cache)
    grad_std_series = series_cache.get_entry([run_id, 'mean'])
    mean_norm_q = grad_std_series.drop(columns=['step']).mean(axis=1)
    return mean_norm_q.to_list()

def fetch_run_mean_grad_last_std_series(run_id: str) -> pd.DataFrame:
    with neptune.init_run(project=NEPTUNE_PROJECT, with_id=run_id, mode="read-only") as run:
        structure = run.get_structure()
        grad_norms = extract_structure_metrics(structure, "raw_grad_norms")
        last_stds = []
        for _, _, series in grad_norms:
            last_std = series['std'].fetch_last()
            last_stds.append(last_std)
        
        return sum(last_stds) / len(last_stds)

def common_line_plot(df: pd.DataFrame, x: str, y: str, color: str, title: str, trace: Optional[Any] = None, x_range: Optional[Tuple[float, float]] = None, y_range: Optional[Tuple[float, float]] = None) -> go.Figure:
    fig = px.line(df, x=x, y=y, color=color, markers=True, log_x=True, log_y=True)

    if trace:
        fig.add_trace(trace)

    fig.update_layout(
        title=title,
        yaxis = dict(
            showexponent = 'all',
            exponentformat = 'power',
            range = y_range
        ),
        xaxis = dict(
            showexponent = 'all',
            exponentformat = 'power',
            range = x_range
        )
    )

    return fig

def common_line_plot_traces(df: pd.DataFrame, x: str, y: str, color: str, title: str, trace: Optional[Any], x_range: Optional[Tuple[float, float]] = None, y_range: Optional[Tuple[float, float]] = None) -> List[go.Figure]:
    fig = px.line(df, x=x, y=y, color=color, markers=True, log_x=True, log_y=True)

    objs = []
    objs.extend(fig.data)

    if trace is not None:
        objs.append(trace)

    return objs

def common_plot_nested_list(df: pd.DataFrame, color: str, x: Union[str, List[float]], y: Union[str, List[float]], title: str, trace: Optional[Any] = None, **kwargs) -> go.Figure:
    fig = go.Figure()
    runs = df[color].unique()
    for run in runs:
        run_df = df[df[color] == run]
        assert len(run_df) == 1
        x_vals = run_df[x].values[0] if isinstance(x, str) else x
        y_vals = run_df[y].values[0] if isinstance(y, str) else y
        fig.add_trace(go.Scatter(x=x_vals, y=y_vals, mode='lines', name=run, hovertext=run, **kwargs))
    
    if trace:
        fig.add_trace(trace)

    
    fig.update_layout(
        title=title,
        yaxis = dict(
            showexponent = 'all',
            exponentformat = 'power',
            type='log'
        ),
        #xaxis = dict(
        #    showexponent = 'all',
        #    exponentformat = 'power'
        #)
    )

    return fig

#### Baseline loss vs lr categorised by eps, total_steps

In [4]:
baseline_df = project.fetch_runs_table(tag="true_baseline", columns=columns).to_pandas()
baseline_df = rename_to_common(baseline_df)
baseline_df = baseline_df[baseline_df['grad_modif_params'].notna()]
baseline_df['eps'] = baseline_df['grad_modif_params'].apply(infere_eps)
baseline_df['sys/id'].apply(partial(fetch_and_save_run_grad_series, series_cache=series_cache, overwrite=False))

baseline_df.sort_values('lr', inplace=True)

In [5]:
fig = common_line_plot(baseline_df, x="lr", y="loss", color="step", title="Baseline Loss vs LR. Categorised by eps, total_steps")
fig.show()

In [6]:
baseline_df_optimal_lr = baseline_df[baseline_df['lr'] == 1e-3].copy()
baseline_df_optimal_lr['steps'] = baseline_df_optimal_lr['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
baseline_df_optimal_lr['loss_steps_1000'] = baseline_df_optimal_lr['sys/id'].apply(partial(get_run_loss_steps_series, series_cache=series_cache, smoothing="1000"))
baseline_df_optimal_lr[STD_TAIL_KEY] = baseline_df_optimal_lr['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
fig = common_plot_nested_list(baseline_df_optimal_lr, color='sys/id', x="steps", y=STD_TAIL_KEY, title="Baseline std_tail_change vs steps")
fig.show()

baseline_df_short = baseline_df[baseline_df['step'] == 2000]
baseline_df_long = baseline_df[~(baseline_df['step'] == 10000)]
baseline_trace_loss_lr_short = go.Scatter(x=baseline_df_short['lr'], y=baseline_df_short['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))

baseline_grad_norm_std_long = go.Scatter(x=baseline_df_optimal_lr['steps'].values[0], y=baseline_df_optimal_lr[STD_TAIL_KEY].values[0], mode='lines', name='baseline', line=dict(color='black', width=2, dash='dash'))
baseline_grad_norm_std_short = go.Scatter(x=baseline_df_optimal_lr['steps'].values[1], y=baseline_df_optimal_lr[STD_TAIL_KEY].values[1], mode='lines', name='baseline', line=dict(color='black', width=2, dash='dash'))

# Short Experiments (2k steps)

#### Loss vs LR categorised by baseline and sanity_checks for v1, v2 std norm

In [7]:
grad_std_norm_type_tag = {
    "std_v1": "std_v1_c_lr_grid_placement_short",
    "std_v2": "std_v2_c_lr_grid_placement_short",
    "std_v4": "std_v4_c_lr_grid_placement_short"
}

std_norm_dims = {
    "std_v1": "(0, 1, 2)",
    "std_v2": "(2,)",
    "std_v4": "(1, 2)"
}

SHORT_RUN_LOSS_RANGE = tuple(map(np.log10, (4.5, 8.0)))

_dfs = []
for grad_norm_type, tag in grad_std_norm_type_tag.items():
    df = project.fetch_runs_table(tag=tag, columns=columns).to_pandas()
    df['norm_dims'] = std_norm_dims[grad_norm_type]
    _dfs.append(df)

std_df = pd.concat(_dfs)
std_df = rename_to_common(std_df)
std_df['eps'] = std_df['grad_modif_params'].apply(infere_eps)
std_df['placement'] = std_df['sys/tags'].apply(infere_placement_type)
std_df['c'] = std_df['grad_modif_params'].apply(infere_c)

Fetching table...: 0 [00:00, ?/s]

In [8]:
std_sanity_check_df = std_df[std_df['c'] == 0].copy()
std_sanity_check_df.sort_values('lr', inplace=True)
std_sanity_check_df['color'] = std_sanity_check_df['placement'] + ", agg_dims=" + std_sanity_check_df['norm_dims']

fig = common_line_plot(std_sanity_check_df, x="lr", y="loss", color="color", title="Sanity check (c=0) - Loss vs LR Categorised by placement and agg_dims.", y_range=SHORT_RUN_LOSS_RANGE, trace=baseline_trace_loss_lr_short)
fig.update_layout(
    autosize=False,
    width=1600,
    height=600,
)
fig.show()

In [9]:
AGG_STD_DICT = {STD_TAIL_KEY: lambda ls: [np.mean(vs) for vs in zip(*ls)], 'steps': 'first'}

In [10]:
std_sanity_check_df_optimal_lr = std_sanity_check_df[std_sanity_check_df['lr'] == 1e-3].copy()
std_sanity_check_df_optimal_lr['steps'] = std_sanity_check_df_optimal_lr['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
std_sanity_check_df_optimal_lr[STD_TAIL_KEY] = std_sanity_check_df_optimal_lr['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=0.9))
std_sanity_check_df_optimal_lr['color'] = std_sanity_check_df_optimal_lr['placement'] + ", agg_dims=" + std_sanity_check_df_optimal_lr['norm_dims']
#std_sanity_check_df_optimal_lr = std_sanity_check_df_optimal_lr.groupby('placement').agg(AGG_STD_DICT).reset_index()

fig = common_plot_nested_list(std_sanity_check_df_optimal_lr, color='color', x="steps", y=STD_TAIL_KEY, title="Sanity check (c=0) - std_tail_change vs steps", trace=baseline_grad_norm_std_short)
fig.update_layout(
    xaxis = dict(title='step'),
    yaxis = dict(title='grad_std_percentile_90'),
    autosize=False,
    width=1600,
    height=600,
)
fig.show()

In [11]:
std_df = std_df[(std_df['c'] != 0) & (std_df['eps'] == 0) & (std_df['step'] == 2000)].copy()
std_df['color'] = std_df['placement'] + ", agg_dims=" + std_df['norm_dims']
std_df.sort_values('lr', inplace=True)

for c in sorted(std_df['c'].unique()):
    std_df_placement = std_df[std_df['c'] == c]
    fig = common_line_plot(std_df_placement, x="lr", y="loss", color='color', title=f"Loss vs LR Categorised for std norm. c={c}.", y_range=SHORT_RUN_LOSS_RANGE, trace=baseline_trace_loss_lr_short)
    fig.update_layout(
        autosize=False,
        width=1600,
        height=800,
    )
    fig.show()

# Chosen examples of parameter impact on loss vs lr

In [12]:
# 1. imapct of norm_dims on loss vs lr for post_add std norm when c=1e-2
def std_norm_dims_stringify(norm_dims: str) -> str:
    if norm_dims == "(2,)":
        return "d_m"
    elif norm_dims == "(1, 2)":
        return "seq_len x d_m"
    elif norm_dims == "(0, 1, 2)":
        return "all"

FONT_SIZE = 22
_exmpl1_df = std_df[(std_df['placement'] == 'post_add') & (std_df['c'] == 1e-2)].copy()
_exmpl1_df['norm_dims'] = _exmpl1_df['norm_dims'].apply(std_norm_dims_stringify)
_exmpl1_df.sort_values(['lr', 'norm_dims'], inplace=True)
fig = common_line_plot(_exmpl1_df, x="lr", y="loss", color='norm_dims', title=None, trace=baseline_trace_loss_lr_short)
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Learning Rate'),
    yaxis = dict(title='Loss'),
)
fig.update_traces(line=dict(width=3))
fig.show()
fig.write_image((IMAGES_PATH / "loss_vs_lr_std_norm_post_add_c_1e-2_short.pdf"), format="pdf")

In [13]:
# 2. imapct of c for post_add placements for std_v1 norm_dims=(0, 1, 2)
_exmpl2_df = std_df[(std_df['norm_dims'] == "(1, 2)") & (std_df['placement'] == 'post_add')].copy()
_exmpl2_df.sort_values(['lr', 'c'], inplace=True)
fig = common_line_plot(_exmpl2_df, x="lr", y="loss", color='c', title=None, y_range=SHORT_RUN_LOSS_RANGE, trace=baseline_trace_loss_lr_short)
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Learning Rate'),
    yaxis = dict(title='Loss'),
)
fig.update_traces(line=dict(width=3))
fig.write_image((IMAGES_PATH / "loss_vs_lr_std_norm_post_add_norm_dims_1_2_short.pdf"), format="pdf")
fig.show()

In [14]:
# 3. imapct of placement for std_norm c=1e-3 and norm_dims=(1, 2)
# very hard to predict impact of placement on loss vs lr for std norm but 'all' seems to be the one performing the worst
_exp_3_yrange = tuple(map(np.log10, (4.5, 8)))
_exmpl3_df_1 = std_df[(std_df['norm_dims'] == "(1, 2)") & (std_df['c'] == 1e-1)].copy()
_exmpl3_df_2 = std_df[(std_df['norm_dims'] == "(1, 2)") & (std_df['c'] == 1e-4)].copy()
_exmpl3_df_1.sort_values(['lr', 'placement'], inplace=True)
_exmpl3_df_2.sort_values(['lr', 'placement'], inplace=True)
fig1 = common_line_plot(_exmpl3_df_1, x="lr", y="loss", color='placement', title=None, y_range=_exp_3_yrange, trace=baseline_trace_loss_lr_short)
fig2 = common_line_plot(_exmpl3_df_2, x="lr", y="loss", color='placement', title=None, y_range=_exp_3_yrange, trace=baseline_trace_loss_lr_short)
fig1.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Learning Rate'),
    yaxis = dict(title='Loss'),
)
fig1.update_traces(line=dict(width=3))
fig2.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Learning Rate'),
    yaxis = dict(title='Loss'),
)
fig2.update_traces(line=dict(width=3))
fig1.show()
fig1.write_image((IMAGES_PATH / "loss_vs_lr_std_norm_post_add_norm_dims_1_2_c_1e-1_short.pdf"), format="pdf")
fig2.show()
fig2.write_image((IMAGES_PATH / "loss_vs_lr_std_norm_post_add_norm_dims_1_2_c_1e-2_short.pdf"), format="pdf")

In [15]:

scale_norm_df = project.fetch_runs_table(tag="scale_norm_c_lr_grid_placement_short", columns=columns).to_pandas()
scale_norm_df = rename_to_common(scale_norm_df)
scale_norm_df['k'] = scale_norm_df['grad_modif_params'].apply(infere_k)
scale_norm_df['c'] = scale_norm_df['grad_modif_params'].apply(infere_c)
scale_norm_df['placement'] = scale_norm_df['sys/tags'].apply(infere_placement_type)
scale_norm_df['norm_dims'] = scale_norm_df['grad_modif_params'].apply(infere_norm_dims)
scale_norm_df.sort_values('lr', inplace=True)

Fetching table...: 0 [00:00, ?/s]



In [16]:
# 4. imapct of norm dims on loss vs lr for scale norm when c=1e-2
_exmpl4_df = scale_norm_df[(scale_norm_df['placement'] == 'post_norm') & (scale_norm_df['c'] == 1e-5) & (scale_norm_df['k'] == '1')].copy()
_exmpl4_df.sort_values(['lr', 'norm_dims'], inplace=True)
fig = common_line_plot(_exmpl4_df, x="lr", y="loss", color='norm_dims', title="Loss vs LR Categorised for scale norm.", trace=baseline_trace_loss_lr_short)
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
)
fig.show()

In [17]:
# 5. compare l2 norm to std norm for c=1e-3 and norm_dims=(1, 2) and placement=post_norm
_c = 1e-2
_placement = 'post_attn_and_ff'
_exp_3_yrange = tuple(map(np.log10, (4.5, 6.7)))
_exampl5_std_df = std_df[(std_df['norm_dims'] == "(1, 2)") & (std_df['c'] == _c) & (std_df['placement'] == _placement)].copy()
_exampl5_scale_df = scale_norm_df[(scale_norm_df['norm_dims'] == (1, 2)) & (scale_norm_df['c'] == _c) & (scale_norm_df['placement'] == _placement) & (scale_norm_df['k'] == '1')].copy()
_exmpl5_df = pd.concat([_exampl5_std_df.assign(norm_type='Std Norm'), _exampl5_scale_df.assign(norm_type='L2 Norm')])
_exmpl5_df.sort_values('lr', inplace=True)
fig = common_line_plot(_exmpl5_df, x="lr", y="loss", color='norm_type', title=None, trace=baseline_trace_loss_lr_short, y_range=_exp_3_yrange)
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Learning Rate'),
    yaxis = dict(title='Loss'),
)
fig.show()
fig.write_image((IMAGES_PATH / "loss_vs_lr_std_vs_l2_norm_post_layer_norm_dims_1_2_c_1e-2_short.pdf"), format="pdf")

In [18]:
std_df_optimal_lr = std_df[std_df['lr'] == 1e-3].copy()
std_df_optimal_lr['steps'] = std_df_optimal_lr['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
std_df_optimal_lr[STD_TAIL_KEY] = std_df_optimal_lr['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
std_df_optimal_lr['color'] = std_df_optimal_lr['placement'] + ", agg_dims=" + std_df_optimal_lr['norm_dims']

In [19]:
for c in sorted(std_df['c'].unique()):
    std_df_c = std_df_optimal_lr[std_df_optimal_lr['c'] == c]
    title = f"90th percentyle of gradients standard derivation vs steps for Std Norm, c={c}"
    fig = common_plot_nested_list(std_df_c, color='color', x="steps", y=STD_TAIL_KEY, title=title, trace=baseline_grad_norm_std_short)
    fig.update_layout(
        xaxis = dict(title='step'),
        yaxis = dict(title='grad_std_percentile_90'),
        autosize=False,
        width=1600,
        height=600,
    )
    fig.show()

In [20]:
# 1. impact of c on GST_90 for std norm
_exampl1_std_df = std_df_optimal_lr[(std_df_optimal_lr['placement'] == 'post_add') & (std_df_optimal_lr['norm_dims'] == "(1, 2)")].copy()
_exampl1_std_df.sort_values(['c'], inplace=True)
fig = common_plot_nested_list(_exampl1_std_df, color='c', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_short)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Step'),
    yaxis = dict(title='GST_90'),
)
fig.show()
fig.write_image((IMAGES_PATH / "gst_90_vs_steps_std_norm_post_add_norm_dims_1_2_short.pdf"), format="pdf")

In [21]:
# 2. impact of norm_dims on GST_90 for std norm in noise environments c=1e-2
_exampl2_std_df = std_df[(std_df['placement'] == 'post_add') & (std_df['c'] == 1e-2)  & (std_df['lr'] == 1e-3)].copy()
_exampl2_std_df['steps'] = _exampl2_std_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
_exampl2_std_df[STD_TAIL_KEY] = _exampl2_std_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
_exampl2_std_df['norm_dims'] = _exampl2_std_df['norm_dims'].astype(str).apply(std_norm_dims_stringify)
_exampl2_std_df.sort_values(['norm_dims'], inplace=True)
fig = common_plot_nested_list(_exampl2_std_df, color='norm_dims', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_short)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Step'),
    yaxis = dict(title='QST_90'),
)
fig.show()
fig.write_image((IMAGES_PATH / "gst_90_vs_steps_std_norm_post_add_c_1e-2_short.pdf"), format="pdf")

In [22]:
# 2. impact of norm_dims on GST_90 for std norm in stable environments c=1e-4
_exampl2_std_df = std_df[(std_df['placement'] == 'post_add') & (std_df['c'] == 1e-4)  & (std_df['lr'] == 1e-3)].copy()
_exampl2_std_df['steps'] = _exampl2_std_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
_exampl2_std_df[STD_TAIL_KEY] = _exampl2_std_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
_exampl2_std_df['norm_dims'] = _exampl2_std_df['norm_dims'].astype(str).apply(std_norm_dims_stringify)
_exampl2_std_df.sort_values(['norm_dims'], inplace=True)
fig = common_plot_nested_list(_exampl2_std_df, color='norm_dims', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_short)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Step'),
    yaxis = dict(title='GST_90'),
)
fig.show()
fig.write_image((IMAGES_PATH / "gst_90_vs_steps_std_norm_post_add_c_1e-4_short.pdf"), format="pdf")

In [23]:
# 3. impact of norm_dims on GST_90 for l2 norm
_exampl2_std_df = scale_norm_df[(scale_norm_df['placement'] == 'post_add') & (scale_norm_df['c'] == 1e-3) & (scale_norm_df['k'] == '1') & (scale_norm_df['lr'] == 1e-3)].copy()
_exampl2_std_df['steps'] = _exampl2_std_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
_exampl2_std_df[STD_TAIL_KEY] = _exampl2_std_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
_exampl2_std_df['norm_dims'] = _exampl2_std_df['norm_dims'].astype(str).apply(std_norm_dims_stringify)
_exampl2_std_df.sort_values(['norm_dims'], inplace=True)
fig = common_plot_nested_list(_exampl2_std_df, color='norm_dims', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_short)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Step'),
    yaxis = dict(title='QST_90'),
)
fig.show()
fig.write_image((IMAGES_PATH / "gst_90_vs_steps_l2_norm_post_add_c_1e-3_short.pdf"), format="pdf")

In [24]:
# 4. impact of placement on GST_90 for std norm for c=1e-3 and norm_dims=(1, 2)
_exampl4_std_df = std_df[(std_df['norm_dims'] == "(1, 2)") & (std_df['c'] == 1e-3) & (std_df['lr'] == 1e-3)].copy()
_exampl4_std_df['steps'] = _exampl4_std_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
_exampl4_std_df[STD_TAIL_KEY] = _exampl4_std_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
_exampl4_std_df['placement'] = _exampl4_std_df['placement'].apply(lambda p: p if p != "post_attn_and_ff" else "post_layer")
_exampl4_std_df.sort_values(['placement'], inplace=True)
fig = common_plot_nested_list(_exampl4_std_df, color='placement', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_short)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Step'),
    yaxis = dict(title='GST_90'),
)
fig.show()
fig.write_image((IMAGES_PATH / "gst_90_vs_steps_std_norm_norm_dims_1_2_c_1e-3_short.pdf"), format="pdf")

In [25]:
# 5. impact of placement on GST_90 for l2 norm for c=1e-3 and norm_dims=(1, 2)
_exampl5_std_df = scale_norm_df[(scale_norm_df['norm_dims'] == (1, 2)) & (scale_norm_df['c'] == 1e-3) & (scale_norm_df['k'] == '1') & (scale_norm_df['lr'] == 1e-3)].copy()
_exampl5_std_df['steps'] = _exampl5_std_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
_exampl5_std_df[STD_TAIL_KEY] = _exampl5_std_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
_exampl5_std_df['placement'] = _exampl5_std_df['placement'].apply(lambda p: p if p != "post_attn_and_ff" else "post_layer")
_exampl5_std_df.sort_values(['placement'], inplace=True)
fig = common_plot_nested_list(_exampl5_std_df, color='placement', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_short)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Step'),
    yaxis = dict(title='QST_90'),
)


In [26]:
# 6. comparing std norm and l2 norm for c=1e-3 and norm_dims=(1, 2) and placement=post_add on GST_90
_c = 1e-3
_placement = 'post_norm'
_exampl6_std_df = std_df[(std_df['norm_dims'] == "(1, 2)") & (std_df['c'] == _c) & (std_df['lr'] == 1e-3) & (std_df['placement'] == _placement)].copy()
_exampl6_std_df['steps'] = _exampl6_std_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
_exampl6_std_df[STD_TAIL_KEY] = _exampl6_std_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
_exampl6_std_df['norm_type'] = 'Std Norm'
_exampl6_scale_df = scale_norm_df[(scale_norm_df['norm_dims'] == (1, 2)) & (scale_norm_df['c'] == _c) & (scale_norm_df['k'] == '1') & (scale_norm_df['lr'] == 1e-3) & (scale_norm_df['placement'] == _placement)].copy()
_exampl6_scale_df['steps'] = _exampl6_scale_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
_exampl6_scale_df[STD_TAIL_KEY] = _exampl6_scale_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
_exampl6_scale_df['norm_type'] = 'L2 Norm'
_exmpl6_df = pd.concat([_exampl6_std_df, _exampl6_scale_df])
_exmpl6_df.sort_values('norm_type', inplace=True)
fig = common_plot_nested_list(_exmpl6_df, color='norm_type', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_short)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='Step'),
    yaxis = dict(title='GST_90'),
)
fig.show()
fig.write_image((IMAGES_PATH / "gst_90_vs_steps_std_vs_l2_norm_post_norm_norm_dims_1_2_c_1e-3_short.pdf"), format="pdf")

#### Loss vs LR categorised by c and norm_dims for each placement in 'norm_scale' grad norm

In [27]:
sanity_check_df = project.fetch_runs_table(owner="szysad", tag=["c_0", "scale_norm_c_lr_grid_placement_short"], columns=columns).to_pandas()
sanity_check_df = rename_to_common(sanity_check_df)
sanity_check_df['k'] = sanity_check_df['grad_modif_params'].apply(infere_k)
#sanity_check_df = sanity_check_df[sanity_check_df['k'] == "auto"]
sanity_check_df['c'] = sanity_check_df['grad_modif_params'].apply(infere_c)
sanity_check_df['placement'] = sanity_check_df['sys/tags'].apply(infere_placement_type)
sanity_check_df['norm_dims'] = sanity_check_df['grad_modif_params'].apply(infere_norm_dims)
sanity_check_df['category'] = sanity_check_df.apply(lambda x: f"placement={x.placement}, c={x.c}, k={x.k}, norm_dims={x.norm_dims}", axis=1)
sanity_check_df.sort_values('lr', inplace=True)

fig = common_line_plot(sanity_check_df, x="lr", y="loss", color="category", title=f"Sanity check Loss vs LR categorised by placement, c and norm_dims for grad norm scale_norm", trace=baseline_trace_loss_lr_short)
fig.show()

In [28]:

plot_cnt = len(placements)
L2_NORM_SHORT_RUN_LOSS_RANGE = tuple(map(np.log10, (4.5, 6.8)))
scale_norm_df = project.fetch_runs_table(tag="scale_norm_c_lr_grid_placement_short", columns=columns).to_pandas()
scale_norm_df = rename_to_common(scale_norm_df)
scale_norm_df['k'] = scale_norm_df['grad_modif_params'].apply(infere_k)
#scale_norm_df = scale_norm_df[scale_norm_df['k'] == "auto"]
scale_norm_df = scale_norm_df[scale_norm_df['loss'] < 8.0]
scale_norm_df['c'] = scale_norm_df['grad_modif_params'].apply(infere_c)
scale_norm_df = scale_norm_df[scale_norm_df['c'] != 0]
scale_norm_df['placement'] = scale_norm_df['sys/tags'].apply(infere_placement_type)
scale_norm_df['norm_dims'] = scale_norm_df['grad_modif_params'].apply(infere_norm_dims)
scale_norm_df['category'] = scale_norm_df.apply(lambda x: f"{x.placement} k={x.k}, norm_dims={x.norm_dims}", axis=1)
scale_norm_df.sort_values('lr', inplace=True)
for i, c in enumerate(sorted(std_df['c'].unique())):
    df = scale_norm_df[scale_norm_df['c'] == c]
    title = f"Loss vs Learning Rate for L2 norm, c={c}. Runs with final loss < 8.0"
    fig = common_line_plot(df, x="lr", y="loss", color="category", title=title, trace=baseline_trace_loss_lr_short, y_range=L2_NORM_SHORT_RUN_LOSS_RANGE)
    fig.update_layout(
        autosize=False,
        width=1600,
        height=600,
    )
    fig.show()

Fetching table...: 0 [00:00, ?/s]

In [29]:
scale_norm_optimal_lr_df = scale_norm_df[(scale_norm_df['lr'] == 1e-3) & (scale_norm_df['sys/failed'] == False)].copy()
scale_norm_optimal_lr_df[STD_TAIL_KEY] = scale_norm_optimal_lr_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
scale_norm_optimal_lr_df['steps'] = scale_norm_optimal_lr_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
scale_norm_optimal_lr_df['color'] = scale_norm_optimal_lr_df['placement'].astype(str) + ", k=" + scale_norm_optimal_lr_df['k'].astype(str) + ", norm_dims=" + scale_norm_optimal_lr_df['norm_dims'].astype(str)
scale_norm_optimal_lr_df.sort_values('color', inplace=True)

for c in sorted(std_df['c'].unique()):
    df = scale_norm_optimal_lr_df[scale_norm_optimal_lr_df['c'] == c]
    title = f"90th percentyle of gradients standard derivation vs steps for L2 Norm, c={c}"
    fig = common_plot_nested_list(df, color='color', x="steps", y=STD_TAIL_KEY, title=title, trace=baseline_grad_norm_std_short)
    fig.update_layout(
        xaxis = dict(title='step'),
        yaxis = dict(title='grad_std_percentile_90'),
        autosize=False,
        width=1600,
        height=600,
    )
    fig.show()

## Relation of loss to gradient std

In [30]:
cs = [1e-1, 1e-2, 1e-3]

scale_norm_df = project.fetch_runs_table(tag="scale_norm_c_lr_grid_placement_short", columns=columns).to_pandas()
scale_norm_df = rename_to_common(scale_norm_df)
scale_norm_df['k'] = scale_norm_df['grad_modif_params'].apply(infere_k)
scale_norm_df = scale_norm_df[scale_norm_df['k'] == "auto"]
scale_norm_df['c'] = scale_norm_df['grad_modif_params'].apply(infere_c)
scale_norm_df = scale_norm_df[scale_norm_df['c'] != 0]
scale_norm_df['placement'] = scale_norm_df['sys/tags'].apply(infere_placement_type)
scale_norm_df['norm_dims'] = scale_norm_df['grad_modif_params'].apply(infere_norm_dims)
scale_norm_df['category'] = scale_norm_df.apply(lambda x: f"{x.placement} k={x.k}, norm_dims={x.norm_dims}", axis=1)
scale_norm_df.sort_values('lr', inplace=True)

Fetching table...: 0 [00:00, ?/s]

In [31]:
dfs = []
std_norm = {
    "std_v1": "(0, 1, 2)",
    "std_v2": "(2)",
    "std_v4": "(1, 2)"
}
for gn_name, gn_tag in grad_std_norm_type_tag.items():
    _df = project.fetch_runs_table(tag=gn_tag, columns=columns).to_pandas()
    _df['type'] = "std_norm"
    _df['std_v'] = gn_name
    _df['norm_dims'] = std_norm[gn_name]
    dfs.append(_df)

scale_norm_df = project.fetch_runs_table(tag="scale_norm_c_lr_grid_placement_short", columns=columns).to_pandas()
scale_norm_df['type'] = "scale_norm"
dfs.append(scale_norm_df)

all_df = pd.concat(dfs, ignore_index=True)
all_df = rename_to_common(all_df)
all_df['k'] = all_df['grad_modif_params'].apply(infere_k)
all_df['c'] = all_df['grad_modif_params'].apply(infere_c)
all_df = all_df[(all_df['c'] != 0) & (all_df['lr'] == 1e-3) & (all_df['sys/failed'] == False)]

all_df['placement'] = all_df['sys/tags'].apply(infere_placement_type)
all_df['grad_std_q_series'] = all_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
all_df['grad_std_q_mean'] = all_df['grad_std_q_series'].apply(np.mean)
all_df['color'] = all_df['placement'].astype(str)



Fetching table...: 0 [00:00, ?/s]

Fetching table...: 0 [00:00, ?/s]

In [32]:
log_axis_layout = dict(
    showexponent = 'all',
    exponentformat = 'power',
    type='log'
)
corr_df = all_df[all_df['loss'] < 5]
corr_df_std = corr_df[corr_df['type'] == "std_norm"]
corr_df_scale_norm = corr_df[corr_df['type'] == "scale_norm"]
corr_df_scale_norm['norm_dims'] = corr_df_scale_norm['grad_modif_params'].apply(infere_norm_dims)
baseline_trace_std_loss = go.Scatter(x=[np.mean(baseline_df_optimal_lr[STD_TAIL_KEY].values[1])], y=[baseline_df_optimal_lr['loss'].values[1]], mode='markers', name='baseline', marker=dict(symbol='x', color='white', size=8, line=dict(width=1, color='black')))



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [33]:
fig = px.scatter(corr_df_std, x='grad_std_q_mean', y='loss', color=np.log10(corr_df_std['c']), hover_name='sys/id', hover_data=['c', 'placement', 'norm_dims', 'std_v'], title="Loss vs mean of 90th quantile of gradient norm std for std normalizations", log_x=True, log_y=True)
fig.add_trace(baseline_trace_std_loss)
fig.update_layout(yaxis = log_axis_layout, xaxis = log_axis_layout)
fig.show()

In [34]:
fig = px.scatter(corr_df_scale_norm, x='grad_std_q_mean', y='loss', color=np.log10(corr_df_scale_norm['c']), hover_name='sys/id', hover_data=['c', 'k', 'placement', 'norm_dims'], title="Loss vs mean of 90th quantile of gradient norm std for scale_norm", log_x=True, log_y=True)
fig.add_trace(baseline_trace_std_loss)
fig.update_layout(yaxis = log_axis_layout, xaxis = log_axis_layout)
fig.show()

In [35]:
# TODO add continous color palette for c
corr_df_std.loc[:, 'Norm Type'] = 'Std Norm'
corr_df_scale_norm.loc[:, 'Norm Type'] = 'L2 Norm'
corr_df_all = pd.concat([corr_df_std, corr_df_scale_norm], ignore_index=True)
corr_df_all.sort_values('c', inplace=True)

fig = px.scatter(
    corr_df_all,
    x='grad_std_q_mean',
    y='loss',
    color=np.log10(corr_df_all['c']),
    color_continuous_scale='Sunset_r',
    hover_name='sys/id',
    hover_data=['c', 'k', 'placement', 'norm_dims', 'std_v'],
    #title="Loss vs mean of 90th quantile of gradient norm std for all normalizations",
    log_x=True,
    log_y=True,
    symbol='Norm Type'
)
fig.add_trace(baseline_trace_std_loss)
fig.update_layout(
    yaxis=log_axis_layout,
    xaxis=log_axis_layout,
    autosize=False,
    width=1600,
    height=800,
    coloraxis_colorbar=dict(
        title='c',
        yanchor="top",
        y=1, 
        x=1.1,
        ticks="outside",
        tickprefix="1e",
    )
)
fig.update_xaxes(title_text='mean GST_90')
cs = set()
for i, trace in enumerate(fig.data):
    name = trace.name.split(',')
    if len(name) == 2:
        if float(name[0]) in cs:
            trace['name'] = ''
            trace['showlegend']=False
        else:
            cs.add(float(name[0]))
            trace['name'] = name[0]

fig.add_shape(
    type="rect",
    x1=baseline_trace_std_loss.x[0],
    y1=baseline_trace_std_loss.y[0],
    x0=5e-6,
    y0=4.55,
    fillcolor="red",
    opacity=0.1,
    #layer="below",
    line=dict(width=0)
)

'''
fig.add_shape(
    type="line",
    x0=baseline_trace_std_loss.x[0],
    y0=baseline_trace_std_loss.y[0],
    x1=baseline_trace_std_loss.x[0],
    y1=4.5,
    line=dict(
        color="red",
        width=2,
        dash="dash",
        opacity=0.5
    )
)
fig.add_shape(
    type="line",
    x0=baseline_trace_std_loss.x[0],
    y0=baseline_trace_std_loss.y[0],
    x1=5e-6,
    y1=baseline_trace_std_loss.y[0],
    line=dict(
        color="red",
        width=2,
        dash="dash"
    )
)
'''

#fig.add_trace(go.Scatter(y=[None], mode='markers',
##                         marker=dict(symbol='circle', color='black'),
#                         name='Std Norm',
#                         ))
#fig.add_trace(go.Scatter(y=[None], mode='markers',
#                         marker=dict(symbol='diamond', color='black'),
#                         name='L2 Norm',
#                         ))
fig.update_traces(marker=dict(size=10))
fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    font=dict(
        size=FONT_SIZE
    ),
    xaxis = dict(title='mean GST_90', range=[np.log10(0.95e-5), np.log10(6.5 * 1e-5)]),
    yaxis = dict(title='Loss', range=[np.log10(4.56), np.log10(4.616)]),
    coloraxis_colorbar=dict(yanchor="top", y=1, x=1, ticks="inside", tickprefix="1e"),
    legend=dict(yanchor="top", y=0.3, x=0.8557, xanchor="left"),
)
fig.show()
fig.write_image((IMAGES_PATH / "loss_vs_gst_90_all.pdf"), format="pdf")



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



# Long experiments

### std_v1 params
|     placement    	| c    	|
|:----------------:	|------	|
| post_add         	| 1e-4 	|
| post_attn_and_ff 	| 1e-4 	|
| post_norm        	| 1e-3 	|
| all              	| 1e-5 	|

### std_v2 params
|     placement    	| c    	|
|:----------------:	|------	|
| post_attn_and_ff 	| 1e-5 	|
| post_add         	| 1e-4 	|
| post_add         	| 1e-5 	|
| post_attn_and_ff 	| 1e-4 	|

### std_v4 params
|     placement    	|   c  	|
|:----------------:	|:----:	|
|     post_add     	| 1e-4 	|
|     post_norm    	| 1e-3 	|
| post_attn_and_ff 	| 1e-4 	|
|        all       	| 1e-5 	|

### scale_norm params
|     placement    	| c    	| k    	| norm_dims 	|
|:----------------:	|------	|------	|-----------	|
| post_add         	| 1e-4 	| 1    	| 0,1,2     	|
| post_add         	| 1e-4 	| 1    	| 1,2       	|                                   
| post_add         	| 1    	| auto 	| 0,1,2     	|
| post_add         	| 1    	| auto 	| 1,2       	|
| all              	| 1e-4 	| 1    	| 0,1,2     	|
| all              	| 1e-5 	| 1    	| 0,1,2     	|
| post_norm        	| 1e-4 	| 1    	| 1,2       	|
| post_norm        	| 1e-4 	| 1    	| 0,1,2     	|
| post_attn_and_ff 	| 1e-4 	| 1    	| 1,2       	|
| post_attn_and_ff 	| 1e-3 	| 1    	| 1,2       	|


# Long Experiments (16k steps)

In [36]:
baseline_long_loss_df = baseline_df_optimal_lr[baseline_df_optimal_lr['step'] == 16_000].copy()
baseline_long_loss_df['loss_100'] = baseline_long_loss_df['sys/id'].apply(partial(get_run_loss_series, series_cache=series_cache))
baseline_long_loss_df['loss_1000'] = baseline_long_loss_df['sys/id'].apply(partial(get_run_loss_series, series_cache=series_cache, smoothing="1000"))

baseline_long_loss_trace_100 = go.Scatter(x=baseline_long_loss_df['steps'].values[0], y=baseline_long_loss_df['loss_100'].values[0], mode='lines', name='baseline', line=dict(color='black', width=2, dash='dash'))
baseline_long_loss_trace_1000 = go.Scatter(x=baseline_long_loss_df['loss_steps_1000'].values[0], y=baseline_long_loss_df['loss_1000'].values[0], mode='lines', name='baseline', line=dict(color='black', dash='dash', width=2))

In [37]:
# TODO add continous color palette for c
corr_df_std.loc[:, 'Norm'] = 'Std Norm'
corr_df_scale_norm.loc[:, 'Norm'] = 'L2 Norm'
corr_df_all = pd.concat([corr_df_std, corr_df_scale_norm], ignore_index=True)
corr_df_all.sort_values('c', inplace=True)

fig = px.scatter(
    corr_df_all,
    x='grad_std_q_mean',
    y='loss',
    color=np.log10(corr_df_all['c']),
    color_continuous_scale='Sunset_r',
    hover_name='sys/id',
    hover_data=['c', 'k', 'placement', 'norm_dims', 'std_v'],
    #title="Loss vs mean of 90th quantile of gradient norm std for all normalizations",
    log_x=True,
    log_y=True,
    symbol='Norm'
)
fig.add_trace(baseline_trace_std_loss)
fig.update_layout(
    yaxis=log_axis_layout,
    xaxis=log_axis_layout,
    autosize=False,
    width=1600,
    height=800,
    coloraxis_colorbar=dict(
        title='c',
        yanchor="top",
        y=1, 
        x=1.1,
        ticks="outside",
        tickprefix="1e",
    )
)
fig.update_xaxes(title_text='mean GST_90')
cs = set()
for i, trace in enumerate(fig.data):
    name = trace.name.split(',')
    if len(name) == 2:
        if float(name[0]) in cs:
            trace['name'] = ''
            trace['showlegend']=False
        else:
            cs.add(float(name[0]))
            trace['name'] = name[0]

fig.add_shape(
    type="rect",
    x1=baseline_trace_std_loss.x[0],
    y1=baseline_trace_std_loss.y[0],
    x0=5e-6,
    y0=4.55,
    fillcolor="red",
    opacity=0.1,
    #layer="below",
    line=dict(width=0)
)

'''
fig.add_shape(
    type="line",
    x0=baseline_trace_std_loss.x[0],
    y0=baseline_trace_std_loss.y[0],
    x1=baseline_trace_std_loss.x[0],
    y1=4.5,
    line=dict(
        color="red",
        width=2,
        dash="dash",
        opacity=0.5
    )
)
fig.add_shape(
    type="line",
    x0=baseline_trace_std_loss.x[0],
    y0=baseline_trace_std_loss.y[0],
    x1=5e-6,
    y1=baseline_trace_std_loss.y[0],
    line=dict(
        color="red",
        width=2,
        dash="dash"
    )
)
'''

#fig.add_trace(go.Scatter(y=[None], mode='markers',
##                         marker=dict(symbol='circle', color='black'),
#                         name='Std Norm',
#                         ))
#fig.add_trace(go.Scatter(y=[None], mode='markers',
#                         marker=dict(symbol='diamond', color='black'),
#                         name='L2 Norm',
#                         ))
fig.show()



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [38]:

long_df = project.fetch_runs_table(tag="grad_norm_long_final", columns=columns).to_pandas()
long_df = rename_to_common(long_df)
long_df['norm_type'] = long_df['sys/tags'].apply(infere_norm_type)
std_norm_mask = long_df['norm_type'] == "std_norm"

long_df.loc[std_norm_mask, 'norm_dims'] = long_df.loc[std_norm_mask, 'sys/tags'].apply(infere_std_norm_dims_from_tags)
long_df['placement'] = long_df['sys/tags'].apply(infere_placement_type)
long_df['placement_str'] = long_df['placement'].apply(lambda x: x if x != "post_attn_and_ff" else "post_layer")
long_df['c'] = long_df['grad_modif_params'].apply(infere_c)
long_df['k'] = long_df['grad_modif_params'].apply(infere_k)
long_df.loc[~std_norm_mask, 'norm_dims'] = long_df.loc[~std_norm_mask, 'grad_modif_params'].apply(infere_norm_dims)
long_df[STD_TAIL_KEY] = long_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
long_df['grad_std_q_mean'] = long_df[STD_TAIL_KEY].apply(np.mean)
long_df['steps'] = long_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
long_df['loss_100_steps'] = long_df['sys/id'].apply(partial(get_run_loss_steps_series, series_cache=series_cache, smoothing="100"))
long_df['loss_1000_steps'] = long_df['sys/id'].apply(partial(get_run_loss_steps_series, series_cache=series_cache, smoothing="1000"))
long_df['loss_100_series'] = long_df['sys/id'].apply(partial(get_run_loss_series, series_cache=series_cache, smoothing="100"))
long_df['loss_1000_series'] = long_df['sys/id'].apply(partial(get_run_loss_series, series_cache=series_cache, smoothing="1000"))
long_df_strong = long_df[long_df['k'] != "auto"].copy()
long_df_strong['c_str'] = long_df_strong['c'].apply(lambda x: {1e-3: "1e-3", 1e-4: "1e-4", 1e-5: "1e-5"}[x])
long_df_strong['norm_dims_str'] = long_df_strong['norm_dims'].apply(lambda x: {(0, 1, 2): 'all', (1, 2): 'seq_len x d_m', (2,): 'd_m'}[x])
long_df_strong['norm_type_str'] = long_df_strong['norm_type'].apply(lambda x: {"std_norm": "Std Norm", "scale_norm": "L2 Norm"}[x])
long_df_strong['color'] = long_df_strong['placement'].astype(str) + ", " + long_df_strong['norm_type_str'].astype(str) + ", c=" + long_df_strong['c_str'] + ", norm_dims=" + long_df_strong['norm_dims_str'].astype(str)


In [39]:
print(f"c: {long_df_strong['c'].unique()}")
print(f"norm_dims: {long_df_strong['norm_dims'].unique()}")
print(f"placement: {long_df_strong['placement'].unique()}")
print(f"norm_type: {long_df_strong['norm_type'].unique()}")

c: [1.e-04 1.e-03 1.e-05]
norm_dims: [(1, 2) (0, 1, 2) (2,)]
placement: ['post_norm' 'post_attn_and_ff' 'all' 'post_add']
norm_type: ['scale_norm' 'std_norm']


In [40]:
# loss curves
title = "Loss vs steps for long runs"
fig = common_plot_nested_list(long_df_strong, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig.update_layout(
    xaxis = dict(title='step'),
    yaxis = dict(title='loss'),
    autosize=False,
    width=1600,
    height=800
)
fig.show()

In [41]:
# 1. c and norm_dims impact on loss curves for std norm  !!! those with norm dims = d_m have higher loss
_df_1 = long_df_strong[(long_df_strong['placement'] == 'post_attn_and_ff') & (long_df_strong['norm_type'] == 'std_norm')].copy()
_df_1['color'] = _df_1['norm_dims_str'] + ", c=" + _df_1['c_str']
_df_1.sort_values('color', inplace=True)
fig = common_plot_nested_list(_df_1, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='Step', range=[8_000, 16_000]),
    yaxis = dict(title='Loss', range=[np.log10(3.9), np.log10(4.1)]),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=1, x=0.79, xanchor="left"),
)
fig.show()
fig.write_image((IMAGES_PATH / "loss_vs_steps_std_norm_post_attn_and_ff_long.pdf"), format="pdf")

In [42]:
# 2. c and norm_dims impact on loss curves for std norm  !!! all those have c=1e-3 and have lower loss then baseline
_df_2 = long_df_strong[(long_df_strong['placement'] == 'post_norm') & (long_df_strong['norm_type'] == 'std_norm')].copy()
_df_2['color'] = _df_2['norm_dims_str'] + ", c=" + _df_2['c_str']
_df_2.sort_values('color', inplace=True)
fig = common_plot_nested_list(_df_2, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='step', range=[8_000, 16_000]),
    yaxis = dict(title='loss', range=[np.log10(3.9), np.log10(4.1)]),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=1, x=0.79, xanchor="left"),
)
fig.show()
fig.write_image((IMAGES_PATH / "loss_vs_steps_std_norm_post_norm_long.pdf"), format="pdf")

In [43]:
# 2 prim. c and norm_dims impact on loss curves for std norm  !!! all those have c=1e-3 and have lower loss then baseline
_df_2 = long_df_strong[(long_df_strong['c'] == 1e-3)].copy()
_df_2['color'] = + _df_2['norm_type_str'] + ', ' + _df_2['norm_dims_str'].astype(str) + ", " + _df_2['placement_str']
_df_2.sort_values('color', inplace=True)
fig = common_plot_nested_list(_df_2, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='Step', range=[8_000, 16_000]),
    yaxis = dict(title='Loss', range=[np.log10(3.9), np.log10(4.1)]),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=1, x=0.674, xanchor="left"),
)
fig.show()
fig.write_image((IMAGES_PATH / "loss_vs_steps_c_1e-3_long.pdf"), format="pdf")

In [44]:
# 3. c and norm_dims impact on loss curves for l2 norm  !!! again here c=1e-3 is better then baseline
_df_3 = long_df_strong[(long_df_strong['norm_type'] == 'scale_norm') & (long_df_strong['placement'] == 'post_attn_and_ff')].copy()
_df_3['color'] = _df_3['c_str'] + ", " + _df_3['norm_dims_str'] + ", " + _df_3['placement']
fig = common_plot_nested_list(_df_3, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='step', range=[6000, 16000]),
    yaxis = dict(title='loss', range=[np.log10(3.9), np.log10(4.2)]),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=1, x=0.68, xanchor="left"),
)
fig.show()

In [45]:
# 4. c and placement for l2 norm and norm_dims=seq_len x d_m
_norm_dims = ['seq_len x d_m', 'all']
_norm_type = 'std_norm'

_df_4_1 = long_df_strong[(long_df_strong['norm_dims_str'] == _norm_dims[0]) & (long_df_strong['norm_type'] == _norm_type) & (long_df_strong['c'] == 1e-4)].copy()
_df_4_1['color'] = _df_4_1['placement'].astype(str) + ", " + _df_4_1['norm_dims_str']
_df_4_1.sort_values('color', inplace=True)
fig1 = common_plot_nested_list(_df_4_1, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig1.update_traces(line=dict(width=3))
fig1.update_layout(
    xaxis = dict(title='Step', range=[11_000, 16000]),
    yaxis = dict(title='Loss', range=[np.log10(3.9), np.log10(3.99)]),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=1, x=0.71, xanchor="left"),
)
fig1.show()
fig1.write_image((IMAGES_PATH / "loss_vs_steps_std_norm_c_1e-4_seq_len_dm_long.pdf"), format="pdf")

_df_4_2 = long_df_strong[(long_df_strong['norm_dims_str'] == _norm_dims[1]) & (long_df_strong['norm_type'] == _norm_type) & (long_df_strong['c'] == 1e-4)].copy()
_df_4_2['color'] = _df_4_2['placement'].astype(str) + ", " + _df_4_2['norm_dims_str']
_df_4_2.sort_values('color', inplace=True)
fig2 = common_plot_nested_list(_df_4_2, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig2.update_traces(line=dict(width=3))
fig2.update_layout(
    xaxis = dict(title='Step', range=[11_000, 16000]),
    yaxis = dict(title='Loss', range=[np.log10(3.9), np.log10(3.99)]),
    legend=dict(yanchor="top", y=1, x=0.807, xanchor="left"),
    font=dict(size=FONT_SIZE),
    autosize=False,
    width=1600,
    height=800
)
fig2.show()
fig2.write_image((IMAGES_PATH / "loss_vs_steps_std_norm_c_1e-4_all_long.pdf"), format="pdf")

In [46]:
# 4. influence of norm dims on loss for std norm and l2 norm placement=post_add and c=1e-3
_df_4 = long_df_strong[(long_df_strong['c'] == 1e-4) & (long_df_strong['placement'] == "post_norm")].copy()
_df_4['color'] = _df_4['norm_type_str'] + ", " + _df_4['norm_dims_str']
fig = common_plot_nested_list(_df_4, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='step'),
    yaxis = dict(title='loss'),
    autosize=False,
    width=1600,
    height=800
)
fig.show()

In [47]:
# 5. influence of norm dims on std norm placement=post_norm

_df_5 = long_df_strong[(long_df_strong['placement'] == "post_attn_and_ff") & (long_df_strong['norm_type'] == "std_norm")].copy()
_df_5['color'] = _df_5['norm_dims_str'] + ", c=" + _df_5['c_str']
fig = common_plot_nested_list(_df_5, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='step'),
    yaxis = dict(title='loss'),
    autosize=False,
    width=1600,
    height=800
)  
fig.show()


In [48]:
# 5. influence of norm_dims on loss for l2 norm and placement=post_add and c=1e-3 and k=auto
_df_5 = long_df[(long_df['c'] == 1.0) & (long_df['placement'] == "post_add") & (long_df['norm_type'] == 'scale_norm')].copy()
_df_5['color'] = _df_5['norm_dims'].astype(str) + ", k=" + _df_5['k'].astype(str)
fig = common_plot_nested_list(_df_5, color='color', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='step'),
    yaxis = dict(title='loss'),
    autosize=False,
    width=1600,
    height=800
)
fig.show()

In [49]:
title = "GST_90 vs steps for long runs"
fig = common_plot_nested_list(long_df_strong, color='color', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_long)
fig.update_layout(
    xaxis = dict(title='Step'),
    yaxis = dict(title='GSQ_90'),
    autosize=False,
    width=1400,
    height=800,
    legend=dict(font = dict(size = 13))
)
fig.show()
#fig.write_image((IMAGES_PATH / "std_90_curves_promising_long.pdf").as_posix(), format="pdf")

In [50]:
from scipy.signal import savgol_filter
MOVING_AVG_WINDOW = 6
def moving_avg(x: List[float]) -> List[float]:
    return savgol_filter(x, MOVING_AVG_WINDOW, 0)
# steps are every 160 steps
# values are averaged 160 * MOVING_AVG_WINDOW steps
# 1. c and norm_dims impact on loss curves for std norm
baseline_grad_norm_std_long_smooth = go.Scatter(x=baseline_df_optimal_lr['steps'].values[0][:-MOVING_AVG_WINDOW], y=moving_avg(baseline_df_optimal_lr[STD_TAIL_KEY].values[0]), mode='lines', name='baseline', line=dict(color='black', width=2, dash='dash'))

In [51]:
# 1 impact of norm_dims on std norm placement=post_attn_and_ff c=1e-4
_df_1 = long_df_strong[(long_df_strong['placement'] == 'post_attn_and_ff') & (long_df_strong['norm_type'] == 'std_norm') & (long_df_strong['c'] == 1e-4)].copy()
_df_1['y_smoothed'] = _df_1[STD_TAIL_KEY].apply(moving_avg)
_df_1['steps'] = _df_1['steps'].apply(lambda x: x[:-MOVING_AVG_WINDOW])
_df_1['color'] = _df_1['norm_dims_str'] + ", c=" + _df_1['c_str']
_df_1.sort_values('color', inplace=True)
fig = common_plot_nested_list(_df_1, color='color', x='steps', y='y_smoothed', title=None, trace=baseline_grad_norm_std_long_smooth)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='Step'),
    yaxis = dict(title='GST_90'),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=0.221, x=0.78, xanchor="left"),
)
fig.show()
fig.write_image((IMAGES_PATH / "std_90_vs_steps_std_norm_post_layer_c_1e-4_long.pdf").as_posix(), format="pdf")

In [52]:
# 2. c and norm_dims impact on std norm curves for std norm . RESULTS AVERAGED!!!
_df_std_2 = long_df_strong[long_df_strong['c'] == 1e-3].copy()
_df_std_2['color'] = _df_std_2['norm_type_str'] + ", " + _df_std_2['norm_dims_str'] + ", " + _df_std_2['placement_str']
_df_std_2[STD_TAIL_KEY] = _df_std_2[STD_TAIL_KEY].apply(moving_avg)
_df_std_2['steps'] = _df_std_2['steps'].apply(lambda x: x[:-MOVING_AVG_WINDOW])
_df_std_2.sort_values('color', inplace=True)
fig = common_plot_nested_list(_df_std_2, color='color', x='steps', y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_long_smooth)
fig.update_traces(line=dict(width=3))
fig.update_layout(
    xaxis = dict(title='Step'),
    yaxis = dict(title='GST_90'),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=0.2215, x=0.661, xanchor="left"),
)
fig.show()
fig.write_image((IMAGES_PATH / "std_90_vs_steps_c_1e-3_long.pdf").as_posix(), format="pdf")

In [53]:
# impact of placement of gradient norm
_norm_dims = ['seq_len x d_m', 'all']
_norm_type = 'std_norm'

_df_4_1 = long_df_strong[(long_df_strong['norm_dims_str'] == _norm_dims[0]) & (long_df_strong['norm_type'] == _norm_type) & (long_df_strong['c'] == 1e-4)].copy()
_df_4_1[STD_TAIL_KEY] = _df_4_1[STD_TAIL_KEY].apply(moving_avg)
_df_4_1['steps'] = _df_4_1['steps'].apply(lambda x: x[:-MOVING_AVG_WINDOW])
_df_4_1['color'] = _df_4_1['placement'].astype(str) + ", " + _df_4_1['norm_dims_str']
_df_4_1.sort_values('color', inplace=True)
fig1 = common_plot_nested_list(_df_4_1, color='color', x='steps', y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_long_smooth)
fig1.update_traces(line=dict(width=3))
fig1.update_layout(
    xaxis = dict(title='MSQ_90'),
    yaxis = dict(title='Loss'),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=0.17, x=0.699, xanchor="left"),
)
fig1.show()
#fig1.write_image((IMAGES_PATH / "msq_90_vs_steps_std_norm_c_1e-4_seq_len_dm_long.pdf"), format="pdf")

_df_4_2 = long_df_strong[(long_df_strong['norm_dims_str'] == _norm_dims[1]) & (long_df_strong['norm_type'] == _norm_type) & (long_df_strong['c'] == 1e-4)].copy()
_df_4_2[STD_TAIL_KEY] = _df_4_2[STD_TAIL_KEY].apply(moving_avg)
_df_4_2['steps'] = _df_4_2['steps'].apply(lambda x: x[:-MOVING_AVG_WINDOW])
_df_4_2['color'] = _df_4_2['placement'].astype(str) + ", " + _df_4_2['norm_dims_str']
_df_4_2.sort_values('color', inplace=True)
fig2 = common_plot_nested_list(_df_4_2, color='color', x='steps', y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_long_smooth)
fig2.update_traces(line=dict(width=3))
fig2.update_layout(
    xaxis = dict(title='Step'),
    yaxis = dict(title='MSQ_90'),
    legend=dict(yanchor="top", y=0.17, x=0.799, xanchor="left"),
    font=dict(size=FONT_SIZE),
    autosize=False,
    width=1600,
    height=800
)
fig2.show()
#fig2.write_image((IMAGES_PATH / "msq_90_vs_steps_std_norm_c_1e-4_all_long.pdf"), format="pdf")

In [54]:
# average loss, msq_90, mean_grad_norm for strong long runs difference from baseline per step
def rel_change(a, b):
    if len(a) < len(b):
        b_cmp = b[:len(a)]
    else:
        b_cmp = b
    return (a - b_cmp) / b_cmp

def abs_change(a, b):
    if len(a) < len(b):
        b_cmp = b[:len(a)]
    else:
        b_cmp = b
    if len(a) != len(b_cmp):
        print(f"len(a)={len(a)}, len(b_cmp)={len(b_cmp)}")
    return a - b_cmp

avg_final_loss = long_df_strong['loss'].mean()
avg_msq_90_baseline_abs_diff = long_df_strong[STD_TAIL_KEY].apply(lambda l: abs_change(np.array(l), baseline_df_optimal_lr[STD_TAIL_KEY].values[0])).apply(np.mean)
avg_msq_90_baseline_rel_diff = long_df_strong[STD_TAIL_KEY].apply(lambda l: rel_change(np.array(l), baseline_df_optimal_lr[STD_TAIL_KEY].values[0])).apply(np.mean)

In [55]:
avg_msq_90_baseline_abs_diff.mean(), avg_msq_90_baseline_rel_diff.mean()

(-3.4578259242352195e-06, -0.09290103925786083)

In [56]:
baseline_long_opt_lr_eps = baseline_df_long[(baseline_df_long['lr'] == 1e-3) & (baseline_df_long['eps'] == 0)].copy()
baseline_long_opt_lr_eps[STD_TAIL_KEY] = baseline_long_opt_lr_eps['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
baseline_long_opt_lr_eps['grad_std_q_mean'] = baseline_long_opt_lr_eps[STD_TAIL_KEY].apply(np.mean)
baseline_long_opt_lr_eps

Unnamed: 0,sys/creation_time,sys/failed,sys/id,sys/name,sys/tags,grad_modif_params,lr,loss,step,eps,std_q0.9,grad_std_q_mean
3,2024-10-01 11:42:27.594,False,LLMRANDOM-16149,post_add_c_lr_grid_exp_28_lr_mul10_baseline lr...,"grad_norm,lr_mul10,post_add_c_lr_grid_long,std...","layer_type=v1,c=99,eps=0",0.001,3.906353,16000.0,0.0,"[1.7934009974851506e-06, 2.8616716463147895e-0...",3.6e-05


In [57]:
baseline_trace_std_loss_long = go.Scatter(x=[baseline_long_opt_lr_eps['grad_std_q_mean'].values[0]], y=[baseline_long_opt_lr_eps['loss'].values[0]], mode='markers', name='baseline', marker=dict(symbol='x', color='white', size=8, line=dict(width=1, color='black')))
fig = px.scatter(
    long_df_strong,
    x='grad_std_q_mean',
    y='loss',
    color=np.log10(long_df_strong['c']),
    color_continuous_scale='Sunset_r',
    hover_name='sys/id',
    hover_data=['c', 'k', 'placement', 'norm_dims', 'norm_type'],
    #title="Loss vs mean of 90th quantile of gradient norm std for all normalizations",
    log_x=True,
    log_y=True,
    symbol='norm_type',
)
fig.update_traces(marker=dict(
    size=6,
    line=dict(width=1, color='black')),
    selector=dict(mode='markers')
)
fig.add_trace(baseline_trace_std_loss_long)
fig.update_layout(
    yaxis=log_axis_layout,
    xaxis=log_axis_layout,
    autosize=False,
    width=1600,
    height=800,
    coloraxis_colorbar=dict(
        title='c',
        yanchor="top",
        y=1, 
        x=1.1,
        ticks="outside",
        tickprefix="1e",
    )
)

In [58]:
long_original_bgn_df = long_df[long_df['k'] == "auto"].copy()
long_original_bgn_df['norm_dims_str'] = long_original_bgn_df['norm_dims'].apply(lambda x: {(0, 1, 2): 'all', (1, 2): 'seq_len x d_m'}[x])
long_original_bgn_df['color'] = long_original_bgn_df['placement'].astype(str) + ", " + long_original_bgn_df['norm_dims_str'].astype(str)
long_original_bgn_df['color_str'] = "BGN, " + long_original_bgn_df['norm_dims_str']

In [59]:
long_original_bgn_df

Unnamed: 0,sys/creation_time,sys/failed,sys/id,sys/name,sys/tags,grad_modif_params,lr,loss,step,norm_type,...,std_q0.9,grad_std_q_mean,steps,loss_100_steps,loss_1000_steps,loss_100_series,loss_1000_series,norm_dims_str,color,color_str
6,2024-12-04 13:21:46.642,False,LLMRANDOM-25002,grad_norm_long_final_scale_norm_post_add_c_1_k...,"c_1,dims_0_1_2,final,grad_norm,grad_norm_long_...","c=1,k=auto,norm_dims=(0,1,2),eps=0",0.001,3.926646,16000.0,scale_norm,...,"[0.01426330599933863, 0.03376563005149365, 0.0...",0.102015,"[0.0, 160.0, 320.0, 480.0, 640.0, 800.0, 960.0...","[100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700...","[1000.0, 2000.0, 3000.0, 4000.0, 5000.0, 6000....","[8.088240146636963, 6.614896512031555, 6.27346...","[5.939662310600281, 4.807847929000855, 4.52605...",all,"post_add, all","BGN, all"
8,2024-12-04 13:21:46.246,False,LLMRANDOM-25000,grad_norm_long_final_scale_norm_post_add_c_1_k...,"c_1,dims_1_2,final,grad_norm,grad_norm_long_fi...","c=1,k=auto,norm_dims=(1,2),eps=0",0.001,3.970364,16000.0,scale_norm,...,"[0.1763723820447922, 0.34483287036418914, 0.42...",1.494172,"[0.0, 160.0, 320.0, 480.0, 640.0, 800.0, 960.0...","[100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700...","[1000.0, 2000.0, 3000.0, 4000.0, 5000.0, 6000....","[8.352403492927552, 6.758821048736572, 6.42250...","[6.125595627784729, 4.98669127368927, 4.645079...",seq_len x d_m,"post_add, seq_len x d_m","BGN, seq_len x d_m"


In [60]:
fig = common_plot_nested_list(long_original_bgn_df, color='color_str', x='loss_1000_steps', y='loss_1000_series', title=None, trace=baseline_long_loss_trace_1000)
all_loss_series = long_df_strong['loss_100_series'].values
long_strong_mean_loss = [np.mean(ls) for ls in zip(*all_loss_series)]
#fig.add_trace(go.Scatter(
#    x=long_original_bgn_df['steps'].values[0],
#    y=long_strong_mean_loss,
#    mode='lines',
#    name='average of other promising',
#    line=dict(color='green', width=3, dash='dot'),
#))
fig.update_traces(line=dict(width=3)),
fig.update_layout(
    xaxis = dict(title='Step', range=[8_000, 16_000]),
    yaxis = dict(title='Loss', range=[np.log10(3.9), np.log10(4.35)]),
    autosize=False,
    width=1600,
    height=800,
    font=dict(size=FONT_SIZE),
    legend=dict(yanchor="top", y=1, x=0.808, xanchor="left"), 
)
fig.show()
fig.write_image((IMAGES_PATH / "loss_vs_steps_bgn_long.pdf"), format="pdf")

In [61]:
from itertools import zip_longest
all_gst_series = long_df_strong[STD_TAIL_KEY].values
long_strong_mean_gst = []
for ls in zip_longest(*all_gst_series, fillvalue=np.nan):
    arr = np.array(ls)
    arr = arr[~np.isnan(arr)]
    long_strong_mean_gst.append(np.mean(arr))

fig = common_plot_nested_list(long_original_bgn_df, color='color_str', x="steps", y=STD_TAIL_KEY, title=None, trace=baseline_grad_norm_std_long)
fig.add_trace(go.Scatter(
    x=long_original_bgn_df['steps'].values[0],
    y=long_strong_mean_gst,
    mode='lines',
    name='average of other promising',
    line=dict(color='green', width=3, dash='dot'),
))
fig.update_layout(
    xaxis = dict(title='Step'),
    yaxis = dict(title='GSQ_90'),
    autosize=False,
    width=1600,
    height=800,
    legend=dict(font = dict(size = 13))
)
fig.show()
#fig.write_image((IMAGES_PATH / "gsq_curves_original_bgn_final.png").as_posix(), format="png")

In [62]:
long_original_bgn_df['mean_grad_norm'] = long_original_bgn_df['sys/id'].apply(partial(get_run_mean_grad_avg_series, series_cache=series_cache))
long_df_strong['mean_grad_norm'] = long_df_strong['sys/id'].apply(partial(get_run_mean_grad_avg_series, series_cache=series_cache))
baseline_df_optimal_lr['mean_grad_norm'] = baseline_df_optimal_lr['sys/id'].apply(partial(get_run_mean_grad_avg_series, series_cache=series_cache))

all_gst_series = long_df_strong['mean_grad_norm'].values
long_strong_mean_gst = []
for ls in zip_longest(*all_gst_series, fillvalue=np.nan):
    arr = np.array(ls)
    arr = arr[~np.isnan(arr)]
    long_strong_mean_gst.append(np.mean(arr))


baseline_grad_mean_trace = go.Scatter(x=baseline_df_optimal_lr['steps'].values[0], y=baseline_df_optimal_lr['mean_grad_norm'].values[0], mode='lines', name='baseline', line=dict(color='black', width=2, dash='dash'))

fig = common_plot_nested_list(long_original_bgn_df, color='color_str', x="steps", y="mean_grad_norm", title=None, trace=baseline_grad_mean_trace)
fig.add_trace(go.Scatter(
    x=long_original_bgn_df['steps'].values[0],
    y=long_strong_mean_gst,
    mode='lines',
    name='average of other promising runs',
    line=dict(color='green', width=3, dash='dot'),
))
fig.update_layout(
    xaxis = dict(title='Step'),
    yaxis = dict(title='Mean Gradient Norm'),
    autosize=False,
    width=1600,
    height=800,
    legend=dict(font = dict(size = 13))
)
fig.show()
#fig.write_image((IMAGES_PATH / "avg_gradient_norm_curves_original_bgn_long.png").as_posix(), format="png")

In [63]:
"""
fig = make_subplots(rows=4, cols=1, shared_xaxes=True, shared_yaxes=False, specs=[[{"rowspan": 2}], [{}], [{}], [{}]], subplot_titles=None, x_title='Step')
fig.add_trace(go.Scatter(
        x=baseline_grad_norm_std_long.x,
        y=baseline_grad_norm_std_long.y,
        name='GSQ_90',
        line=dict(color='black', width=2, dash='dot'),
        legendgroup='baseline',
        legendgrouptitle_text='Baselines',
    ),
    row=1, col=1
)
# baseline
fig.add_trace(go.Scatter(
    x=baseline_df_optimal_lr['steps'].values[0],
    y=baseline_df_optimal_lr['mean_grad_norm'].values[0],
    name='Mean Gradient Norm',
    line=dict(color='black', width=2),
    legendgroup='baseline',
    ),
    row=1, col=1,
)
"""
fig1 = go.Figure()
colors = ['blue', 'red']
# BGN 1
fig1.add_trace(go.Scatter(
        x=long_original_bgn_df.iloc[1]['steps'],
        y=long_original_bgn_df.iloc[1]['mean_grad_norm'],
        name='Mean Gradient Norm',
        line=dict(color=colors[1], width=2),
    )
)


fig1.add_trace(go.Scatter(
        x=long_original_bgn_df.iloc[1]['steps'],
        y=long_original_bgn_df.iloc[1][STD_TAIL_KEY],
        name='GSQ_90',
        line=dict(color=colors[1], width=2, dash='dot'),
    )
)

# BGN 2
fig2 = go.Figure()
fig2.add_trace(go.Scatter(
        x=long_original_bgn_df.iloc[0]['steps'],
        y=long_original_bgn_df.iloc[0]['mean_grad_norm'],
        name='Mean Gradient Norm',
        line=dict(color=colors[0], width=2),
    )
)

fig2.add_trace(go.Scatter(
        x=long_original_bgn_df.iloc[0]['steps'],
        y=long_original_bgn_df.iloc[0][STD_TAIL_KEY],
        name='GSQ_90',
        line=dict(color=colors[0], width=2, dash='dot'),
    )
)

'''
# Others promising mean
long_strong_mean_grad = []
for ls in zip_longest(*long_df_strong['mean_grad_norm'].values, fillvalue=np.nan):
    arr = np.array(ls)
    arr = arr[~np.isnan(arr)]
    long_strong_mean_grad.append(np.mean(arr))

fig.add_trace(go.Scatter(
    x=long_original_bgn_df['steps'].values[0],
    y=long_strong_mean_grad,
    mode='lines',
    name='Mean Gradient Norm',
    legendgrouptitle_text="Averaged from other promising",
    legendgroup='other_promising',
    line=dict(color='green', width=2),
))

# Others promising MSQ_90
long_strong_mean_msq = []
for ls in zip_longest(*long_df_strong[STD_TAIL_KEY].values, fillvalue=np.nan):
    arr = np.array(ls)
    arr = arr[~np.isnan(arr)]
    long_strong_mean_msq.append(np.mean(arr))

fig.add_trace(go.Scatter(
    x=long_original_bgn_df['steps'].values[0],
    y=long_strong_mean_msq,
    mode='lines',
    name='GSQ_90',
    legendgroup='other_promising',
    line=dict(color='green', width=3, dash='dot'),
))
'''

image_names = [
    "msq_90_and_grad_mean_bgn_seq_len_dm.pdf",
    "msq_90_and_grad_mean_bgn_all.pdf",
]
y_ranges = [(0.5, 2), (0.05, 0.13)]
legend = dict(yanchor="top", y=0.177, x=0.809, xanchor="left")

for fig, y_range, imgname in zip([fig1, fig2], y_ranges, image_names):
    fig.update_traces(line=dict(width=3))
    fig.update_layout(
        xaxis = dict(title="Steps"),
        yaxis = dict(
            title=None,
            showexponent = 'all',
            exponentformat = 'power',
            type='log',
            range=[np.log10(y) for y in y_range]
        ),
        autosize=False,
        width=1600,
        height=600,
        font=dict(size=FONT_SIZE),
        legend=legend,
    )
    fig.show()
    fig.write_image((IMAGES_PATH / imgname).as_posix(), format="pdf")
fig.write_image((IMAGES_PATH / "gradient_statistics_original_bgn_long.png").as_posix(), format="png")
