In [1]:
import re

def extract_epoch_from_filename(filename):
    
    match = re.search(r"train_data_checkpoint_(\d+)\.npz$", filename)
    if match:
        epoch_str = match.group(1)
        epoch_str = "00"+epoch_str
        return epoch_str[-3:] 
    else:
        raise ValueError(f"Could not extract epoch number from filename: {filename}")


In [2]:
import os
import csv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm
import gc


def get_epoch_from_filename(file_name):
    name, _ = os.path.splitext(file_name)
    epoch_str = name[-3:]
    try:
        return int(epoch_str)
    except ValueError:
        return None

def train_on_npz(file_path, input_dim, num_classes, input_parameters):
    batch_size = input_parameters.get("batch_size", 256)
    learning_rate = input_parameters.get("learning_rate", 0.1)
    num_epochs = input_parameters.get("num_epochs", 20)
    momentum = input_parameters.get("momentum", 0.9)
    weight_decay = input_parameters.get("weight_decay", 1e-4)
    device = input_parameters.get("device", "cuda" if torch.cuda.is_available() else "cpu")

    data = np.load(file_path)
    X = torch.from_numpy(data["features"]).float()
    y = torch.from_numpy(data["labels"]).long()
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    model = nn.Sequential(
        nn.BatchNorm1d(input_dim, affine=False, eps=1e-6),
        nn.Linear(input_dim, num_classes)
    ).to(device)

    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    model.train()
    for epoch in tqdm(range(num_epochs)):
        total_loss = 0.0
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            optimizer.zero_grad()
            logits = model(batch_X)
            loss = criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
    avg_loss = total_loss / len(dataloader)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            logits = model(batch_X)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == batch_y).sum().item()
            total += batch_y.size(0)

    del optimizer
    
    accuracy = correct / total
    return avg_loss, accuracy, model

def eval_on_val_file(model, file_path, input_dim, num_classes, input_parameters):
    batch_size = input_parameters.get("batch_size", 256)
    device = input_parameters.get("device", "cuda" if torch.cuda.is_available() else "cpu")

    data = np.load(file_path)
    X = torch.from_numpy(data["features"]).float()
    y = torch.from_numpy(data["labels"]).long()
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            logits = model(batch_X)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == batch_y).sum().item()
            total += batch_y.size(0)

    accuracy = correct / total
    return accuracy

