In [None]:
import pandas as pd
import wandb

api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs("mszawerd-politechnika-warszawska/debug")

summary_list, config_list, name_list, history = [], [], [], []
for run in runs:
    if run.name != 'radiant-sweep-16':
        continue
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files 
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append(
        {k: v for k, v in run.config.items()
         if not k.startswith('_')})

    # .name is the human-readable name of the run.
    name_list.append(run.name)
    history.append([])
    run.history(pandas=True).to_json('data.json', indent=4)

runs_df = pd.DataFrame({
    "summary": summary_list,
    "config": config_list,
    "name": name_list,
    "history": history
})


In [None]:
import json

with open('a.txt', 'w') as fh:
    json.dump([x for x in run.scan_history()], fh, indent=4)

In [None]:
import wandb
from tqdm import tqdm
from tools.project import RAW_PATH

USER_NAME = 'mszawerd-politechnika-warszawska'

api = wandb.Api()


def get_sweep_runs(project: str, sweep_id: str):
    sweep_path = f"{USER_NAME}/{project}/{sweep_id}"
    sweep = api.sweep(sweep_path)
    return sweep.runs


runs = get_sweep_runs('textual-musicgen-small', '4k5q7co8')
# artifact = api.artifact("mszawerd-politechnika-warszawska/debug/run-1b4mmwpt-history:v0")
# artifact = api.artifact("mszawerd-politechnika-warszawska/debug/run-1b4mmwpt-history:v0")

In [None]:
import pandas as pd

stats = {}


def extract_data(path: str):
    df = pd.read_parquet(path)
    fad_columns = [col for col in df.columns if col.startswith("FAD ")]
    fad_dict = {
        f"fad_{col.split(' ')[1].lower()}": df[col].dropna().tolist()
        for col in fad_columns
    }
    fad_dict.update({
        'fad_avg': list(df['fad_avg'].dropna().values),

    })
    return fad_dict


def download_artifacts(project: str, runs):
    for run in tqdm(runs):
        run_id = run.id
        artifact_name = f"{USER_NAME}/{project}/run-{run_id}-history:v0"
        try:
            artifact = api.artifact(artifact_name)
            artifact_dir = artifact.download(RAW_PATH('runs', run_id))
            stats[run_id] = {
                'params': run.config,
                'stats': extract_data(RAW_PATH('runs', run_id, '0000.parquet'))
            }

        except Exception as e:
            print(f"Error downloading artifact for run {run}: {str(e)}")


download_artifacts('textual-musicgen-small', runs)
with open(RAW_PATH('run_stats', 'stats.json'), 'w') as fh:
    import json

    json.dump(stats, fh, indent=4)

In [None]:
import wandb

run = wandb.init()

art_dir = run.use_artifact('mszawerd-politechnika-warszawska/debug/run-1b4mmwpt-history:v0').download()

In [None]:
df = pd.read_parquet(
    '/home/mszawerda/musical-generative-models-conditioning/src/audiocraft/artifacts/run-1b4mmwpt-history:v0/0000.parquet')

In [None]:
concept = 'metal-solos'

df[df[f'FAD {concept}'].notnull()][f'FAD {concept}']

In [None]:
import torch

In [None]:
torch.load(
    "/home/mszawerda/musical-generative-models-conditioning/models/concepts-dataset/giddy-sweep-1-best.pt").keys()