In [1]:
import pandas as pd
import numpy as np
import branched_resnet_v2 as br
import preprocessing as pre
from preprocessing import combine_npzs
import os
import torch

print(torch.__version__)

# Set device and verify CUDA availability
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


2.6.0+cu118
CUDA available: True
CUDA version: 11.8
Device count: 1
Current device: 0
Device name: NVIDIA GeForce RTX 4070
Using device: cuda


In [None]:
ring_train = combine_npzs("data/ringv2/train")
ring_val = combine_npzs("data/ringv2/val")
# ring_test = combine_npzs("data/ringv2/test")




Processing file: train_subset_first_5000_RingArtifactv1_images.npz
Processing file: train_subset_second_5000_RingArtifactv1_images.npz
Processing file: train_subset_third_5000_RingArtifactv1_images.npz
Processing file: train_subset_fourth_5000_RingArtifactv1_images.npz
Processing file: train_subset_fifth_5000_RingArtifactv1_images.npz
Processing file: train_subset_sixth_5000_RingArtifactv1_images.npz
Processing file: train_subset_last_4561_RingArtifactv1_images.npz
Processing file: val_subset_first_5000_RingArtifactv1_images.npz
Processing file: val_subset_last_1491_RingArtifactv1_images.npz


In [None]:
train = np.load("data/fixed/train.npz")
val = np.load("data/fixed/validation.npz")
# test = np.load("data/fixed/test.npz")

train = {**train, **ring_train}
val = {**val, **ring_val}
# test = {**test, **ring_test}

In [None]:
print("Train keys:", train.keys())
print("Validation keys:", val.keys())
# print("Test keys:", test.keys())

# print(len(train['label']), len(val['label']), len(test['label']))  # Example key: 'images'
print(len(train['label']), len(val['label']))  # Example key: 'images'