def run_linear_probing_on_folder(folder_path, val_folder_path, input_dim, num_classes, input_parameters, csv_path="results.csv"):
    if type(folder) in (list, tuple):
        files = []
        for files_folders in folder_path:
            files += [os.path.join(files_folders, f) for f in os.listdir(files_folders)]
    else:
        files = [os.path.join(folder_path, f) for f in os.listdir(folder_path)]
        
    npz_files = [f for f in files if f.endswith(".npz")]
    val_files = [f for f in os.listdir(val_folder_path) if f.endswith(".npz")]
    npz_files.sort(key=lambda file_name: extract_epoch_from_filename(file_name))
    
    results = []

    best_accuracy = 0
    
    for file_name in npz_files:
        epoch = extract_epoch_from_filename(file_name)
        val_file_path = f"{int(epoch)}"
        val_file_path = os.path.join(val_folder_path, f"val_data_checkpoint_{int(epoch)}.npz")
        
        print(f"Working on file: {file_name}")
        print(f"Eval file: {val_file_path}")
        
        loss, accuracy, model = train_on_npz(file_name, input_dim, num_classes, input_parameters)
        val_accuracy = eval_on_val_file(model, val_file_path, input_dim, num_classes, input_parameters)
        print(f"{epoch} file → Final Avg Loss: {loss:.4f} Final Accuracy: {accuracy:.4f} Val Accuracy: {val_accuracy:.4f}")
        results.append({"epoch": epoch, "loss": loss, "accuracy": accuracy, 'val_accuracy': val_accuracy})

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model = model
            save_path = "/kaggle/working/models/best_model.pth.tar"
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save({
                'epoch_file_epoch': epoch,
                'model_state_dict': best_model.state_dict(),
                'loss': loss,
                'accuracy': best_accuracy,
                'val_accuracy': val_accuracy,
            }, save_path)
            print(f"Saved new best model with accurcy {best_accuracy * 100:.4f} at {save_path}")

        del model
        gc.collect()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        

    with open(csv_path, mode="w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["epoch", "loss", "accuracy", "val_accuracy"])
        writer.writeheader()
        writer.writerows(results)

    print(f"Results saved to {csv_path}")
    return results


In [3]:
input_params = {
    "batch_size": 512,
    "learning_rate": 0.1,
    "num_epochs": 40,
    "momentum": 0.9,
    "weight_decay": 1e-4,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

folder = ("/kaggle/input/linearprobing-mae-base-featureextractorreal-2/extracted_features","/kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70")
val_folder = "/kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features"
input_dim = 768
num_classes = 100

results = run_linear_probing_on_folder(folder, val_folder, input_dim, num_classes, input_params)

Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_1.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_1.npz


100%|██████████| 40/40 [00:44<00:00,  1.10s/it]


001 file → Final Avg Loss: 3.6800 Final Accuracy: 0.1637 Val Accuracy: 0.1458
Saved new best model with accurcy 16.3677 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_6.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_6.npz


100%|██████████| 40/40 [00:43<00:00,  1.09s/it]


006 file → Final Avg Loss: 3.4884 Final Accuracy: 0.2008 Val Accuracy: 0.1686
Saved new best model with accurcy 20.0846 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_11.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_11.npz


100%|██████████| 40/40 [00:44<00:00,  1.10s/it]


011 file → Final Avg Loss: 3.0780 Final Accuracy: 0.2818 Val Accuracy: 0.1948
Saved new best model with accurcy 28.1754 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_16.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_16.npz


100%|██████████| 40/40 [00:43<00:00,  1.09s/it]


016 file → Final Avg Loss: 3.0471 Final Accuracy: 0.2844 Val Accuracy: 0.2098
Saved new best model with accurcy 28.4377 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_21.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_21.npz


100%|██████████| 40/40 [00:44<00:00,  1.10s/it]


021 file → Final Avg Loss: 2.9237 Final Accuracy: 0.3099 Val Accuracy: 0.2252
Saved new best model with accurcy 30.9877 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_26.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_26.npz


100%|██████████| 40/40 [00:43<00:00,  1.08s/it]


026 file → Final Avg Loss: 2.8249 Final Accuracy: 0.3288 Val Accuracy: 0.2404
Saved new best model with accurcy 32.8785 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_31.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_31.npz


100%|██████████| 40/40 [00:43<00:00,  1.09s/it]


031 file → Final Avg Loss: 2.7420 Final Accuracy: 0.3448 Val Accuracy: 0.2560
Saved new best model with accurcy 34.4800 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_36.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_36.npz


100%|██████████| 40/40 [00:43<00:00,  1.09s/it]


036 file → Final Avg Loss: 2.7375 Final Accuracy: 0.3468 Val Accuracy: 0.2622
Saved new best model with accurcy 34.6815 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_41.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_41.npz


100%|██████████| 40/40 [00:43<00:00,  1.08s/it]


041 file → Final Avg Loss: 2.6974 Final Accuracy: 0.3519 Val Accuracy: 0.2718
Saved new best model with accurcy 35.1869 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_46.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_46.npz


100%|██████████| 40/40 [00:43<00:00,  1.10s/it]


046 file → Final Avg Loss: 2.6487 Final Accuracy: 0.3632 Val Accuracy: 0.2798
Saved new best model with accurcy 36.3177 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_51.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_51.npz


100%|██████████| 40/40 [00:43<00:00,  1.08s/it]


051 file → Final Avg Loss: 2.6399 Final Accuracy: 0.3654 Val Accuracy: 0.2874
Saved new best model with accurcy 36.5408 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_56.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_56.npz


100%|██████████| 40/40 [00:43<00:00,  1.09s/it]


056 file → Final Avg Loss: 2.5480 Final Accuracy: 0.3838 Val Accuracy: 0.3010
Saved new best model with accurcy 38.3769 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_61.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_61.npz


100%|██████████| 40/40 [00:43<00:00,  1.10s/it]


061 file → Final Avg Loss: 2.5672 Final Accuracy: 0.3797 Val Accuracy: 0.3052
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_66.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_66.npz


100%|██████████| 40/40 [00:43<00:00,  1.08s/it]


066 file → Final Avg Loss: 2.5307 Final Accuracy: 0.3875 Val Accuracy: 0.3118
Saved new best model with accurcy 38.7538 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/train-data-checkpoint-mae-base-1-70/train_data_checkpoint_mae_base_1-70/train_data_checkpoint_71.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_71.npz


100%|██████████| 40/40 [00:44<00:00,  1.10s/it]


071 file → Final Avg Loss: 2.4815 Final Accuracy: 0.3987 Val Accuracy: 0.3244
Saved new best model with accurcy 39.8654 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/linearprobing-mae-base-featureextractorreal-2/extracted_features/train_data_checkpoint_76.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_76.npz


100%|██████████| 40/40 [00:43<00:00,  1.09s/it]


076 file → Final Avg Loss: 2.4673 Final Accuracy: 0.4020 Val Accuracy: 0.3252
Saved new best model with accurcy 40.1992 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/linearprobing-mae-base-featureextractorreal-2/extracted_features/train_data_checkpoint_81.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_81.npz


100%|██████████| 40/40 [00:43<00:00,  1.10s/it]


081 file → Final Avg Loss: 2.4162 Final Accuracy: 0.4123 Val Accuracy: 0.3370
Saved new best model with accurcy 41.2262 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/linearprobing-mae-base-featureextractorreal-2/extracted_features/train_data_checkpoint_86.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_86.npz


100%|██████████| 40/40 [00:43<00:00,  1.08s/it]


086 file → Final Avg Loss: 2.3942 Final Accuracy: 0.4170 Val Accuracy: 0.3348
Saved new best model with accurcy 41.6962 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/linearprobing-mae-base-featureextractorreal-2/extracted_features/train_data_checkpoint_91.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_91.npz


100%|██████████| 40/40 [00:43<00:00,  1.09s/it]


091 file → Final Avg Loss: 2.3909 Final Accuracy: 0.4180 Val Accuracy: 0.3384
Saved new best model with accurcy 41.7969 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/linearprobing-mae-base-featureextractorreal-2/extracted_features/train_data_checkpoint_96.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_96.npz


100%|██████████| 40/40 [00:43<00:00,  1.10s/it]


096 file → Final Avg Loss: 2.3846 Final Accuracy: 0.4193 Val Accuracy: 0.3400
Saved new best model with accurcy 41.9308 at /kaggle/working/models/best_model.pth.tar
Working on file: /kaggle/input/linearprobing-mae-base-featureextractorreal-2/extracted_features/train_data_checkpoint_100.npz
Eval file: /kaggle/input/linearprobing-mae-base-featureextractorreal-val/extracted_features/val_data_checkpoint_100.npz


100%|██████████| 40/40 [00:43<00:00,  1.09s/it]


100 file → Final Avg Loss: 2.3834 Final Accuracy: 0.4193 Val Accuracy: 0.3374
Saved new best model with accurcy 41.9315 at /kaggle/working/models/best_model.pth.tar
Results saved to results.csv
