In [55]:
import os
import pandas as pd
from IPython.display import display, display_html
import numpy as np
import matplotlib.pyplot as plt
def samples_seen_from_str(samples):
    if samples.endswith("M"):
        samples = float(samples[:-1]) * 10**6
    elif samples.endswith("B"):
        samples = float(samples[:-1]) * 10**9
    elif samples.endswith("K"):
        samples = float(samples[:-1]) * 10**3
    return samples
def human(v):
    # check closest from `scales`
    dist = np.abs(np.array(scales_numeric) - v)
    if dist.min() < 1e9:
        idx = dist.argmin()
        return scales[idx]
    else:
        if v < 10 ** 6:
            return str(v)
        elif v > 10**6 and v < 10**9:
            return (str(v/10**6)+"M").replace(".0M", "M")
        elif v > 10**9:
            return (str(v/10**9)+"B").replace(".0B", "B")
scales = [
    "1.28M",
    "3M",
    "6.4M",
    "12.8M",
    "30M",
    "64M",
    "128M",
    "300M",
    "640M",
    "1.28B",
    "3B",
]
scales_numeric = [samples_seen_from_str(s) for s in scales]

In [None]:
df = pd.read_csv("scaling_laws/data/all_results.csv.gz")
df["samples_seen_scale_pretty"] = df.total_samples_seen.apply(lambda s:human(s))
df = df[
    ((df.pretrain_dataset=="datacomp_1b") & (df.lr_scheduler == "cosine")) | 
    ((df.pretrain_dataset=="datacomp_1b") & (df.lr_scheduler == "const")) |
    ((df.pretrain_dataset=="relaion2b-en") & (df.lr_scheduler == "cosine"))
]
model_profile = pd.read_csv("model_profile.csv").set_index("model_simple_namespace")

