In [1]:
def get_epoch_from_filename(file_path):
    base = os.path.basename(file_path)  # e.g. 'train_data_checkpoint_005.npz'
    name, _ = os.path.splitext(base)    # e.g. 'train_data_checkpoint_005'
    epoch_str = "000"+name               # last 3 chars before extension, e.g. '005'
    return epoch_str[-3:]

In [2]:
import os
import csv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm
import gc

def train_on_npz(file_path, input_dim, num_classes, input_parameters):
    batch_size = input_parameters.get("batch_size", 256)
    learning_rate = input_parameters.get("learning_rate", 0.1)
    num_epochs = input_parameters.get("num_epochs", 20)
    momentum = input_parameters.get("momentum", 0.9)
    weight_decay = input_parameters.get("weight_decay", 1e-4)
    device = input_parameters.get("device", "cuda" if torch.cuda.is_available() else "cpu")

    data = np.load(file_path)
    X = torch.from_numpy(data["features"]).float()
    y = torch.from_numpy(data["labels"]).long()
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    model = nn.Sequential(
        nn.BatchNorm1d(input_dim, affine=False, eps=1e-6),
        nn.Linear(input_dim, num_classes)
    ).to(device)

    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    model.train()
    for epoch in tqdm(range(num_epochs)):
        total_loss = 0.0
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            optimizer.zero_grad()
            logits = model(batch_X)
            loss = criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
    avg_loss = total_loss / len(dataloader)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            logits = model(batch_X)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == batch_y).sum().item()
            total += batch_y.size(0)

    del optimizer
    
    accuracy = correct / total
    return avg_loss, accuracy, model

def eval_on_val_file(model, file_path, input_dim, num_classes, input_parameters):
    batch_size = input_parameters.get("batch_size", 256)
    device = input_parameters.get("device", "cuda" if torch.cuda.is_available() else "cpu")

    data = np.load(file_path)
    X = torch.from_numpy(data["features"]).float()
    y = torch.from_numpy(data["labels"]).long()
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            logits = model(batch_X)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == batch_y).sum().item()
            total += batch_y.size(0)

    accuracy = correct / total
    return accuracy

