In [1]:
import re

def extract_epoch_from_filename(filename):
    
    match = re.search(r"train_data_checkpoint_(\d+)\.npz$", filename)
    if match:
        epoch_str = match.group(1)
        epoch_str = "00"+epoch_str
        return epoch_str[-3:] 
    else:
        raise ValueError(f"Could not extract epoch number from filename: {filename}")


In [2]:
import os
import csv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm
import gc


def get_epoch_from_filename(file_name):
    name, _ = os.path.splitext(file_name)
    epoch_str = name[-3:]
    try:
        return int(epoch_str)
    except ValueError:
        return None

def train_on_npz(file_path, input_dim, num_classes, input_parameters):
    batch_size = input_parameters.get("batch_size", 256)
    learning_rate = input_parameters.get("learning_rate", 0.1)
    num_epochs = input_parameters.get("num_epochs", 20)
    momentum = input_parameters.get("momentum", 0.9)
    weight_decay = input_parameters.get("weight_decay", 1e-4)
    device = input_parameters.get("device", "cuda" if torch.cuda.is_available() else "cpu")

    data = np.load(file_path)
    X = torch.from_numpy(data["features"]).float()
    y = torch.from_numpy(data["labels"]).long()
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    model = nn.Sequential(
        nn.BatchNorm1d(input_dim, affine=False, eps=1e-6),
        nn.Linear(input_dim, num_classes)
    ).to(device)

    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    model.train()
    for epoch in tqdm(range(num_epochs)):
        total_loss = 0.0
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            optimizer.zero_grad()
            logits = model(batch_X)
            loss = criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
    avg_loss = total_loss / len(dataloader)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            logits = model(batch_X)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == batch_y).sum().item()
            total += batch_y.size(0)

    del optimizer
    
    accuracy = correct / total
    return avg_loss, accuracy, model

def eval_on_val_file(model, file_path, input_dim, num_classes, input_parameters):
    batch_size = input_parameters.get("batch_size", 256)
    device = input_parameters.get("device", "cuda" if torch.cuda.is_available() else "cpu")

    data = np.load(file_path)
    X = torch.from_numpy(data["features"]).float()
    y = torch.from_numpy(data["labels"]).long()
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            logits = model(batch_X)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == batch_y).sum().item()
            total += batch_y.size(0)

    accuracy = correct / total
    return accuracy