In [None]:
evals = [
 ("imagenet1k", "acc1"), # ImageNet-1k zero-shot classification
 ("mscoco_captions", "image_retrieval_recall@5"), # MS-COCO Recall@5 zero-shot image retrieval (T->I)
 ("mscoco_captions", "text_retrieval_recall@5"), # MS-COCO Recall@5 zero-shot text retrieval (I->T)
 ("datacomp_classification", "acc1"), # average over 35 classification tasks from DataComp (Tab.15 from https://arxiv.org/abs/2304.14108)
 ("imagenet_distribution_shift", "acc1"), # average over ImageNet-v2, ImageNet-R, ImageNet-Sketch, ImageNet-A, ObjectNet zero-shot classification
]   
for scheduler in ("cosine", "const",):
    for pretrain_dataset in ("datacomp_1b", "relaion2b-en"):
        for ds, metric in evals:
            for eval_type in ("similarity",):
                d = df.copy()
                d = d[d.eval_type==eval_type]
                d = d[d.pretrain_dataset==pretrain_dataset]
                d = d[d.downstream_dataset==ds]                        
                if scheduler != "const":            
                    d = d[d.epoch==d.total_epochs]
                d = d[d.lr_scheduler==scheduler]
                d = d[d.samples_seen_scale_pretty.isin(scales)]
                vars=("samples_seen_scale_pretty", "model_simple_namespace")
                d = d.sort_values(by=metric, ascending=False)
                d = d.drop_duplicates(subset=vars, keep="first") # Show best result for each model x samples seen
                d = pd.pivot(
                    d,
                    index="samples_seen_scale_pretty",
                    columns=("model_simple_namespace"),
                    values=metric,
                )
                if not len(d):
                    continue
                d.index.name = "Samples Seen"
                
                d.columns.name = "Model"
                cols = d.columns
                cols = sorted(cols, key=lambda k:model_profile.loc[k].gflops)
                d = d[cols]
                d.columns = ["_".join(k.split("_")[::-1]) for k in d.columns]
                missing = [s for s in scales if s not in d.index]
                for s in missing:
                    d.loc[s] = np.nan
                index = sorted(d.index, key=lambda k:samples_seen_from_str(k))
                d = d.loc[index]
                d = d.T
                        
                maxval = np.nanmax(d.values)
                minval = np.nanmin(d.values)
                style = d.style.background_gradient(cmap="viridis", vmin=minval, vmax=maxval, axis=None)                    
                style = style.format(precision=3, na_rep="NA")
                display_html(f"<font color='green'>'{ds}'</font> {metric} results for models pre-trained on <font color='red'>'{pretrain_dataset}'</font>  using scheduler <font color='blue'>'{scheduler}'</font>:", raw=True)            
                display(style)
            

Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.013,0.03,0.06,0.102,0.187,0.27,0.354,0.44,0.497,0.531,0.559
siglip_ViT-S-32,0.013,0.029,0.056,0.096,0.177,0.262,0.346,0.428,0.487,,
clip_ViT-M-32,0.014,0.032,0.063,0.105,0.201,0.292,0.383,0.478,0.537,0.573,0.608
siglip_ViT-M-32,0.014,0.03,0.059,0.101,0.191,0.283,0.373,0.468,0.525,,
clip_ViT-S-16,0.02,0.044,0.08,0.134,0.237,0.329,0.423,0.505,,0.596,0.628
siglip_ViT-S-16,,,,,,,0.414,,,,
mammut_ViT-S-32,0.015,0.033,0.064,0.111,0.204,0.293,0.372,0.452,0.503,0.514,0.565
clip_ViT-S-14,0.02,0.046,0.08,0.135,0.237,,0.42,0.513,,0.598,0.637
clip_ViT-B-32,0.014,0.032,0.065,0.112,0.216,0.312,0.404,0.503,,0.608,0.653
siglip_ViT-B-32,0.013,0.031,0.062,0.108,0.206,0.302,0.397,0.498,0.561,,


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.02,0.038,0.066,0.106,0.183,0.256,0.326,0.394,0.44,0.477,0.497
siglip_ViT-S-32,0.02,0.036,0.064,0.104,0.181,0.25,0.32,0.388,0.438,,
clip_ViT-M-32,0.022,0.039,0.07,0.113,0.197,0.276,0.349,0.426,0.477,0.517,0.547
siglip_ViT-M-32,0.021,0.037,0.066,0.112,0.193,0.267,0.346,0.42,0.475,,
clip_ViT-S-16,0.029,0.053,0.087,0.138,0.228,0.308,0.39,0.457,,0.536,0.557
siglip_ViT-S-16,,,,,,,0.382,,,,
mammut_ViT-S-32,0.02,0.038,0.068,0.118,0.201,0.278,0.346,0.41,0.456,0.459,0.507
clip_ViT-S-14,0.032,0.054,0.086,0.145,0.234,,0.384,0.454,,0.54,0.566
clip_ViT-B-32,0.019,0.039,0.072,0.119,0.209,0.294,0.372,0.453,,0.548,0.586
siglip_ViT-B-32,0.021,0.037,0.067,0.114,0.204,0.283,0.365,0.445,0.503,,


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.031,0.054,0.107,0.17,0.277,0.385,0.479,0.565,0.619,0.656,0.675
siglip_ViT-S-32,0.027,0.054,0.102,0.163,0.274,0.368,0.474,0.55,0.609,,
clip_ViT-M-32,0.03,0.055,0.111,0.178,0.307,0.413,0.506,0.602,0.665,0.702,0.722
siglip_ViT-M-32,0.031,0.056,0.1,0.174,0.288,0.394,0.501,0.599,0.652,,
clip_ViT-S-16,0.041,0.08,0.139,0.22,0.348,0.463,0.554,0.635,,0.713,0.734
siglip_ViT-S-16,,,,,,,0.55,,,,
mammut_ViT-S-32,0.03,0.06,0.109,0.187,0.313,0.414,0.513,0.579,0.641,0.642,0.692
clip_ViT-S-14,0.045,0.084,0.139,0.22,0.347,,0.555,0.623,,0.713,0.742
clip_ViT-B-32,0.026,0.059,0.116,0.187,0.316,0.437,0.535,0.634,,0.727,0.772
siglip_ViT-B-32,0.026,0.055,0.104,0.169,0.301,0.416,0.522,0.618,0.686,,


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.118,0.141,0.177,0.211,0.27,0.328,0.376,0.43,0.461,0.491,0.5
siglip_ViT-S-32,0.126,0.138,0.171,0.203,0.27,0.322,0.371,0.423,0.463,,
clip_ViT-M-32,0.123,0.14,0.181,0.218,0.277,0.339,0.395,0.451,0.494,0.52,0.527
siglip_ViT-M-32,,,,,,,0.386,0.446,,,
clip_ViT-S-16,0.123,0.157,0.193,0.229,0.296,0.358,0.411,0.468,,0.515,0.537
mammut_ViT-S-32,0.121,0.144,0.179,0.22,0.284,0.339,0.383,0.436,0.473,0.477,0.51
clip_ViT-B-32,0.117,0.145,,0.222,0.289,,0.407,0.471,,0.535,0.575
siglip_ViT-B-32,0.122,0.142,0.18,0.222,0.287,0.341,0.398,0.461,0.506,,
coca_ViT-S-32,,0.142,,,0.276,,0.373,0.426,,,
mammut_ViT-S-16,0.123,0.158,0.202,0.242,0.317,0.369,0.42,0.473,0.502,0.524,0.553


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.016,0.027,0.043,0.068,0.128,0.186,0.246,0.307,0.353,0.385,0.408
siglip_ViT-S-32,0.016,0.026,0.041,0.062,0.119,0.176,0.237,0.299,0.349,,
clip_ViT-M-32,0.016,0.028,0.045,0.074,0.137,0.203,0.271,0.341,0.388,0.426,0.46
siglip_ViT-M-32,,,,,,,0.259,0.326,,,
clip_ViT-S-16,0.021,0.035,0.055,0.085,0.156,0.228,0.297,0.371,,0.457,0.485
mammut_ViT-S-32,0.017,0.028,0.045,0.079,0.144,0.208,0.27,0.329,0.371,0.376,0.417
clip_ViT-B-32,0.016,0.028,,0.077,0.147,,0.288,0.37,,0.466,0.516
siglip_ViT-B-32,0.016,0.027,0.043,0.072,0.143,0.205,0.279,0.36,0.42,,
coca_ViT-S-32,0.017,0.029,,0.078,0.136,,0.248,0.308,,,
mammut_ViT-S-16,0.022,0.036,0.061,0.101,0.18,0.259,0.322,0.388,0.434,0.467,0.488


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.013,0.029,0.05,0.086,0.158,0.229,0.304,0.379,0.434,0.48,0.517
clip_ViT-M-32,0.013,0.028,0.052,0.09,0.171,0.247,0.325,0.416,0.476,0.51,
clip_ViT-S-16,0.018,0.037,0.066,0.111,0.195,0.277,0.352,0.439,0.499,,
mammut_ViT-S-32,0.013,0.029,0.05,0.09,0.168,0.243,0.315,0.386,0.435,0.458,0.491
clip_ViT-B-32,0.014,0.029,0.054,0.096,0.182,0.262,0.346,0.44,0.512,0.56,0.607
mammut_ViT-S-16,0.019,0.039,0.067,0.121,0.21,0.295,0.37,0.448,0.5,0.536,0.576
mammut_ViT-M-32,0.013,0.03,0.055,0.101,0.185,0.259,0.343,0.425,0.482,0.522,0.561
mammut_ViT-B-32,0.013,0.03,0.056,0.105,0.197,0.283,0.365,0.457,0.518,0.563,0.608
clip_ViT-B-16-text-plus,0.016,,,,,,0.414,0.51,0.581,,
mammut_ViT-B-16,,,,,,,0.449,0.543,0.609,,


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.018,0.041,0.069,0.112,0.191,0.265,0.334,0.407,0.451,0.491,0.528
clip_ViT-M-32,0.019,0.039,0.071,0.12,0.202,0.287,0.36,0.442,0.493,0.518,
clip_ViT-S-16,0.027,0.054,0.087,0.141,0.238,0.32,0.397,0.467,0.515,,
mammut_ViT-S-32,0.018,0.038,0.067,0.112,0.2,0.286,0.355,0.424,0.46,0.482,0.505
clip_ViT-B-32,0.019,0.04,0.076,0.124,0.218,0.301,0.381,0.464,0.523,0.567,0.605
mammut_ViT-S-16,0.026,0.051,0.088,0.146,0.253,0.341,0.425,0.486,0.521,0.561,0.587
mammut_ViT-M-32,0.019,0.04,0.071,0.124,0.223,0.303,0.388,0.458,0.506,0.547,0.579
mammut_ViT-B-32,0.019,0.039,0.073,0.129,0.237,0.331,0.409,0.492,0.544,0.58,0.614
clip_ViT-B-16-text-plus,0.025,,,,,,0.453,0.533,0.584,,
mammut_ViT-B-16,,,,,,,0.507,0.573,0.617,,


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.032,0.06,0.113,0.174,0.296,0.398,0.492,0.578,0.632,0.665,0.696
clip_ViT-M-32,0.036,0.064,0.119,0.179,0.312,0.421,0.52,0.618,0.672,0.689,
clip_ViT-S-16,0.043,0.081,0.139,0.223,0.363,0.47,0.563,0.651,0.697,,
mammut_ViT-S-32,0.032,0.058,0.106,0.179,0.303,0.427,0.512,0.601,0.642,0.653,0.681
clip_ViT-B-32,0.033,0.061,0.116,0.192,0.327,0.447,0.533,0.637,0.704,0.73,0.768
mammut_ViT-S-16,0.046,0.084,0.147,0.238,0.381,0.503,0.589,0.664,0.705,0.723,0.752
mammut_ViT-M-32,0.032,0.062,0.11,0.193,0.34,0.45,0.547,0.644,0.689,0.713,0.748
mammut_ViT-B-32,0.034,0.06,0.116,0.201,0.365,0.485,0.576,0.664,0.718,0.738,0.771
clip_ViT-B-16-text-plus,0.044,,,,,,0.627,0.712,0.764,,
mammut_ViT-B-16,,,,,,,0.68,0.749,0.792,,


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.122,0.139,0.16,0.193,0.246,0.293,0.346,0.386,0.421,,
clip_ViT-M-32,0.123,0.137,0.17,0.197,0.256,0.312,0.361,0.416,0.457,,
clip_ViT-S-16,,,,,,,0.371,0.415,0.453,,
mammut_ViT-S-32,0.115,0.136,0.164,0.201,0.252,0.308,0.355,0.398,0.431,,
clip_ViT-B-32,0.121,0.137,0.169,0.205,0.263,0.321,0.374,0.429,0.468,,
mammut_ViT-M-32,0.117,0.135,0.163,0.207,0.268,0.316,0.37,0.423,0.461,,
mammut_ViT-B-32,0.117,0.136,0.165,0.212,0.273,0.342,0.383,0.435,0.484,,
clip_ViT-L-14,0.126,0.146,0.182,0.236,0.302,0.376,0.438,0.5,0.542,,
mammut_ViT-L-14,0.122,0.149,0.183,,0.318,0.397,0.463,0.528,0.571,,
clip_ViT-H-14,0.125,0.138,0.182,0.224,0.305,0.384,0.444,0.515,0.56,,


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32,0.016,0.025,0.038,0.061,0.11,0.158,0.208,0.265,0.306,,
clip_ViT-M-32,0.015,0.026,0.041,0.064,0.118,0.172,0.229,0.293,0.341,,
clip_ViT-S-16,,,,,,,0.249,0.314,0.363,,
mammut_ViT-S-32,0.016,0.025,0.04,0.064,0.119,0.174,0.226,0.277,0.316,,
clip_ViT-B-32,0.015,0.027,0.042,0.068,0.126,0.186,0.246,0.318,0.376,,
mammut_ViT-M-32,0.016,0.027,0.042,0.072,0.131,0.188,0.248,0.312,0.357,,
mammut_ViT-B-32,0.016,0.026,0.043,0.074,0.141,0.205,0.271,0.343,0.393,,
clip_ViT-L-14,0.019,0.033,0.053,0.092,0.175,0.261,0.35,0.447,0.52,,
mammut_ViT-L-14,0.019,0.034,0.055,,0.194,0.3,0.401,0.505,0.569,,
clip_ViT-H-14,0.017,0.026,0.051,0.088,0.169,0.271,0.366,0.471,0.546,,


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32-alt,,,,0.003,0.013,0.037,0.138,0.31,0.391,0.447,0.462
clip_ViT-S-32,0.021,0.039,0.069,0.089,0.175,0.248,0.325,0.358,0.416,0.469,0.485
clip_ViT-M-32-alt,,,,0.003,0.015,0.049,0.18,0.364,0.448,0.502,0.521
clip_ViT-S-16-alt,,,,0.003,0.015,0.052,0.171,0.367,0.452,0.506,0.525
clip_ViT-M-32,0.022,0.042,0.076,0.095,0.188,0.276,0.352,0.39,0.465,0.513,0.532
clip_ViT-S-16,,,,0.003,0.016,0.052,0.19,0.398,0.478,0.533,0.554
mammut_ViT-S-32,0.021,0.042,0.076,0.094,0.015,0.053,0.212,0.363,0.422,0.435,
clip_ViT-S-14,,,,0.004,0.016,0.052,0.183,0.392,0.481,0.537,0.557
clip_ViT-B-32,0.023,0.044,0.08,0.101,0.208,0.299,0.376,0.422,0.507,0.552,0.572
clip_ViT-M-16-alt,,,,0.001,0.003,0.026,0.107,0.353,0.494,0.559,0.59


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32-alt,,,,0.003,0.014,0.039,0.128,0.277,0.343,0.386,0.404
clip_ViT-S-32,0.027,0.048,0.082,0.099,0.17,0.24,0.303,0.343,0.368,0.417,0.44
clip_ViT-M-32-alt,,,,0.004,0.019,0.052,0.168,0.332,0.397,0.447,0.467
clip_ViT-S-16-alt,,,,0.004,0.022,0.052,0.157,0.332,0.406,0.453,0.473
clip_ViT-M-32,0.028,0.052,0.086,0.106,0.183,0.267,0.324,0.367,0.414,0.462,0.485
clip_ViT-S-16,,,,0.004,0.023,0.055,0.172,0.355,0.43,0.483,0.508
mammut_ViT-S-32,0.025,0.047,0.082,0.1,0.02,0.053,0.208,0.339,0.388,0.403,
clip_ViT-S-14,,,,0.004,0.023,0.048,0.165,0.352,0.435,0.483,0.511
clip_ViT-B-32,0.03,0.054,0.089,0.111,0.205,0.277,0.347,0.387,0.45,0.499,0.52
clip_ViT-M-16-alt,,,,0.001,0.004,0.028,0.104,0.317,0.451,0.512,0.533