Train keys: dict_keys(['original', 'label', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1'])
Validation keys: dict_keys(['original', 'label', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1'])
34561 6491


In [None]:
len(train['label'])

34561

In [None]:
# datasets = [train, val, test]
datasets = [train, val]
data = {}

for k in datasets[0].keys():
    print(f"Processing key: {k}")
    data[k] = np.concatenate([d[k] for d in datasets])



Processing key: original
Processing key: label
Processing key: Uniform_Noise
Processing key: Rotate_90deg
Processing key: Ring_Artifact_v1


In [None]:
print(len(data['label']))  # Example key: 'images'

41052


In [2]:
# Save data to a new .npz file
# np.savez_compressed("D:/data/cross_val_data.npz", **data)

# Load the combined dataset
data = np.load("D:/data/cross_val_data.npz")

In [3]:
first_norm, second_norm, folds, labels, first_domain, second_domain = pre.kfold_split(data, "original", "Rotate_90deg")

In [4]:
val_dataset, train_dataset = pre.retrieve_fold_data(
    fold_index=0,
    folds=folds,
    labels=labels,
    first_norm=first_norm,
    second_norm=second_norm,
    first_domain=first_domain,
    second_domain=second_domain
)

print("Train dataset size:", len(train_dataset))
print("Test dataset size:", len(val_dataset))

Train dataset size: 65672
Test dataset size: 16432


In [5]:
from transformers import Trainer, TrainingArguments, set_seed
import datetime

def parabolic_increasing_lambda_scheduler(epoch, total_epochs, start_value=0.0, end_value=1.0):
    progress = epoch / total_epochs
    return start_value + (end_value - start_value) * (progress ** 2)

DATE = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Set output directory
output_dir = f'./data/D20_cv_results/'

SEED = 42

# Set seed
torch.manual_seed(SEED)
set_seed(SEED)

def run_experiment(lambda_scheduler, from_checkpoint=None, train_input=train_dataset, val_input=val_dataset, num_epochs=3, learning_rate=0.1, optimizer='sgd', weight_decay=1e-4, fold_index=0):

    NUM_EPOCHS = num_epochs
    lr = learning_rate
    OPTIM = optimizer
    WEIGHT_DECAY = weight_decay

    # Initialize model
    config = br.ResNetConfig()
    model = br.ResNetForMultiLabel(config=config, num_d1_classes=11, num_d2_classes=2, lamb = 0)

    scheduler_name = lambda_scheduler.__name__

    # Set training arguments
    training_args = TrainingArguments(
            output_dir=f"{output_dir}/{scheduler_name}_fold_{fold_index}_results_{DATE}",
            num_train_epochs=NUM_EPOCHS,
            per_device_train_batch_size=32,
            eval_strategy="epoch",
            save_strategy="epoch",
            logging_dir='./logs',
            logging_steps=10,
            load_best_model_at_end=True,
            learning_rate=lr,
            weight_decay=WEIGHT_DECAY,
            seed=SEED,
            optim=OPTIM
        )

    trainer = Trainer(
        model = model,
        args = training_args,
        train_dataset= train_input,
        eval_dataset= val_input,
        compute_metrics=br.make_metrics_fn(model),
        callbacks=[br.LambdaUpdateCallback(model, lambda_scheduler, NUM_EPOCHS)]
    )
    if from_checkpoint != None:
        trainer.train(resume_from_checkpoint=from_checkpoint)
    else:
        trainer.train()

    trainer.save_model(f"{output_dir}/{scheduler_name}_fold_{fold_index}_final_model_{DATE}")

    metrics = trainer.evaluate(eval_dataset=val_input)

    return metrics

# run_experiment(br.lambda_scheduler)


In [None]:
# Starting from checkpoint

# val_dataset, train_dataset = pre.retrieve_fold_data(
#     fold_index=4,
#     folds=folds,
#     labels=labels,
#     first_norm=first_norm,
#     second_norm=second_norm,
#     first_domain=first_domain,
#     second_domain=second_domain
# )

# run_experiment(parabolic_increasing_lambda_scheduler, val_input=val_dataset, train_input=train_dataset, num_epochs=50, fold_index=4, from_checkpoint="data/D21_cv_results/parabolic_increasing_lambda_scheduler_results_2025-09-26_01-45-15/checkpoint-61590")

[34m[1mwandb[0m: Currently logged in as: [33msamuelsavine[0m ([33msamuelsavine-johns-hopkins-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
31,1.7292,1.75306,0.996405,0.99529,0.995651,0.995469,0.5,0.25,0.5,0.333333,0.36
32,1.9119,1.75476,0.996222,0.995476,0.995373,0.995418,0.5,0.25,0.5,0.333333,0.3844
33,1.6088,1.914343,0.996283,0.995571,0.995289,0.995423,0.5,0.25,0.5,0.333333,0.4096
34,1.6607,1.392305,0.996405,0.995129,0.995612,0.995359,0.5,0.25,0.5,0.333333,0.4356
35,1.5312,1.42322,0.996161,0.995278,0.994957,0.995106,0.5,0.25,0.5,0.333333,0.4624
36,1.4968,1.421127,0.995857,0.994793,0.995079,0.994927,0.5,0.25,0.5,0.333333,0.49
37,1.4491,1.511501,0.996039,0.995406,0.995097,0.995246,0.5,0.25,0.5,0.333333,0.5184
38,1.3159,1.549118,0.996283,0.995343,0.995234,0.995283,0.5,0.25,0.5,0.333333,0.5476
39,1.3705,1.362503,0.995918,0.994742,0.994798,0.994767,0.5,0.25,0.5,0.333333,0.5776
40,1.1705,1.16728,0.995674,0.994155,0.994683,0.994415,0.5,0.25,0.5,0.333333,0.6084


{'eval_loss': 0.6999772191047668,
 'eval_accuracy_branch1': 0.9964050694613698,
 'eval_precision_branch1': 0.99511070931697,
 'eval_recall_branch1': 0.9954895578433749,
 'eval_f1_branch1': 0.9952970575890556,
 'eval_accuracy_branch2': 0.47532293443821594,
 'eval_precision_branch2': 0.35389043860169744,
 'eval_recall_branch2': 0.47532293443821594,
 'eval_f1_branch2': 0.3377159055561145,
 'eval_lambda': 0.9603999999999999,
 'eval_runtime': 22.318,
 'eval_samples_per_second': 735.37,
 'eval_steps_per_second': 91.944,
 'epoch': 50.0}

In [6]:
metrics = []

for i in range(5):
    print(f"Starting fold {i+1}/5")
    val_dataset, train_dataset = pre.retrieve_fold_data(
        fold_index=i,
        folds=folds,
        labels=labels,
        first_norm=first_norm,
        second_norm=second_norm,
        first_domain=first_domain,
        second_domain=second_domain
    )

    metrics.append(run_experiment(parabolic_increasing_lambda_scheduler, val_input=val_dataset, train_input=train_dataset, num_epochs=50, fold_index=i))
    print(f"Fold {i+1} Metrics: {metrics[-1]}")
    # save intermediate results
    pd.DataFrame(metrics).to_csv(f"{output_dir}/intermediate_results_{DATE}.csv", index=False)
    


Starting fold 1/5


[34m[1mwandb[0m: Currently logged in as: [33msamuelsavine[0m ([33msamuelsavine-johns-hopkins-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,1.0608,0.993233,0.743549,0.684971,0.704751,0.665199,0.517344,0.51822,0.517344,0.511476,0.0
2,1.0075,0.879056,0.881207,0.865055,0.8619,0.854708,0.504443,0.674927,0.504443,0.344803,0.0004
3,0.8978,0.823445,0.92545,0.921836,0.918862,0.919802,0.559396,0.580957,0.559396,0.527968,0.0016
4,0.9104,0.766029,0.940543,0.932905,0.937695,0.934409,0.531037,0.540896,0.531037,0.500961,0.0036
5,0.8995,0.82052,0.955879,0.953708,0.952385,0.951789,0.510528,0.525941,0.510528,0.42514,0.0064
6,0.8002,0.871525,0.965616,0.960551,0.960068,0.959558,0.505112,0.555713,0.505112,0.359732,0.01
7,0.8885,0.923278,0.980708,0.979423,0.979374,0.979298,0.508885,0.542906,0.508885,0.387463,0.0144
8,0.8844,0.735703,0.976692,0.972365,0.972994,0.972018,0.510711,0.510711,0.510711,0.510711,0.0196
9,0.9258,0.786605,0.981682,0.979977,0.979629,0.979747,0.505599,0.506055,0.505599,0.496118,0.0256
10,0.9251,0.770236,0.985942,0.985278,0.985527,0.985381,0.506207,0.508083,0.506207,0.475793,0.0324


Fold 1 Metrics: {'eval_loss': 0.7057162523269653, 'eval_accuracy_branch1': 0.9931840311587147, 'eval_precision_branch1': 0.992146404568368, 'eval_recall_branch1': 0.9927677273863381, 'eval_f1_branch1': 0.9924431658377696, 'eval_accuracy_branch2': 0.502555988315482, 'eval_precision_branch2': 0.5027435112720953, 'eval_recall_branch2': 0.502555988315482, 'eval_f1_branch2': 0.49390795786775354, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 38.605, 'eval_samples_per_second': 425.645, 'eval_steps_per_second': 53.206, 'epoch': 50.0}
Starting fold 2/5


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,1.0799,1.03919,0.722182,0.685476,0.668397,0.660058,0.540175,0.542347,0.540175,0.534204,0.0
2,0.9773,0.869119,0.868395,0.848406,0.850555,0.845155,0.535062,0.551103,0.535062,0.49547,0.0004
3,0.9016,0.817331,0.910945,0.900317,0.892089,0.892704,0.525444,0.528578,0.525444,0.512068,0.0016
4,0.9039,0.812898,0.945824,0.940071,0.939695,0.939022,0.520575,0.53953,0.520575,0.455272,0.0036
5,0.8928,0.780617,0.964329,0.960735,0.959839,0.959839,0.521183,0.533187,0.521183,0.473582,0.0064
6,0.89,0.750164,0.955746,0.949748,0.943247,0.944389,0.511992,0.512273,0.511992,0.509182,0.01
7,0.8738,0.843512,0.972547,0.968158,0.968611,0.968152,0.5,0.25,0.5,0.333333,0.0144
8,0.8581,1.049388,0.976321,0.972645,0.973205,0.97275,0.5,0.25,0.5,0.333333,0.0196
9,0.9019,0.947505,0.979851,0.978008,0.976992,0.977411,0.507792,0.509036,0.507792,0.490244,0.0256
10,1.0168,1.071154,0.985634,0.983961,0.983832,0.983825,0.5,0.25,0.5,0.333333,0.0324


Fold 2 Metrics: {'eval_loss': 0.7048930525779724, 'eval_accuracy_branch1': 0.9923301680058437, 'eval_precision_branch1': 0.9912096723665182, 'eval_recall_branch1': 0.9910846096536828, 'eval_f1_branch1': 0.9911388782254315, 'eval_accuracy_branch2': 0.49038227416605795, 'eval_precision_branch2': 0.4837396650477146, 'eval_recall_branch2': 0.49038227416605795, 'eval_f1_branch2': 0.4324154089200065, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 39.3508, 'eval_samples_per_second': 417.476, 'eval_steps_per_second': 52.197, 'epoch': 50.0}
Starting fold 3/5


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,1.1524,0.955507,0.776769,0.771376,0.753251,0.73557,0.553965,0.556321,0.553965,0.549251,0.0
2,0.9741,0.854858,0.867463,0.864849,0.84099,0.830854,0.582775,0.59873,0.582775,0.565209,0.0004
3,0.9578,0.909637,0.919052,0.908707,0.908653,0.906742,0.50067,0.510707,0.50067,0.34783,0.0016
4,0.915,0.784539,0.929041,0.920069,0.923878,0.918946,0.56164,0.562432,0.56164,0.560244,0.0036
5,0.9101,0.810091,0.957608,0.954127,0.95395,0.953805,0.504751,0.56177,0.504751,0.356174,0.0064
6,0.9236,0.758689,0.968876,0.966599,0.967353,0.966895,0.512121,0.512185,0.512121,0.511477,0.01
7,0.8535,0.85774,0.96717,0.963127,0.964597,0.963085,0.505238,0.505578,0.505238,0.49758,0.0144
8,0.9234,0.836535,0.976976,0.975228,0.975643,0.975425,0.5,0.25,0.5,0.333333,0.0196
9,0.8857,0.98198,0.972713,0.965336,0.965463,0.964446,0.5,0.25,0.5,0.333333,0.0256
10,0.9491,0.741234,0.984407,0.981711,0.982926,0.9823,0.508954,0.513743,0.508954,0.46209,0.0324


Fold 3 Metrics: {'eval_loss': 0.7056677937507629, 'eval_accuracy_branch1': 0.9930564015105372, 'eval_precision_branch1': 0.9921696928877868, 'eval_recall_branch1': 0.9917975374711091, 'eval_f1_branch1': 0.9919726655715848, 'eval_accuracy_branch2': 0.5032890729686929, 'eval_precision_branch2': 0.5059841084706225, 'eval_recall_branch2': 0.503289072968693, 'eval_f1_branch2': 0.4402681004996274, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 22.6311, 'eval_samples_per_second': 725.463, 'eval_steps_per_second': 90.716, 'epoch': 50.0}
Starting fold 4/5


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,1.0737,1.142009,0.774887,0.740668,0.73317,0.723188,0.5,0.25,0.5,0.333333,0.0
2,0.9522,0.910444,0.883819,0.871051,0.85399,0.860109,0.504082,0.663648,0.504082,0.344228,0.0004
3,0.9282,0.793358,0.928415,0.922121,0.918568,0.919314,0.530827,0.556582,0.530827,0.470583,0.0016
4,0.9038,0.807028,0.948032,0.945097,0.943432,0.943852,0.507067,0.595956,0.507067,0.358505,0.0036
5,0.9276,1.062258,0.964177,0.963636,0.960164,0.961557,0.500731,0.583517,0.500731,0.336245,0.0064
6,0.8703,0.912584,0.969843,0.963738,0.966203,0.964766,0.509565,0.51246,0.509565,0.47932,0.01
7,0.8667,0.750673,0.978068,0.974794,0.975678,0.975166,0.508164,0.51266,0.508164,0.460238,0.0144
8,0.9237,0.784369,0.980626,0.979814,0.978798,0.979226,0.500122,0.75003,0.500122,0.333604,0.0196
9,0.8939,0.77521,0.983429,0.979166,0.98105,0.979976,0.511088,0.516132,0.511088,0.469632,0.0256
10,0.9414,0.842475,0.987145,0.985572,0.985336,0.985334,0.500305,0.540399,0.500305,0.33541,0.0324


Fold 4 Metrics: {'eval_loss': 0.7057554721832275, 'eval_accuracy_branch1': 0.9935420982088461, 'eval_precision_branch1': 0.9927932320172339, 'eval_recall_branch1': 0.9927548342093402, 'eval_f1_branch1': 0.9927531916191454, 'eval_accuracy_branch2': 0.49518703545753623, 'eval_precision_branch2': 0.47746515343679585, 'eval_recall_branch2': 0.49518703545753623, 'eval_f1_branch2': 0.3716501103034772, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 44.5297, 'eval_samples_per_second': 368.608, 'eval_steps_per_second': 46.082, 'epoch': 50.0}
Starting fold 5/5


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,1.0783,1.118824,0.689678,0.674399,0.647046,0.627289,0.500183,0.568228,0.500183,0.334172,0.0
2,0.9313,0.86744,0.848586,0.836667,0.816687,0.804483,0.564099,0.567242,0.564099,0.558946,0.0004
3,0.9676,0.841784,0.912747,0.89465,0.885554,0.880414,0.505118,0.521329,0.505118,0.389028,0.0016
4,0.9031,0.815248,0.945101,0.936641,0.937715,0.935977,0.512064,0.523591,0.512064,0.444169,0.0036
5,0.9274,0.775219,0.949366,0.946631,0.945998,0.945253,0.518279,0.531044,0.518279,0.463086,0.0064
6,0.8297,0.769423,0.970144,0.964831,0.963177,0.963317,0.541921,0.542668,0.541921,0.539906,0.01
7,0.8504,0.744854,0.972886,0.968406,0.967606,0.967641,0.536498,0.536799,0.536498,0.535546,0.0144
8,0.9342,0.76558,0.977334,0.970334,0.97508,0.972468,0.522423,0.522777,0.522423,0.520556,0.0196
9,0.8975,0.856122,0.986047,0.983775,0.984098,0.983912,0.498842,0.484721,0.498842,0.34825,0.0256
10,0.9193,1.322714,0.985559,0.983303,0.983267,0.983214,0.5,0.25,0.5,0.333333,0.0324


Fold 5 Metrics: {'eval_loss': 0.707304835319519, 'eval_accuracy_branch1': 0.9921398976358762, 'eval_precision_branch1': 0.9904138441270113, 'eval_recall_branch1': 0.9904026464279795, 'eval_f1_branch1': 0.9904028608186117, 'eval_accuracy_branch2': 0.48555934681940044, 'eval_precision_branch2': 0.47890445185007535, 'eval_recall_branch2': 0.48555934681940044, 'eval_f1_branch2': 0.44151370752355973, 'eval_lambda': 0.9603999999999999, 'eval_runtime': 22.5846, 'eval_samples_per_second': 726.689, 'eval_steps_per_second': 90.858, 'epoch': 50.0}
