In [1]:
import pandas as pd
import numpy as np
import branched_resnet_v2 as br
import preprocessing as pre
from preprocessing import combine_npzs
import os
import torch

print(torch.__version__)

# Set device and verify CUDA availability
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


2.6.0+cu118
CUDA available: True
CUDA version: 11.8
Device count: 1
Current device: 0
Device name: NVIDIA GeForce RTX 4070
Using device: cuda


In [None]:
ring_train = combine_npzs("data/ringv2/train")
ring_val = combine_npzs("data/ringv2/val")
# ring_test = combine_npzs("data/ringv2/test")




Processing file: train_subset_first_5000_RingArtifactv1_images.npz
Processing file: train_subset_second_5000_RingArtifactv1_images.npz
Processing file: train_subset_third_5000_RingArtifactv1_images.npz
Processing file: train_subset_fourth_5000_RingArtifactv1_images.npz
Processing file: train_subset_fifth_5000_RingArtifactv1_images.npz
Processing file: train_subset_sixth_5000_RingArtifactv1_images.npz
Processing file: train_subset_last_4561_RingArtifactv1_images.npz
Processing file: val_subset_first_5000_RingArtifactv1_images.npz
Processing file: val_subset_last_1491_RingArtifactv1_images.npz


In [None]:
train = np.load("data/fixed/train.npz")
val = np.load("data/fixed/validation.npz")
# test = np.load("data/fixed/test.npz")

train = {**train, **ring_train}
val = {**val, **ring_val}
# test = {**test, **ring_test}

In [None]:
print("Train keys:", train.keys())
print("Validation keys:", val.keys())
# print("Test keys:", test.keys())

# print(len(train['label']), len(val['label']), len(test['label']))  # Example key: 'images'
print(len(train['label']), len(val['label']))  # Example key: 'images'