def run_linear_probing_on_folder(folder_path, val_folder_path, input_dim, num_classes, input_parameters, csv_path="results.csv"):
    files = os.listdir(folder_path)
    npz_files = [f for f in files if f.endswith(".npz")]
    val_files = [f for f in os.listdir(val_folder_path) if f.endswith(".npz")]
    npz_files.sort()
    
    results = []

    best_accuracy = 0
    
    for file_name in npz_files:
        file_path = os.path.join(folder_path, file_name)
        epoch = get_epoch_from_filename(file_name)
        epoch = "00"+epoch
        val_file_path = os.path.join(val_folder_path, f"val_data_checkpoint_{epoch[-3:]}.npz")
        
        print(f"Working on file: {file_name}")
        print(f"Eval file: {val_file_path}")
        
        loss, accuracy, model = train_on_npz(file_path, input_dim, num_classes, input_parameters)
        val_accuracy = eval_on_val_file(model, val_file_path, input_dim, num_classes, input_parameters)
        print(f"{epoch} file → Final Avg Loss: {loss:.4f} Final Accuracy: {accuracy:.4f} Val Accuracy: {val_accuracy:.4f}")
        results.append({"epoch": epoch, "loss": loss, "accuracy": accuracy, 'val_accuracy': val_accuracy})

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model = model
            save_path = "/kaggle/working/models/best_model.pth.tar"
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save({
                'epoch_file_epoch': epoch,
                'model_state_dict': best_model.state_dict(),
                'loss': loss,
                'accuracy': best_accuracy,
                'val_accuracy': val_accuracy,
            }, save_path)
            print(f"Saved new best model with accurcy {best_accuracy * 100:.4f} at {save_path}")

        del model
        gc.collect()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        

    with open(csv_path, mode="w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["epoch", "loss", "accuracy", "val_accuracy"])
        writer.writeheader()
        writer.writerows(results)

    print(f"Results saved to {csv_path}")
    return results


In [3]:
input_params = {
    "batch_size": 512,
    "learning_rate": 0.1,
    "num_epochs": 40,
    "momentum": 0.9,
    "weight_decay": 1e-4,
    "device": "cuda"
}

folder = "/kaggle/input/linearprobing-simclr-featureextractor-train/extracted_features"
val_folder = "/kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features"
input_dim = 2048
num_classes = 100

results = run_linear_probing_on_folder(folder, val_folder, input_dim, num_classes, input_params)

Working on file: train_data_checkpoint_005.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_005.npz


100%|██████████| 40/40 [01:01<00:00,  1.55s/it]


00005 file → Final Avg Loss: 2.8985 Final Accuracy: 0.3077 Val Accuracy: 0.2128
Saved new best model with accurcy 30.7677 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_010.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_010.npz


100%|██████████| 40/40 [01:01<00:00,  1.53s/it]


00010 file → Final Avg Loss: 2.6041 Final Accuracy: 0.3690 Val Accuracy: 0.2568
Saved new best model with accurcy 36.9038 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_015.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_015.npz


100%|██████████| 40/40 [01:01<00:00,  1.53s/it]


00015 file → Final Avg Loss: 2.4149 Final Accuracy: 0.4078 Val Accuracy: 0.2976
Saved new best model with accurcy 40.7838 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_020.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_020.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00020 file → Final Avg Loss: 2.2978 Final Accuracy: 0.4325 Val Accuracy: 0.3172
Saved new best model with accurcy 43.2492 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_025.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_025.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00025 file → Final Avg Loss: 2.1750 Final Accuracy: 0.4610 Val Accuracy: 0.3396
Saved new best model with accurcy 46.0985 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_030.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_030.npz


100%|██████████| 40/40 [01:01<00:00,  1.53s/it]


00030 file → Final Avg Loss: 2.0802 Final Accuracy: 0.4821 Val Accuracy: 0.3534
Saved new best model with accurcy 48.2092 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_035.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_035.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00035 file → Final Avg Loss: 2.0328 Final Accuracy: 0.4927 Val Accuracy: 0.3604
Saved new best model with accurcy 49.2708 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_040.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_040.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00040 file → Final Avg Loss: 1.9308 Final Accuracy: 0.5155 Val Accuracy: 0.3856
Saved new best model with accurcy 51.5508 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_045.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_045.npz


100%|██████████| 40/40 [01:00<00:00,  1.51s/it]


00045 file → Final Avg Loss: 1.8654 Final Accuracy: 0.5305 Val Accuracy: 0.3900
Saved new best model with accurcy 53.0462 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_050.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_050.npz


100%|██████████| 40/40 [01:00<00:00,  1.51s/it]


00050 file → Final Avg Loss: 1.8211 Final Accuracy: 0.5416 Val Accuracy: 0.3976
Saved new best model with accurcy 54.1646 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_055.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_055.npz


100%|██████████| 40/40 [01:00<00:00,  1.51s/it]


00055 file → Final Avg Loss: 1.7811 Final Accuracy: 0.5510 Val Accuracy: 0.4002
Saved new best model with accurcy 55.0985 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_060.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_060.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00060 file → Final Avg Loss: 1.7346 Final Accuracy: 0.5607 Val Accuracy: 0.4074
Saved new best model with accurcy 56.0738 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_065.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_065.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00065 file → Final Avg Loss: 1.6842 Final Accuracy: 0.5720 Val Accuracy: 0.4202
Saved new best model with accurcy 57.1992 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_070.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_070.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00070 file → Final Avg Loss: 1.6710 Final Accuracy: 0.5763 Val Accuracy: 0.4270
Saved new best model with accurcy 57.6277 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_075.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_075.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00075 file → Final Avg Loss: 1.5938 Final Accuracy: 0.5938 Val Accuracy: 0.4322
Saved new best model with accurcy 59.3823 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_080.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_080.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00080 file → Final Avg Loss: 1.5488 Final Accuracy: 0.6046 Val Accuracy: 0.4422
Saved new best model with accurcy 60.4638 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_085.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_085.npz


100%|██████████| 40/40 [01:01<00:00,  1.53s/it]


00085 file → Final Avg Loss: 1.5255 Final Accuracy: 0.6106 Val Accuracy: 0.4422
Saved new best model with accurcy 61.0638 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_090.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_090.npz


100%|██████████| 40/40 [01:00<00:00,  1.52s/it]


00090 file → Final Avg Loss: 1.4765 Final Accuracy: 0.6236 Val Accuracy: 0.4496
Saved new best model with accurcy 62.3577 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_095.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_095.npz


100%|██████████| 40/40 [01:01<00:00,  1.53s/it]


00095 file → Final Avg Loss: 1.4454 Final Accuracy: 0.6284 Val Accuracy: 0.4546
Saved new best model with accurcy 62.8377 at /kaggle/working/models/best_model.pth.tar
Working on file: train_data_checkpoint_100.npz
Eval file: /kaggle/input/linearprobing-simclr-featureextractor-valid/extracted_features/val_data_checkpoint_100.npz


100%|██████████| 40/40 [01:01<00:00,  1.53s/it]


00100 file → Final Avg Loss: 1.4108 Final Accuracy: 0.6386 Val Accuracy: 0.4568
Saved new best model with accurcy 63.8615 at /kaggle/working/models/best_model.pth.tar
Results saved to results.csv