def run_linear_probing_on_folder(folder_path, val_folder_path, input_dim, num_classes, input_parameters, csv_path="results.csv"):
    files = os.listdir(folder_path)
    npz_files = [f for f in files if f.endswith(".npz")]
    val_files = [f for f in os.listdir(val_folder_path) if f.endswith(".npz")]
    npz_files.sort(key=lambda file_name: extract_epoch_from_filename(file_name))
    
    results = []

    best_accuracy = 0
    
    for file_name in npz_files:
        file_path = os.path.join(folder_path, file_name)
        epoch = extract_epoch_from_filename(file_name)
        val_file_path = f"{int(epoch)}"
        val_file_path = os.path.join(val_folder_path, f"val_data_checkpoint_{int(epoch)}.npz")
        
        print(f"Working on file: {file_name}")
        print(f"Eval file: {val_file_path}")
        
        loss, accuracy, model = train_on_npz(file_path, input_dim, num_classes, input_parameters)
        val_accuracy = eval_on_val_file(model, val_file_path, input_dim, num_classes, input_parameters)
        print(f"{epoch} file → Final Avg Loss: {loss:.4f} Final Accuracy: {accuracy:.4f} Val Accuracy: {val_accuracy:.4f}")
        results.append({"epoch": epoch, "loss": loss, "accuracy": accuracy, 'val_accuracy': val_accuracy})

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model = model
            save_path = "/kaggle/working/models/best_model.pth.tar"
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save({
                'epoch_file_epoch': epoch,
                'model_state_dict': best_model.state_dict(),
                'loss': loss,
                'accuracy': best_accuracy,
                'val_accuracy': val_accuracy,
            }, save_path)
            print(f"Saved new best model with accurcy {best_accuracy * 100:.4f} at {save_path}")

        del model
        gc.collect()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        

    with open(csv_path, mode="w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["epoch", "loss", "accuracy", "val_accuracy"])
        writer.writeheader()
        writer.writerows(results)

    print(f"Results saved to {csv_path}")
    return results


In [3]:
input_params = {
    "batch_size": 512,
    "learning_rate": 0.1,
    "num_epochs": 40,
    "momentum": 0.9,
    "weight_decay": 1e-4,
    "device": "cuda"
}

folder = "/kaggle/input/linearprobing-mae-small-featureextractorreal/extracted_features"
val_folder = "/kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features"
input_dim = 384
num_classes = 100

results = run_linear_probing_on_folder(folder, val_folder, input_dim, num_classes, input_params)

Working on file: train_data_checkpoint_1.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_1.npz


100%|██████████| 40/40 [00:37<00:00,  1.05it/s]


001 file → Final Avg Loss: 3.7521 Final Accuracy: 0.1493 Val Accuracy: 0.1298
Saved new best model with accurcy 14.9292 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_6.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_6.npz


100%|██████████| 40/40 [00:37<00:00,  1.08it/s]


006 file → Final Avg Loss: 3.5340 Final Accuracy: 0.1904 Val Accuracy: 0.1694
Saved new best model with accurcy 19.0400 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_11.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_11.npz


100%|██████████| 40/40 [00:36<00:00,  1.08it/s]


011 file → Final Avg Loss: 3.2107 Final Accuracy: 0.2548 Val Accuracy: 0.2114
Saved new best model with accurcy 25.4777 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_16.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_16.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


016 file → Final Avg Loss: 3.0970 Final Accuracy: 0.2773 Val Accuracy: 0.2294
Saved new best model with accurcy 27.7262 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_21.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_21.npz


100%|██████████| 40/40 [00:37<00:00,  1.08it/s]


021 file → Final Avg Loss: 3.0139 Final Accuracy: 0.2923 Val Accuracy: 0.2454
Saved new best model with accurcy 29.2262 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_26.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_26.npz


100%|██████████| 40/40 [00:37<00:00,  1.08it/s]


026 file → Final Avg Loss: 2.9376 Final Accuracy: 0.3060 Val Accuracy: 0.2608
Saved new best model with accurcy 30.6023 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_31.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_31.npz


100%|██████████| 40/40 [00:36<00:00,  1.08it/s]


031 file → Final Avg Loss: 2.9185 Final Accuracy: 0.3084 Val Accuracy: 0.2544
Saved new best model with accurcy 30.8408 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_36.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_36.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


036 file → Final Avg Loss: 2.8712 Final Accuracy: 0.3188 Val Accuracy: 0.2554
Saved new best model with accurcy 31.8762 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_41.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_41.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


041 file → Final Avg Loss: 2.8243 Final Accuracy: 0.3258 Val Accuracy: 0.2720
Saved new best model with accurcy 32.5846 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_46.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_46.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


046 file → Final Avg Loss: 2.7822 Final Accuracy: 0.3343 Val Accuracy: 0.2798
Saved new best model with accurcy 33.4323 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_51.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_51.npz


100%|██████████| 40/40 [00:37<00:00,  1.06it/s]


051 file → Final Avg Loss: 2.7257 Final Accuracy: 0.3460 Val Accuracy: 0.2866
Saved new best model with accurcy 34.5985 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_56.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_56.npz


100%|██████████| 40/40 [00:37<00:00,  1.06it/s]


056 file → Final Avg Loss: 2.6952 Final Accuracy: 0.3514 Val Accuracy: 0.2936
Saved new best model with accurcy 35.1446 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_61.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_61.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


061 file → Final Avg Loss: 2.6879 Final Accuracy: 0.3523 Val Accuracy: 0.2916
Saved new best model with accurcy 35.2308 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_66.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_66.npz


100%|██████████| 40/40 [00:37<00:00,  1.08it/s]


066 file → Final Avg Loss: 2.6448 Final Accuracy: 0.3607 Val Accuracy: 0.3062
Saved new best model with accurcy 36.0685 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_71.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_71.npz


100%|██████████| 40/40 [00:37<00:00,  1.08it/s]


071 file → Final Avg Loss: 2.5779 Final Accuracy: 0.3731 Val Accuracy: 0.3184
Saved new best model with accurcy 37.3131 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_76.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_76.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


076 file → Final Avg Loss: 2.5241 Final Accuracy: 0.3845 Val Accuracy: 0.3230
Saved new best model with accurcy 38.4531 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_81.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_81.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


081 file → Final Avg Loss: 2.5127 Final Accuracy: 0.3882 Val Accuracy: 0.3270
Saved new best model with accurcy 38.8200 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_86.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_86.npz


100%|██████████| 40/40 [00:37<00:00,  1.06it/s]


086 file → Final Avg Loss: 2.4881 Final Accuracy: 0.3914 Val Accuracy: 0.3286
Saved new best model with accurcy 39.1369 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_91.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_91.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


091 file → Final Avg Loss: 2.4793 Final Accuracy: 0.3942 Val Accuracy: 0.3310
Saved new best model with accurcy 39.4192 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_96.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_96.npz


100%|██████████| 40/40 [00:37<00:00,  1.07it/s]


096 file → Final Avg Loss: 2.4683 Final Accuracy: 0.3964 Val Accuracy: 0.3348
Saved new best model with accurcy 39.6408 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_100.npz
Eval file: /kaggle/input/linearprobing-mae-small-featureextractor-real-val/extracted_features/val_data_checkpoint_100.npz


100%|██████████| 40/40 [00:37<00:00,  1.06it/s]


100 file → Final Avg Loss: 2.4674 Final Accuracy: 0.3970 Val Accuracy: 0.3372
Saved new best model with accurcy 39.6977 at /kaggle/working/models/best_model.pth.tar
Results saved to results.csv