Train keys: dict_keys(['original', 'label', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1'])
Validation keys: dict_keys(['original', 'label', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1'])
34561 6491


In [None]:
len(train['label'])

34561

In [None]:
# datasets = [train, val, test]
datasets = [train, val]
data = {}

for k in datasets[0].keys():
    print(f"Processing key: {k}")
    data[k] = np.concatenate([d[k] for d in datasets])



Processing key: original
Processing key: label
Processing key: Uniform_Noise
Processing key: Rotate_90deg
Processing key: Ring_Artifact_v1


In [None]:
print(len(data['label']))  # Example key: 'images'

41052


In [2]:
# Save data to a new .npz file
# np.savez_compressed("D:/data/D21_cross_val_data.npz", **data)

# Load the combined dataset
data = np.load("D:/data/D21_cross_val_data.npz")

In [3]:
first_norm, second_norm, folds, labels, first_domain, second_domain = pre.kfold_split(data, "original", "Ring_Artifact_v1")

In [5]:
val_dataset, train_dataset = pre.retrieve_fold_data(
    fold_index=0,
    folds=folds,
    labels=labels,
    first_norm=first_norm,
    second_norm=second_norm,
    first_domain=first_domain,
    second_domain=second_domain
)

print("Train dataset size:", len(train_dataset))
print("Test dataset size:", len(val_dataset))

Train dataset size: 65672
Test dataset size: 16432


In [6]:
from transformers import Trainer, TrainingArguments, set_seed
import datetime

def parabolic_increasing_lambda_scheduler(epoch, total_epochs, start_value=0.0, end_value=1.0):
    progress = epoch / total_epochs
    return start_value + (end_value - start_value) * (progress ** 2)

DATE = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Set output directory
output_dir = f'./data/D21_cv_results/'

SEED = 42

# Set seed
torch.manual_seed(SEED)
set_seed(SEED)

def run_experiment(lambda_scheduler, from_checkpoint=None, train_input=train_dataset, val_input=val_dataset, num_epochs=3, learning_rate=0.1, optimizer='sgd', weight_decay=1e-4, fold_index=0):

    NUM_EPOCHS = num_epochs
    lr = learning_rate
    OPTIM = optimizer
    WEIGHT_DECAY = weight_decay

    # Initialize model
    config = br.ResNetConfig()
    model = br.ResNetForMultiLabel(config=config, num_d1_classes=11, num_d2_classes=2, lamb = 0)

    scheduler_name = lambda_scheduler.__name__

    # Set training arguments
    training_args = TrainingArguments(
            output_dir=f"{output_dir}/{scheduler_name}_fold_{fold_index}_results_{DATE}",
            num_train_epochs=NUM_EPOCHS,
            per_device_train_batch_size=32,
            eval_strategy="epoch",
            save_strategy="epoch",
            logging_dir='./logs',
            logging_steps=10,
            load_best_model_at_end=True,
            learning_rate=lr,
            weight_decay=WEIGHT_DECAY,
            seed=SEED,
            optim=OPTIM
        )

    trainer = Trainer(
        model = model,
        args = training_args,
        train_dataset= train_input,
        eval_dataset= val_input,
        compute_metrics=br.make_metrics_fn(model),
        callbacks=[br.LambdaUpdateCallback(model, lambda_scheduler, NUM_EPOCHS)]
    )
    if from_checkpoint != None:
        trainer.train(resume_from_checkpoint=from_checkpoint)
    else:
        trainer.train()

    trainer.save_model(f"{output_dir}/{scheduler_name}_fold_{fold_index}_final_model_{DATE}")

    metrics = trainer.evaluate(eval_dataset=val_input)

    return metrics

# run_experiment(br.lambda_scheduler)


In [9]:
# Starting from checkpoint

val_dataset, train_dataset = pre.retrieve_fold_data(
    fold_index=4,
    folds=folds,
    labels=labels,
    first_norm=first_norm,
    second_norm=second_norm,
    first_domain=first_domain,
    second_domain=second_domain
)

run_experiment(parabolic_increasing_lambda_scheduler, val_input=val_dataset, train_input=train_dataset, num_epochs=50, fold_index=4, from_checkpoint="data/D21_cv_results/parabolic_increasing_lambda_scheduler_results_2025-09-26_01-45-15/checkpoint-61590")

[34m[1mwandb[0m: Currently logged in as: [33msamuelsavine[0m ([33msamuelsavine-johns-hopkins-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
31,1.7292,1.75306,0.996405,0.99529,0.995651,0.995469,0.5,0.25,0.5,0.333333,0.36
32,1.9119,1.75476,0.996222,0.995476,0.995373,0.995418,0.5,0.25,0.5,0.333333,0.3844
33,1.6088,1.914343,0.996283,0.995571,0.995289,0.995423,0.5,0.25,0.5,0.333333,0.4096
34,1.6607,1.392305,0.996405,0.995129,0.995612,0.995359,0.5,0.25,0.5,0.333333,0.4356
35,1.5312,1.42322,0.996161,0.995278,0.994957,0.995106,0.5,0.25,0.5,0.333333,0.4624
36,1.4968,1.421127,0.995857,0.994793,0.995079,0.994927,0.5,0.25,0.5,0.333333,0.49
37,1.4491,1.511501,0.996039,0.995406,0.995097,0.995246,0.5,0.25,0.5,0.333333,0.5184
38,1.3159,1.549118,0.996283,0.995343,0.995234,0.995283,0.5,0.25,0.5,0.333333,0.5476
39,1.3705,1.362503,0.995918,0.994742,0.994798,0.994767,0.5,0.25,0.5,0.333333,0.5776
40,1.1705,1.16728,0.995674,0.994155,0.994683,0.994415,0.5,0.25,0.5,0.333333,0.6084


{'eval_loss': 0.6999772191047668,
 'eval_accuracy_branch1': 0.9964050694613698,
 'eval_precision_branch1': 0.99511070931697,
 'eval_recall_branch1': 0.9954895578433749,
 'eval_f1_branch1': 0.9952970575890556,
 'eval_accuracy_branch2': 0.47532293443821594,
 'eval_precision_branch2': 0.35389043860169744,
 'eval_recall_branch2': 0.47532293443821594,
 'eval_f1_branch2': 0.3377159055561145,
 'eval_lambda': 0.9603999999999999,
 'eval_runtime': 22.318,
 'eval_samples_per_second': 735.37,
 'eval_steps_per_second': 91.944,
 'epoch': 50.0}

In [None]:
metrics = []

for i in range(5):
    print(f"Starting fold {i+1}/5")
    val_dataset, train_dataset = pre.retrieve_fold_data(
        fold_index=i,
        folds=folds,
        labels=labels,
        first_norm=first_norm,
        second_norm=second_norm,
        first_domain=first_domain,
        second_domain=second_domain
    )

    metrics.append(run_experiment(parabolic_increasing_lambda_scheduler, val_input=val_dataset, train_input=train_dataset, num_epochs=50, fold_index=i))
    print(f"Fold {i+1} Metrics: {metrics[-1]}")
    # save intermediate results
    pd.DataFrame(metrics).to_csv(f"{output_dir}/intermediate_results_{DATE}.csv", index=False)
    


Starting fold 1/5


[34m[1mwandb[0m: Currently logged in as: [33msamuelsavine[0m ([33msamuelsavine-johns-hopkins-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,0.9472,0.864756,0.814387,0.791411,0.767778,0.749032,0.635467,0.675978,0.635467,0.613207,0.0
2,0.8363,0.853568,0.924111,0.909769,0.907789,0.902871,0.509859,0.657085,0.509859,0.35987,0.0004
3,0.8946,0.730714,0.961417,0.958933,0.959501,0.95897,0.584895,0.602227,0.584895,0.566522,0.0016
4,0.8818,0.810985,0.970545,0.961501,0.969549,0.965169,0.538522,0.548598,0.538522,0.513296,0.0036
5,0.8604,0.739387,0.982717,0.98045,0.981736,0.981015,0.527933,0.551581,0.527933,0.466824,0.0064
6,0.8193,0.859353,0.986551,0.983498,0.985626,0.984531,0.501643,0.598239,0.501643,0.339208,0.01
7,0.8563,0.804416,0.987342,0.986427,0.985667,0.985988,0.510711,0.518867,0.510711,0.451423,0.0144
8,0.8921,0.969468,0.992575,0.991209,0.992114,0.991645,0.505903,0.508751,0.505903,0.462143,0.0196
9,0.9317,0.734502,0.991358,0.990749,0.990145,0.990429,0.516492,0.519379,0.516492,0.49779,0.0256
10,0.9625,0.742344,0.990506,0.990219,0.989761,0.989942,0.512658,0.518445,0.512658,0.47118,0.0324


Fold 1 Metrics: {'eval_loss': 0.7047597169876099, 'eval_accuracy_branch1': 0.9943403115871471, 'eval_precision_branch1': 0.992970529034539, 'eval_recall_branch1': 0.9935458971959097, 'eval_f1_branch1': 0.9932472574910519, 'eval_accuracy_branch2': 0.4807692307692308, 'eval_precision_branch2': 0.2599729006910915, 'eval_recall_branch2': 0.4807692307692308, 'eval_f1_branch2': 0.32570041287398527, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 37.7702, 'eval_samples_per_second': 435.052, 'eval_steps_per_second': 54.381, 'epoch': 50.0}
Starting fold 2/5


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,0.9617,0.809442,0.83662,0.823221,0.813016,0.814492,0.721756,0.72562,0.721756,0.720559,0.0
2,0.8788,0.736154,0.942111,0.930014,0.933177,0.929833,0.587473,0.589584,0.587473,0.585028,0.0004
3,0.8307,0.747517,0.963051,0.957907,0.958766,0.957452,0.548758,0.54884,0.548758,0.548568,0.0016
4,0.8576,0.747082,0.977721,0.973793,0.976734,0.975092,0.538958,0.54247,0.538958,0.529226,0.0036
5,0.9078,0.772093,0.985269,0.982879,0.983561,0.983176,0.517044,0.530537,0.517044,0.457071,0.0064
6,0.8534,0.718655,0.987643,0.986654,0.984526,0.985511,0.523618,0.5239,0.523618,0.52221,0.01
7,0.916,0.908654,0.991721,0.989999,0.990847,0.990406,0.501096,0.673351,0.501096,0.336193,0.0144
8,0.8689,0.831836,0.992513,0.991116,0.991562,0.991334,0.502922,0.525143,0.502922,0.361945,0.0196
9,0.9732,0.759303,0.992635,0.990738,0.992484,0.99159,0.501948,0.50289,0.501948,0.457753,0.0256
10,1.055,0.90688,0.992513,0.991297,0.990797,0.991005,0.502496,0.598247,0.502496,0.342231,0.0324


Fold 2 Metrics: {'eval_loss': 0.700166642665863, 'eval_accuracy_branch1': 0.9959215972729486, 'eval_precision_branch1': 0.9950443166647754, 'eval_recall_branch1': 0.9951925624639483, 'eval_f1_branch1': 0.9951032674925379, 'eval_accuracy_branch2': 0.4988434380326272, 'eval_precision_branch2': 0.3557708835842413, 'eval_recall_branch2': 0.4988434380326272, 'eval_f1_branch2': 0.33357259373635306, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 38.023, 'eval_samples_per_second': 432.054, 'eval_steps_per_second': 54.02, 'epoch': 50.0}
Starting fold 3/5


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,1.0396,0.852961,0.831344,0.8308,0.802964,0.803182,0.64009,0.693712,0.64009,0.613331,0.0
2,0.8382,0.736545,0.94372,0.935299,0.940669,0.936649,0.610245,0.610422,0.610245,0.610088,0.0004
3,0.8864,0.892802,0.971982,0.963943,0.970279,0.966934,0.504203,0.604355,0.504203,0.347694,0.0016
4,0.8243,0.714296,0.976489,0.973685,0.974422,0.973902,0.566269,0.618879,0.566269,0.512312,0.0036
5,0.8297,0.81092,0.973444,0.963762,0.956073,0.955256,0.513948,0.591037,0.513948,0.38342,0.0064
6,0.8559,0.748565,0.990803,0.987986,0.989072,0.98847,0.522597,0.528711,0.522597,0.495755,0.01
7,0.8382,0.762,0.989828,0.98803,0.988997,0.988418,0.504203,0.504358,0.504203,0.499732,0.0144
8,0.8732,0.902488,0.994031,0.993703,0.992972,0.993324,0.5,0.25,0.5,0.333333,0.0196
9,0.968,1.014607,0.990255,0.985925,0.989383,0.987487,0.501523,0.652821,0.501523,0.337564,0.0256
10,1.0102,0.780924,0.995432,0.993678,0.994439,0.994052,0.501766,0.523178,0.501766,0.352146,0.0324


Fold 3 Metrics: {'eval_loss': 0.6988590955734253, 'eval_accuracy_branch1': 0.9972591058594226, 'eval_precision_branch1': 0.99684934690912, 'eval_recall_branch1': 0.9963503720197664, 'eval_f1_branch1': 0.9965964986887846, 'eval_accuracy_branch2': 0.4971372883420636, 'eval_precision_branch2': 0.4828181689036303, 'eval_recall_branch2': 0.49713728834206355, 'eval_f1_branch2': 0.36479417508602807, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 21.8683, 'eval_samples_per_second': 750.767, 'eval_steps_per_second': 93.88, 'epoch': 50.0}
Starting fold 4/5


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,1.004,0.85693,0.831059,0.827185,0.797665,0.8045,0.641647,0.754717,0.641647,0.596915,0.0
2,0.8824,0.841672,0.938345,0.933658,0.919287,0.922284,0.513769,0.611793,0.513769,0.377258,0.0004
3,0.8351,0.720073,0.976057,0.971802,0.973024,0.972323,0.570062,0.648409,0.570062,0.504693,0.0016
4,0.8655,0.812672,0.983855,0.980444,0.981418,0.980897,0.501279,0.681355,0.501279,0.336599,0.0036
5,0.8724,1.158125,0.987754,0.987494,0.986173,0.986782,0.5,0.25,0.5,0.333333,0.0064
6,0.842,0.789446,0.992202,0.990507,0.991195,0.990833,0.518521,0.524356,0.518521,0.487846,0.01
7,0.8787,0.753924,0.992872,0.99183,0.99199,0.991903,0.50396,0.512681,0.50396,0.400968,0.0144
8,0.7836,0.745982,0.993969,0.99298,0.992755,0.99285,0.516815,0.516988,0.516815,0.515579,0.0196
9,0.836,0.860126,0.993298,0.99222,0.992358,0.992254,0.501828,0.518851,0.501828,0.356565,0.0256
10,0.8923,0.904337,0.995431,0.994259,0.994764,0.9945,0.507311,0.511384,0.507311,0.458913,0.0324


Fold 4 Metrics: {'eval_loss': 0.6985048055648804, 'eval_accuracy_branch1': 0.9970147435116364, 'eval_precision_branch1': 0.9960965589707853, 'eval_recall_branch1': 0.996177764083982, 'eval_f1_branch1': 0.9961298236898856, 'eval_accuracy_branch2': 0.49031314731326914, 'eval_precision_branch2': 0.3788711149148839, 'eval_recall_branch2': 0.4903131473132692, 'eval_f1_branch2': 0.33806288686603275, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 20.3006, 'eval_samples_per_second': 808.549, 'eval_steps_per_second': 101.081, 'epoch': 50.0}
Starting fold 5/5


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,0.9952,0.91114,0.814465,0.818479,0.792955,0.785353,0.591884,0.648038,0.591884,0.549128,0.0
2,0.8538,0.764374,0.940958,0.931033,0.938486,0.934256,0.572752,0.576087,0.572752,0.568018,0.0004
3,0.8718,0.777444,0.96137,0.952997,0.959506,0.955729,0.531136,0.544348,0.531136,0.493405,0.0016
4,0.8706,0.748921,0.975384,0.971244,0.970058,0.970254,0.534853,0.534891,0.534853,0.534723,0.0036
5,0.928,0.737616,0.982452,0.97761,0.980823,0.979013,0.513588,0.524474,0.513588,0.452728,0.0064
6,0.7788,0.7461,0.988301,0.983023,0.986791,0.98474,0.512613,0.518403,0.512613,0.471001,0.01
7,0.8731,0.778522,0.98757,0.98238,0.984617,0.983161,0.507982,0.512094,0.507982,0.462273,0.0144
8,0.9701,0.782978,0.994333,0.993713,0.993197,0.993443,0.500853,0.500926,0.500853,0.490854,0.0196
9,0.9158,0.85193,0.994455,0.993616,0.993148,0.993371,0.500061,0.503984,0.500061,0.336796,0.0256
10,0.877,1.274584,0.993176,0.990741,0.992812,0.991737,0.5,0.25,0.5,0.333333,0.0324