Samples Seen,1.28M,3M,6.4M,12.8M,30M,64M,128M,300M,640M,1.28B,3B
clip_ViT-S-32-alt,,,,0.004,0.027,0.056,0.194,0.421,0.502,0.555,0.578
clip_ViT-S-32,0.043,0.078,0.124,0.154,0.26,0.368,0.441,0.493,0.536,0.597,0.62
clip_ViT-M-32-alt,,,,0.006,0.029,0.077,0.25,0.473,0.564,0.631,0.651
clip_ViT-S-16-alt,,,,0.006,0.029,0.079,0.248,0.49,0.581,0.62,0.645
clip_ViT-M-32,0.043,0.079,0.138,0.167,0.276,0.397,0.48,0.526,0.591,0.647,0.664
clip_ViT-S-16,,,,0.005,0.034,0.082,0.255,0.534,0.61,0.661,0.683
mammut_ViT-S-32,0.041,0.077,0.127,0.156,0.028,0.079,0.306,0.493,0.555,0.577,
clip_ViT-S-14,,,,0.006,0.034,0.075,0.254,0.525,0.614,0.67,0.687
clip_ViT-B-32,0.041,0.08,0.138,0.17,0.302,0.416,0.503,0.55,0.623,0.674,0.702
clip_ViT-M-16-alt,,,,0.001,0.004,0.045,0.161,0.475,0.633,0.684,0.712
