In [26]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import random
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from flat_models import *

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

## MODEL TO LOAD ##
model_name = "ps16_large"
outdir = os.path.abspath(f'checkpoints/{model_name}')
print("outdir", outdir)
# Load previous config.yaml if available
if os.path.exists(f"{outdir}/config.yaml"):
    config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
    print(f"Loaded config.yaml from ckpt folder {outdir}")
    # create global variables from the config
    print("\n__CONFIG__")
    for attribute_name in config.keys():
        print(f"{attribute_name} = {config[attribute_name]}")
        globals()[attribute_name] = config[f'{attribute_name}']
    print("\n")

### Multi-GPU config ###
device_count = torch.cuda.device_count()
print(f"Number of available CUDA devices: {device_count}")

local_rank = os.getenv('LOCAL_RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print(f"LOCAL RANK={local_rank}")

num_devices = os.getenv('NUM_GPUS')
if num_devices is None: 
    num_devices = 1
else:
    num_devices = int(num_devices)
print(f"NUM GPUS={num_devices}")
distributed = True if num_devices>1 else False
if distributed: assert device_count==num_devices

node = os.getenv('SLURM_NODEID')
if node is None:
    node = 0
else:
    node = int(node)
print(f"NODE={node}")

global_rank = os.getenv('RANK')
if global_rank is None:
    global_rank = 0
else:
    global_rank = int(global_rank)
print(f"GLOBAL RANK={global_rank}")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

# base_lr = probe_base_lr
batch_size = probe_batch_size
num_epochs = probe_num_epochs
data_type = torch.float32 # change depending on your mixed_precision
global_batch_size = batch_size * world_size

base_lr = 3e-4
batch_size = 128
num_epochs = 50

# FSDP Setup
if distributed:
    import torch.distributed as dist
    import torch.multiprocessing as mp
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy
    import functools
    from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
    print(f"setting device to cuda:{local_rank}")
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda',local_rank)
    dist.init_process_group("nccl", rank=global_rank, world_size=world_size)
    print(f"\nSuccessfully set cuda:{local_rank} | global_rank{global_rank} | node{node}")
    dist.barrier()
    print(f"global_rank{global_rank} passed barrier")
else:
    device = torch.device('cuda')

print("PID of this process =",os.getpid())
print("device =", device, "distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)

# seed all random functions
utils.seed_everything(seed + global_rank)

outdir /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/ps16_replicate
Loaded config.yaml from ckpt folder /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/ps16_replicate

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 5
ckpt_saving = True
cls_embed = True
decoder_embed_dim = 512
grad_accumulation_steps = 1
grad_clip = 1.0
hcp_flat_path = /weka/proj-medarc/shared/hcp_flat
mask_ratio = 0.75
model_name = ps16_replicate
no_qkv_bias = False
norm_pix_loss = False
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_workers = 10
patch_size = 16
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_num_epochs = 30
probe_num_samples_per_epoch = 100000
resume_from_ckpt = True
seed = 42
sep_pos_embed = True
t_patch_size = 2
test_num_samples_per_epoch = 50000
test_set = False
trunc_init = False
use_contrastive_loss = False
wandb_log = True


Number of available CUDA devices: 1
LOCAL RANK=0
NUM GPUS=1
NODE=

# hcp_flat

In [27]:
from util.hcp_flat import load_hcp_flat_mask
from util.hcp_flat import create_hcp_flat
import util.visualize as vis

model = mae_vit_small_fmri(
    patch_size=patch_size,
    decoder_embed_dim=decoder_embed_dim,
    t_patch_size=t_patch_size,
    pred_t_dim=pred_t_dim,
    decoder_depth=4,
    cls_embed=cls_embed,
    norm_pix_loss=norm_pix_loss,
    no_qkv_bias=no_qkv_bias,
    sep_pos_embed=sep_pos_embed,
    trunc_init=trunc_init,
)

# Load ckpt
if not os.path.exists(outdir) or not os.path.isdir(outdir):
    print(f"\nCheckpoint folder {outdir} does not exist.\n")
else:
    checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

    # # Find the latest ckpt to load
    epoch_numbers = []
    for file in checkpoint_files:
        try:
            epoch_number = int(file.split('epoch')[-1].split('.')[0])
            epoch_numbers.append(epoch_number)
        except ValueError:
            continue
    latest_epoch = max(epoch_numbers)
    checkpoint_name = f"epoch{latest_epoch}.pth"

    # Load the checkpoint
    checkpoint_path = os.path.join(outdir, checkpoint_name)
    state = torch.load(checkpoint_path)
    model.load_state_dict(state["model_state_dict"], strict=True)

    print(f"\nLoaded checkpoint {checkpoint_name} from {outdir}\n")

model.eval()
model.requires_grad_(False)
model.to(device)
pass

img_size (144, 320) patch_size (16, 16) frames 16 t_patch_size 2
model initialized

Loaded checkpoint epoch99.pth from /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/ps16_replicate



## Create dataset and data loaders

In [28]:
import pickle
from pathlib import Path
import pandas as pd
from sklearn.linear_model import LogisticRegressionCV
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

parquet_folder = f"epoch{latest_epoch}"
print("parquet_folder", parquet_folder)

target = "trial_type"
print(f"Target: {target}")

train_features = pd.read_parquet(f"{outdir}/{parquet_folder}/train.parquet")
test_features = pd.read_parquet(f"{outdir}/{parquet_folder}/test.parquet")
print(f"train: {train_features.shape}, test: {test_features.shape}")

X_train = np.stack(train_features["feature"])
X_test = np.stack(test_features["feature"])
print(f"X_train: {X_train.shape}, X_test: {X_test.shape}")

if target == "task":
    labels_train = train_features["task"].str.rstrip("1234").values
    labels_test = test_features["task"].str.rstrip("1234").values
elif target == "trial_type":
    labels_train = train_features["trial_type"].values
    labels_test = test_features["trial_type"].values

label_enc = LabelEncoder()
y_train = label_enc.fit_transform(labels_train)
y_test = label_enc.transform(labels_test)

print(f"classes ({len(label_enc.classes_)}): {label_enc.classes_}")
print(
    f"\ny_train: {y_train.shape} {y_train[:20]}\n"
    f"y_test: {y_test.shape} {y_test[:20]}"
)
del train_features, test_features

X_train = torch.Tensor(X_train)
X_test = torch.Tensor(X_test)
y_train = torch.Tensor(y_train)
y_test = torch.Tensor(y_test)

parquet_folder epoch99
Target: trial_type
train: (118336, 9), test: (12608, 9)
X_train: (118336, 384), X_test: (12608, 384)
classes (21): ['0bk_body' '0bk_faces' '0bk_places' '0bk_tools' '2bk_body' '2bk_faces'
 '2bk_places' '2bk_tools' 'fear' 'lf' 'lh' 'match' 'math' 'mental' 'neut'
 'relation' 'rf' 'rh' 'rnd' 'story' 't']

y_train: (118336,) [14  8 14  8 14 13 13 13 13 18 18 13 13 18 18 15 11 11 15 11]
y_test: (12608,) [ 7  7  0  0  5  5  3  3  4  4  6  6  1  1  2  2 19 19 12 12]


In [29]:
from torch.utils.data import DataLoader, TensorDataset
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

probe_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

probe_num_batches = len(probe_dl)
test_num_batches = len(test_dl)

print("probe_num_batches", probe_num_batches)
print("test_num_batches", test_num_batches)

probe_num_batches 925
test_num_batches 99


# Downstream probe

In [30]:
class LinearProbe(nn.Module):
    def __init__(self, input_dim, h=256, num_classes=8):
        super(LinearProbe, self).__init__()
        # self.classifier = nn.Linear(input_dim, num_classes)
        
        # self.classifier = nn.Sequential(
        #     nn.LayerNorm(input_dim),
        #     nn.Linear(input_dim, num_classes)
        # )
        
        # self.linear = nn.Linear(input_dim, h)
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.GELU(),
            nn.Dropout(p=0.35),
            nn.Linear(input_dim, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Dropout(p=0.15),
            nn.Linear(h, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Dropout(p=0.15),
            nn.Linear(h, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Dropout(p=0.15),
            nn.Linear(h, num_classes)
        )
    def forward(self, x):
        x = self.classifier(x)
        return x

In [31]:
linear_probe = LinearProbe(384,#model.embed_dim,#457344
                 h=5024,
                 num_classes=len(np.unique(y_train))).to(device)

# Set up optimizer and saving functions

In [32]:
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
probe_opt_grouped_parameters = [
    {'params': [p for n, p in linear_probe.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.05},
    {'params': [p for n, p in linear_probe.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]

lr = base_lr * global_batch_size / 256
print(f"multiply base lr {base_lr} by effective batch size {global_batch_size}")
print(f"lr = {lr}")

probe_optimizer = torch.optim.AdamW(probe_opt_grouped_parameters, lr=lr, betas=(0.9, 0.95))

def adjust_learning_rate(optimizer, epoch, warmup_epochs=5, min_lr=0.0):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch < warmup_epochs:
        lr_ = lr * epoch / warmup_epochs
    else:
        lr_ = min_lr + (lr - min_lr) * 0.5 * (
            1.0
            + math.cos(
                math.pi
                * (epoch - warmup_epochs)
                / (num_epochs - warmup_epochs)
            )
        )
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr_
    return lr_

print("\nDone with model preparations!")
num_params = utils.count_params(linear_probe)

multiply base lr 0.0003 by effective batch size 8
lr = 9.375e-06

Done with model preparations!
param counts:
52,561,877 total
52,561,877 trainable


# Start training

In [33]:
epoch = 0
lrs, probe_losses, probe_accs, test_probe_losses, test_probe_accs = [], [], [], [], []

In [34]:
mse = nn.MSELoss()
l1 = nn.L1Loss()
crossentropy = nn.CrossEntropyLoss()
progress_bar = tqdm(range(epoch, num_epochs), disable=local_rank!=0, desc="Overall")
linear_probe.train()
for epoch in progress_bar:
    with torch.cuda.amp.autocast(dtype=data_type):  
        for probe_i, batch in enumerate(probe_dl):
            probe_optimizer.zero_grad()
            adjust_learning_rate(probe_optimizer, probe_i / probe_num_batches + epoch)

            latents = batch[0].to(device, non_blocking=True)
            label = batch[1].to(device).long()
            
            task_pred = linear_probe(latents)
            probe_loss = crossentropy(task_pred, label)
            probe_acc = (torch.max(task_pred,1).indices == label).sum() / len(label)
            
            probe_loss.backward()
            probe_optimizer.step()

            probe_losses.append(probe_loss.item())
            probe_accs.append(probe_acc.item())
            
            # if probe_i%print_interval==0 and probe_i>0:
            #     print(f"Ep. {epoch} | probe_loss {np.mean(probe_losses[-print_interval:]):.3f} | probe_acc {np.mean(probe_accs[-print_interval:]):.3f} | lr {probe_optimizer.param_groups[0]['lr']} | {probe_i}/{probe_num_batches}")

        # print(f"Ep. {epoch} | iter {probe_i} | probe_loss {np.mean(probe_losses[-probe_i:]):.3f} | probe_acc {np.mean(probe_accs[-probe_i:]):.3f} | lr {probe_optimizer.param_groups[0]['lr']}")

        logs = {"train/probe_loss": np.mean(probe_losses[-probe_i:]),
                "train/probe_acc": np.mean(probe_accs[-probe_i:])}
    
    # Evaluate performance on held-out test dataset
    linear_probe.eval()
    with torch.no_grad():
        for test_i, batch in enumerate(test_dl):
            latents = batch[0].to(device, non_blocking=True)
            label = batch[1].to(device).long()

            task_pred = linear_probe(latents)
            probe_loss = crossentropy(task_pred, label)
            probe_acc = (torch.max(task_pred,1).indices == label).sum() / len(label)

            test_probe_losses.append(probe_loss.item())
            test_probe_accs.append(probe_acc.item())

            # if test_i%print_interval==0 and test_i>0:
            #     print(f"Test | probe_loss {np.mean(test_probe_losses[-print_interval:]):.3f} | probe_acc {np.mean(test_probe_accs[-print_interval:]):.3f} | {test_i}")

    print(f"{model_name} {checkpoint_name} | Test | iter {test_i} | probe_loss {np.mean(test_probe_losses[-test_i:]):.3f} | probe_acc {np.mean(test_probe_accs[-test_i:]):.3f}")
    logs = {"test/probe_loss": np.mean(test_probe_losses[-test_i:]),
            "test/probe_acc": np.mean(test_probe_accs[-test_i:])}

Overall:   2%|▋                                  | 1/50 [00:04<03:38,  4.45s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 2.480 | probe_acc 0.243


Overall:   4%|█▍                                 | 2/50 [00:08<03:16,  4.09s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.865 | probe_acc 0.735


Overall:   6%|██                                 | 3/50 [00:12<03:06,  3.98s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.631 | probe_acc 0.798


Overall:   8%|██▊                                | 4/50 [00:15<03:00,  3.92s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.540 | probe_acc 0.820


Overall:  10%|███▌                               | 5/50 [00:19<02:54,  3.89s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.442 | probe_acc 0.858


Overall:  12%|████▏                              | 6/50 [00:23<02:50,  3.87s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.527 | probe_acc 0.826


Overall:  14%|████▉                              | 7/50 [00:27<02:45,  3.85s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.390 | probe_acc 0.871


Overall:  16%|█████▌                             | 8/50 [00:31<02:41,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.362 | probe_acc 0.878


Overall:  18%|██████▎                            | 9/50 [00:35<02:37,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.343 | probe_acc 0.885


Overall:  20%|██████▊                           | 10/50 [00:38<02:33,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.341 | probe_acc 0.884


Overall:  22%|███████▍                          | 11/50 [00:42<02:29,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.356 | probe_acc 0.882


Overall:  24%|████████▏                         | 12/50 [00:46<02:25,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.322 | probe_acc 0.893


Overall:  26%|████████▊                         | 13/50 [00:50<02:21,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.318 | probe_acc 0.895


Overall:  28%|█████████▌                        | 14/50 [00:54<02:17,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.335 | probe_acc 0.888


Overall:  30%|██████████▏                       | 15/50 [00:58<02:13,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.316 | probe_acc 0.896


Overall:  32%|██████████▉                       | 16/50 [01:01<02:10,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.309 | probe_acc 0.897


Overall:  34%|███████████▌                      | 17/50 [01:05<02:06,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.323 | probe_acc 0.894


Overall:  36%|████████████▏                     | 18/50 [01:09<02:02,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.303 | probe_acc 0.899


Overall:  38%|████████████▉                     | 19/50 [01:13<01:58,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.308 | probe_acc 0.899


Overall:  40%|█████████████▌                    | 20/50 [01:17<01:54,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.295 | probe_acc 0.905


Overall:  42%|██████████████▎                   | 21/50 [01:21<01:52,  3.87s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.300 | probe_acc 0.904


Overall:  44%|██████████████▉                   | 22/50 [01:25<01:48,  3.89s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.293 | probe_acc 0.906


Overall:  46%|███████████████▋                  | 23/50 [01:29<01:45,  3.90s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.305 | probe_acc 0.903


Overall:  48%|████████████████▎                 | 24/50 [01:32<01:41,  3.91s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.299 | probe_acc 0.905


Overall:  50%|█████████████████                 | 25/50 [01:36<01:37,  3.90s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.291 | probe_acc 0.910


Overall:  52%|█████████████████▋                | 26/50 [01:40<01:33,  3.88s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.297 | probe_acc 0.910


Overall:  54%|██████████████████▎               | 27/50 [01:44<01:28,  3.87s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.296 | probe_acc 0.909


Overall:  56%|███████████████████               | 28/50 [01:48<01:24,  3.85s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.296 | probe_acc 0.910


Overall:  58%|███████████████████▋              | 29/50 [01:52<01:20,  3.85s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.295 | probe_acc 0.911


Overall:  60%|████████████████████▍             | 30/50 [01:56<01:16,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.298 | probe_acc 0.912


Overall:  62%|█████████████████████             | 31/50 [01:59<01:12,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.301 | probe_acc 0.912


Overall:  64%|█████████████████████▊            | 32/50 [02:03<01:09,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.301 | probe_acc 0.911


Overall:  66%|██████████████████████▍           | 33/50 [02:07<01:05,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.310 | probe_acc 0.911


Overall:  68%|███████████████████████           | 34/50 [02:11<01:01,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.309 | probe_acc 0.912


Overall:  70%|███████████████████████▊          | 35/50 [02:15<00:57,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.317 | probe_acc 0.911


Overall:  72%|████████████████████████▍         | 36/50 [02:18<00:53,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.314 | probe_acc 0.914


Overall:  74%|█████████████████████████▏        | 37/50 [02:22<00:49,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.315 | probe_acc 0.913


Overall:  76%|█████████████████████████▊        | 38/50 [02:26<00:46,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.318 | probe_acc 0.913


Overall:  78%|██████████████████████████▌       | 39/50 [02:30<00:42,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.320 | probe_acc 0.915


Overall:  80%|███████████████████████████▏      | 40/50 [02:34<00:38,  3.85s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.322 | probe_acc 0.914


Overall:  82%|███████████████████████████▉      | 41/50 [02:38<00:34,  3.85s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.324 | probe_acc 0.914


Overall:  84%|████████████████████████████▌     | 42/50 [02:42<00:30,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.323 | probe_acc 0.915


Overall:  86%|█████████████████████████████▏    | 43/50 [02:45<00:26,  3.84s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.327 | probe_acc 0.914


Overall:  88%|█████████████████████████████▉    | 44/50 [02:49<00:22,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.326 | probe_acc 0.915


Overall:  90%|██████████████████████████████▌   | 45/50 [02:53<00:19,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.327 | probe_acc 0.914


Overall:  92%|███████████████████████████████▎  | 46/50 [02:57<00:15,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.327 | probe_acc 0.915


Overall:  94%|███████████████████████████████▉  | 47/50 [03:01<00:11,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.327 | probe_acc 0.915


Overall:  96%|████████████████████████████████▋ | 48/50 [03:04<00:07,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.327 | probe_acc 0.915


Overall:  98%|█████████████████████████████████▎| 49/50 [03:08<00:03,  3.83s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.328 | probe_acc 0.914


Overall: 100%|██████████████████████████████████| 50/50 [03:12<00:00,  3.85s/it]

ps16_replicate epoch99.pth | Test | iter 98 | probe_loss 0.328 | probe_acc 0.914





In [35]:
## MLP ##
# clane epoch99.pth | Test | iter 100 | probe_loss 0.233 | probe_acc 0.923
# ps16 (using connor codebase) epoch99.pth | Test | iter 98 | probe_loss 0.410 | probe_acc 0.918
# ps16_replicate (using paul codebase) epoch99.pth | Test | iter 98 | probe_loss 0.324 | probe_acc 0.914
# ps8_30pct epoch99.pth | Test | iter 100 | probe_loss 0.442 | probe_acc 0.918
# ps16_large epoch99.pth | Test | iter 98 | probe_loss 0.188 | probe_acc 0.969
# ps8_30pct_512dec epoch99.pth | Test | iter 99 | probe_loss 0.290 | probe_acc 0.952

# ps16_mask9_3losses_bs32 epoch49.pth | Test | iter 98 | probe_loss 0.315 | probe_acc 0.941
# ps16_mask9_3losses_bs32 epoch65.pth | Test | iter 98 | probe_loss 0.335 | probe_acc 0.942
# ps16_mask9_3losses_bs32 epoch99.pth | Test | iter 98 | probe_loss 0.304 | probe_acc 0.949
# ps8_mask75_bs32_l epoch30.pth | Test | iter 100 | probe_loss 0.997 | probe_acc 0.695
# ps16_mask75_bs32_l epoch30.pth | Test | iter 98 | probe_loss 0.592 | probe_acc 0.864
# ps16_mask75_bs32_l epoch99.pth | Test | iter 98 | probe_loss 0.393 | probe_acc 0.894

## Linear ##
# ps16_mask9_3losses_bs32 epoch99.pth | Test | iter 98 | probe_loss 0.477 | probe_acc 0.864