In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import random
from tqdm.auto import tqdm
import webdataset as wds
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils.flat_models import *
import h5py
from typing import List, Dict, Any, Tuple
from sklearn.preprocessing import StandardScaler
import argparse

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

# ## MODEL TO LOAD ##
# model_name = "HCPflat_large_gsrFalse_"
# parquet_folder = "epoch99"

# # outdir = os.path.abspath(f'checkpoints/{model_name}')
# outdir = os.path.abspath(f'checkpoints/{model_name}')

# print("outdir", outdir)
# # Load previous config.yaml if available
# if os.path.exists(f"{outdir}/config.yaml"):
#     config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
#     print(f"Loaded config.yaml from ckpt folder {outdir}")
#     # create global variables from the config
#     print("\n__CONFIG__")
#     for attribute_name in config.keys():
#         print(f"{attribute_name} = {config[attribute_name]}")
#         globals()[attribute_name] = config[f'{attribute_name}']
#     print("\n")

# world_size = os.getenv('WORLD_SIZE')
# if world_size is None: 
#     world_size = 1
# else:
#     world_size = int(world_size)
# print(f"WORLD_SIZE={world_size}")

# if utils.is_interactive():
#     # Following allows you to change functions in models.py or utils.py and 
#     # have this notebook automatically update with your revisions
#     %load_ext autoreload
#     %autoreload 2

# batch_size = probe_batch_size
# num_epochs = probe_num_epochs

# data_type = torch.float32 # change depending on your mixed_precision
# global_batch_size = batch_size * world_size

device = torch.device('cuda')

# hcp_flat_path = "/weka/proj-medarc/shared/HCP-Flat"
# seed = 42
num_frames = 16
gsr = False
# num_workers = 5
batch_size = 128
# target = 'sex' # This can be 'trial_type' 'age' 'sex'

print("PID of this process =",os.getpid())

PID of this process = 502574


In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "HCPflat_large_gsrFalse_"
    print("model_name:", model_name)
    outdir = os.path.abspath(f'checkpoints/{model_name}')
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = f"--hcp_flat_path=/weka/proj-medarc/shared/HCP-Flat \
                    --target=sex \
                    --model_name={model_name} \
                    --batch_size={batch_size} \
                    --max_lr=1e-3 --num_epochs=40 --no-save_ckpt --no-wandb_log --num_workers=10 \
                    --weight_decay=1e-5 \
                    --parquet_folder=epoch99 \
                    --outdir={outdir} \
                    --global_pooling"
    # --multisubject_ckpt=../train_logs/multisubject_subj01_1024_24bs_nolow
    # recommended max_lr on trial_type 9e-3, on age 1e-3, on sex 1e-3
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: HCPflat_large_gsrFalse_
--hcp_flat_path=/weka/proj-medarc/shared/HCP-Flat                     --target=sex                     --model_name=HCPflat_large_gsrFalse_                     --batch_size=128                     --max_lr=1e-3 --num_epochs=40 --no-save_ckpt --no-wandb_log --num_workers=10                     --weight_decay=1e-5                     --parquet_folder=epoch99                     --outdir=/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_                     --global_pooling


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="HCPflat_large_gsrFalse_",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--hcp_flat_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--batch_size", type=int, default=128,
    help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--num_epochs",type=int,default=150,
    help="number of epochs of training",
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--save_ckpt",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)
parser.add_argument(
    "--target",type=str,default='trial_type',choices=['trial_type','sex','age'],
)
parser.add_argument(
    "--num_workers",type=int,default=10,
)
parser.add_argument(
    "--weight_decay",type=float,default=1e-5,
)
parser.add_argument(
    "--parquet_folder",type=str,default='epoch99',
)
parser.add_argument(
    "--outdir",type=str,default='./checkpoints/HCPflat_large_gsrFalse_/',
    help="Path where the precomputed features are located"
)
parser.add_argument(
    "--global_pooling",action=argparse.BooleanOptionalAction,default=True,
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

print(f"------ ARGS ------- \n {args}")

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)


------ ARGS ------- 
 Namespace(model_name='HCPflat_large_gsrFalse_', hcp_flat_path='/weka/proj-medarc/shared/HCP-Flat', batch_size=128, wandb_log=False, num_epochs=40, lr_scheduler_type='cycle', save_ckpt=False, seed=42, max_lr=0.001, target='sex', num_workers=10, weight_decay=1e-05, parquet_folder='epoch99', outdir='/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_', global_pooling=True)


In [4]:
#### UNCOMMENT THIS TO SAVE THE HCP-FLAT IN HDF5 FORMAT


# from torch.utils.data import default_collate
# from mae_utils.flat import load_hcp_flat_mask
# from mae_utils.flat import create_hcp_flat
# from mae_utils.flat import batch_unmask
# import mae_utils.visualize as vis


# batch_size = 26
# print(f"changed batch_size to {batch_size}")

# ## Test ##
# datasets_to_include = "HCP"
# assert "HCP" in datasets_to_include
# test_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=num_frames, shuffle=False, gsr=gsr, sub_list = 'test')
# test_dl = wds.WebLoader(
#     test_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# ## Train ##
# assert "HCP" in datasets_to_include
# train_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=num_frames, shuffle=False, gsr=gsr, sub_list = 'train')
# train_dl = wds.WebLoader(
#     train_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# def flatten_meta(meta_dict):
#     """
#     Flatten the meta dictionary by:
#     - Replacing single-item lists with the item itself.
#     - Converting tensors to scalar numbers.
#     """
#     flattened = {}
#     for key, value in meta_dict.items():
#         if isinstance(value, list):
#             if len(value) == 1:
#                 flattened[key] = value[0]  # Replace list with its single item
#             else:
#                 flattened[key] = value  # Keep as is if multiple items
#         elif isinstance(value, torch.Tensor):
#             # Convert tensor to scalar
#             if value.numel() == 1:
#                 flattened[key] = value.item()
#             else:
#                 flattened[key] = value.tolist()  # Convert multi-element tensor to list
#         else:
#             flattened[key] = value  # Keep the value as is
#     return flattened

# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File('train_hcp_raw_flatmaps.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(train_dl), total = 120000):
#         images = batch['image'][0]
#         meta = batch['meta']
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save('metadata_test_HCP_raw_flatmaps.npy', meta_array)


# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File('test_hcp_raw_flatmaps.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(test_dl), total = 12000):
#         images = batch['image'][0]
#         meta = batch['meta']
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save('metadata_train_HCP_raw_flatmaps.npy', meta_array)

### Data

In [5]:
# f_train = h5py.File('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/train_hcp_raw_flatmaps.hdf5', 'r')
# flatmaps_train = f_train['flatmaps']

# f_test = h5py.File('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/test_hcp_raw_flatmaps.hdf5', 'r')
# flatmaps_test = f_test['flatmaps']

# metadata_train = np.load('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/metadata_train_HCP_raw_flatmaps.npy', allow_pickle=True)
# metadata_test = np.load('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/metadata_test_HCP_raw_flatmaps.npy', allow_pickle=True)

train_features = pd.read_parquet(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/train.parquet")
test_features = pd.read_parquet(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/test.parquet")

### Create the dataloader

In [7]:
from torch.utils.data import Dataset, DataLoader

class HCPFlatDataset(Dataset):
    def __init__(self, parquet_data):
        self.parquet_data = parquet_data

    def __len__(self):
        return len(self.parquet_data)

    def __getitem__(self, idx):
        c_feature = torch.Tensor(self.parquet_data.iloc[idx]['feature'])
        c_sub = self.parquet_data.iloc[idx]['sub'][0]
        c_trial_type = self.parquet_data.iloc[idx]['trial_type'][0]
        return c_feature, c_trial_type, c_sub

# Loading to cpu for faster training, this can take several minutes.  Remove this [:] if you want to move one at the time.
train_dataset = HCPFlatDataset(train_features)
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

test_dataset = HCPFlatDataset(test_features)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)


In [8]:
# for a in test_dl:
#     print(a[0], a[1], a[2])
#     break

### Load subject information

In [9]:
# open the file containing subject information
if target == "age" or target == "sex":
    subject_information_HCP_path = os.path.join(hcp_flat_path, "subjects_data_restricted.csv")
    try:
        subject_information_HCP = pd.read_csv(subject_information_HCP_path)
    except:
        try:
            subject_information_HCP = pd.read_csv('./unrestricted_clane9_4_23_2024_13_28_14.csv')   
        except:
            assert False, "Subject information file not found"

    ###### This is for unrestricted
    # age_related_columns = [
    #     'Age', 'PicSeq_AgeAdj', 'CardSort_AgeAdj', 'Flanker_AgeAdj',
    #     'ReadEng_AgeAdj', 'PicVocab_AgeAdj', 'ProcSpeed_AgeAdj',
    #     'CogFluidComp_AgeAdj', 'CogEarlyComp_AgeAdj', 'CogTotalComp_AgeAdj',
    #     'CogCrystalComp_AgeAdj', 'Endurance_AgeAdj', 'Dexterity_AgeAdj',
    #     'Strength_AgeAdj', 'Odor_AgeAdj', 'Taste_AgeAdj'
    # ]
    
    # sex_related_columns = [
    #     'Gender'
    # ]

    ###### This is for restricted
    gender_related_columns = [
        'Gender'
    ]

    age_related_columns = [
        'Age_in_Yrs',
        'Menstrual_AgeBegan',
        'Menstrual_AgeIrreg',
        'Menstrual_AgeStop',
        'SSAGA_Alc_Age_1st_Use',
        'SSAGA_TB_Age_1st_Cig',
        'SSAGA_Mj_Age_1st_Use',
        'Endurance_AgeAdj',
        'Dexterity_AgeAdj',
        'Strength_AgeAdj',
        'PicSeq_AgeAdj',
        'CardSort_AgeAdj',
        'Flanker_AgeAdj',
        'ReadEng_AgeAdj',
        'PicVocab_AgeAdj',
        'ProcSpeed_AgeAdj',
        'Odor_AgeAdj',
        'Taste_AgeAdj'
    ]

    # # show the first few rows of the subject information
    # subject_information_HCP[age_related_columns + sex_related_columns].head()

    # Handle missing values (e.g., impute with mean)
    mean_age = subject_information_HCP['Age_in_Yrs'].mean()
    
    # Initialize the scaler
    scaler = StandardScaler()
    
    # Perform z-score normalization
    subject_information_HCP['Age_in_Yrs_z'] = scaler.fit_transform(subject_information_HCP[['Age_in_Yrs']])


    
def get_label_unrestricted(subject_id: List[str], target: str, method_for_age: str = 'mean') -> List:
    """
    Get the label for the given subject id and target.

    For sex 0 is F and 1 is M
    """

    # convert to list of ints
    subject_id = [int(x) for x in subject_id]

    if target == "age":
        age_array = []
        for subject in subject_id:
            c_age = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Age'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_age) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_age) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            c_age = c_age[0].split('-')
            if len(c_age) < 2:
                c_age = c_age[0].split('+')
                age_array.append(int(c_age[0]))
            else:
                if method_for_age == 'mean':
                    age_array.append(np.mean([int(x) for x in c_age]))
                elif method_for_age == 'min':
                    age_array.append(np.min([int(x) for x in c_age]))
                elif method_for_age == 'max':
                    age_array.append(np.max([int(x) for x in c_age]))
                else:
                    assert False, f"Method {method_for_age} not recognized"

        return np.array(age_array)  
    
    elif target == 'sex':
        sex_array = []
        for subject in subject_id:
            c_sex = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Gender'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_sex) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_sex) > 1:
                print(f"Warning: Multiple entries for subject {subject}")
            sex_array.append(int(c_sex[0] == 'M'))
        return sex_array

def get_label_restricted(subject_id: List[str], target: str, normalized: bool = True) -> List:
    """
    Get the label for the given subject id and target.

    For sex 0 is F and 1 is M
    """

    # convert to list of ints
    subject_id = [int(x) for x in subject_id]

    if target == "age":
        age_array = []
        for subject in subject_id:
            c_age = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Age_in_Yrs' if not normalized else 'Age_in_Yrs_z'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_age) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_age) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            age_array.append(np.int8(c_age[0]))

        return np.array(age_array)  
    
    elif target == 'sex':
        sex_array = []
        for subject in subject_id:
            c_sex = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Gender'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_sex) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_sex) > 1:
                print(f"Warning: Multiple entries for subject {subject}")
            sex_array.append(int(c_sex[0] == 'M'))
        return sex_array

In [10]:
from sklearn.preprocessing import LabelEncoder

if target == "trial_type":
    
    INCLUDE_CONDS = {
        "fear",
        "neut",
        "math",
        "story",
        "lf",
        "lh",
        "rf",
        "rh",
        "t",
        "match",
        "relation",
        "mental",
        "rnd",
        "0bk_body",
        "2bk_body",
        "0bk_faces",
        "2bk_faces",
        "0bk_places",
        "2bk_places",
        "0bk_tools",
        "2bk_tools",
    }

    # test_data = []

    # # Iterate over the DataLoader with a progress bar
    # for sample in tqdm(train_dl, desc="Processing samples"):
    #     x = sample['image']
    #     y = sample['meta']['trial_type']
    #     key = sample['meta']['key']
    #     print(x.shape, y, key)
    #     break
    # Initialize the label encoder
    label_encoder = LabelEncoder()
    label_encoder.fit(sorted(INCLUDE_CONDS))  # Ensure consistent ordering

    num_classes = len(label_encoder.classes_)
    print(f"Number of classes: {num_classes}")

In [11]:
# for sample in tqdm(train_dl):
#     x = sample[0]
#     subject_id = sample[1]['sub']

#     # benchmark time
#     start = time.time()
#     y = get_label(subject_id, 'age')
#     end = time.time()
#     print(f"Time taken: {end - start}")
#     print(x.shape, y, subject_id, torch.Tensor(y).shape)
#     break

### Create pytorch model

In [12]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # Flatten the input except for the batch dimension
        x = x.view(x.size(0), -1)
        out = self.linear(x)
        return out  # Raw logits

# Determine the input dimension from a single sample
# Assuming images are of shape [1, 16, 144, 320]
sample_batch = next(iter(train_dl))
sample_image = sample_batch[0][0]  # Shape: [1, 16, 144, 320]
input_dim = sample_image.view(-1).size(0)
print(f"Input dimension: {input_dim}")


Input dimension: 1024


In [13]:
# Initialize the model

if target == "trial_type":
    model = LinearClassifier(input_dim=input_dim, num_classes=num_classes)
    criterion = nn.CrossEntropyLoss()

elif target == "age":
    model = LinearClassifier(input_dim=input_dim, num_classes=1)
    criterion = nn.MSELoss()

elif target == "sex":
    model = LinearClassifier(input_dim=input_dim, num_classes=1)
    criterion = nn.BCEWithLogitsLoss()

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# import schedulefree
# optimizer = schedulefree.AdamWScheduleFree(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)

num_iterations_per_epoch = math.ceil(train_features.shape[0]/batch_size)

if lr_scheduler_type == 'linear':
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),
        last_epoch=-1
    )
elif lr_scheduler_type == 'cycle':
    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))
    print("total_steps", total_steps)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )


total_steps 34800


### Wandb logging

In [14]:
import wandb
import uuid

myuuid = uuid.uuid4()
str(myuuid)
if utils.is_interactive():
    print("Running in interactive notebook. Disabling W&B and ckpt saving.")
    wandb_log = False
    save_ckpt = False

if wandb_log:
    wandb_project = 'fMRI-foundation-model'
    wandb_config = {
        "model_name": f"HCPflat_raw_{target}",
        "batch_size": batch_size,
        "weight_decay": weight_decay,
        "num_epochs": num_epochs,
        "seed": seed,
        "lr_scheduler_type": lr_scheduler_type,
        "save_ckpt": save_ckpt,
        "seed": seed,
        "max_lr": max_lr,
        "target": target,
        "num_workers": num_workers,
        "weight_decay": weight_decay
    }
    print("wandb_config:\n", wandb_config)
    random_id = random.randint(0, 100000)
    wandb_id = "HCPflat_raw" + f"_{model_suffix}_{target}_{myuuid}"
    print("wandb_id:", wandb_id)
    wandb.init(
        id=wandb_id,
        project=wandb_project,
        name="HCPflat_raw"+ f"_{model_suffix}_{target}",
        config=wandb_config,
        resume="allow",
    )

Running in interactive notebook. Disabling W&B and ckpt saving.


### Training loop

In [15]:
for epoch in range(num_epochs):
    running_train_loss = 0.0
    correct_train = 0
    mse_age_train = 0.0
    total_train = 0
    step = 0

    # Training Phase
    model.train()
    optimizer.zero_grad()  # Reset gradients before starting training

    for batch in tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        optimizer.zero_grad()
        images = batch[0].to(device).float()  # Shape: [batch_size, 1, 16, 144, 320]

        # Prepare labels based on target type
        if target == "trial_type":
            labels = batch[1]  # List of labels
            labels = label_encoder.transform(labels)
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        elif target == "age":
            labels = get_label_restricted(batch[2], 'age')
            labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
        elif target == "sex":
            labels = get_label_restricted(batch[2], 'sex')
            labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
        # labels = labels.unsqueeze(1)
        # Forward pass
        outputs = model(images)  # Output shape depends on the target

        # Compute loss
        if target in ["trial_type", "sex"]:
            # For classification, ensure outputs are logits
            loss = criterion(outputs.squeeze(), labels.squeeze())
        elif target == "age":
            # For regression, ensure outputs are single values
            loss = criterion(outputs.squeeze(), labels.squeeze())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate loss
        running_train_loss += loss.item() * images.size(0)

        # Calculate and accumulate metrics
        if target == "trial_type":
            _, predicted = torch.max(outputs, 1)
            correct_train += (predicted == labels).sum().item()
        elif target == "age":
            mse_age_train += (torch.sum((outputs.squeeze() - labels) ** 2).item())
        elif target == "sex":
            threshold = 0.5
            predicted = (torch.sigmoid(outputs) > threshold).float().squeeze()
            correct_train += (predicted == labels).sum().item()

        total_train += labels.size(0)
        step += 1

        # Print intermediate metrics every 100 steps
        if step % 100 == 0:
            if target in ["trial_type", "sex"]:
                current_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
                print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training Accuracy: {current_accuracy:.2f}%")
            elif target == "age":
                current_mse = mse_age_train / total_train if total_train > 0 else 0.0
                print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training MSE: {current_mse:.4f}")

        if lr_scheduler_type is not None:
                lr_scheduler.step()

    # Calculate epoch-level metrics
    epoch_train_loss = running_train_loss / total_train if total_train > 0 else 0.0

    if target in ["trial_type", "sex"]:
        train_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
    elif target == "age":
        train_mse = mse_age_train / total_train if total_train > 0 else 0.0

    # Validation Phase
    model.eval()
    running_val_loss = 0.0
    correct_val = 0
    mse_age_val = 0.0
    total_val = 0

    with torch.no_grad():
        for batch in tqdm(test_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            images = batch[0].to(device).float()  # Removed unsqueeze(1) unless specifically needed

            # Prepare labels based on target type
            if target == "trial_type":
                labels = batch[1]  # List of labels
                labels = label_encoder.transform(labels)
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
            elif target == "age":
                labels = get_label_restricted(batch[2], 'age')
                labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
            elif target == "sex":
                labels = get_label_restricted(batch[2], 'sex')
                labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]

            # labels = labels.unsqueeze(1)
            
            # Forward pass
            outputs = model(images)

            # Compute loss
            if target in ["trial_type", "sex"]:
                loss = criterion(outputs.squeeze(), labels.squeeze())
            elif target == "age":
                loss = criterion(outputs.squeeze(), labels.squeeze())

            # Accumulate loss
            running_val_loss += loss.item() * images.size(0)

            # Calculate and accumulate metrics
            if target == "trial_type":
                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
            elif target == "age":
                mse_age_val += (torch.sum((outputs.squeeze() - labels) ** 2).item()) 
            elif target == "sex":
                threshold = 0.5
                predicted = (torch.sigmoid(outputs) > threshold).float().squeeze()
                correct_val += (predicted == labels).sum().item()

            total_val += labels.size(0)

    # Calculate epoch-level validation metrics
    epoch_val_loss = running_val_loss / total_val if total_val > 0 else 0.0

    if target in ["trial_type", "sex"]:
        val_accuracy = 100 * correct_val / total_val if total_val > 0 else 0.0
    elif target == "age":
        val_mse = mse_age_val / total_val if total_val > 0 else 0.0

    # Print epoch-level metrics
    if target in ["trial_type", "sex"]:
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"- Training Loss: {epoch_train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}% "
              f"- Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    elif target == "age":
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"- Training Loss: {epoch_train_loss:.4f}, Training MSE: {train_mse:.4f} "
              f"- Validation Loss: {epoch_val_loss:.4f}, Validation MSE: {val_mse:.4f}")

    # Log metrics with wandb
    if wandb_log:
        log_dict = {
            "epoch_train_loss": epoch_train_loss,
            "epoch_val_loss": epoch_val_loss,
        }
        if target in ["trial_type", "sex"]:
            log_dict.update({
                f"train_accuracy_{target}": train_accuracy,
                f"val_accuracy_{target}": val_accuracy,
            })
        elif target == "age":
            log_dict.update({
                f"train_mse_{target}": train_mse,
                f"val_mse_{target}": val_mse,
            })
        wandb.log(log_dict)

# Save checkpoint if required
if save_ckpt:
    outdir = os.path.abspath(f'checkpoints/{"HCPflat_raw"+ f"_{model_suffix}_{target}"}_{random_id}')
    os.makedirs(outdir, exist_ok=True)
    print("Saving checkpoint to:", outdir)
    # Save model state
    torch.save(model.state_dict(), os.path.join(outdir, "model.pth"))
    # Save configuration
    with open(os.path.join(outdir, "config.yaml"), 'w') as f:
        yaml.dump(wandb_config, f)
    print(f"Model and config saved to {outdir}")


Epoch 1/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.6415 - Training Accuracy: 58.36%
Step [200/870] - Training Loss: 0.6210 - Training Accuracy: 64.36%
Step [300/870] - Training Loss: 0.5219 - Training Accuracy: 69.53%
Step [400/870] - Training Loss: 0.4441 - Training Accuracy: 73.06%
Step [500/870] - Training Loss: 0.4192 - Training Accuracy: 75.42%
Step [600/870] - Training Loss: 0.4268 - Training Accuracy: 77.22%
Step [700/870] - Training Loss: 0.3428 - Training Accuracy: 78.51%
Step [800/870] - Training Loss: 0.4011 - Training Accuracy: 79.49%


Epoch 1/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

  c_feature = torch.Tensor(self.parquet_data.iloc[idx]['feature'])


Epoch [1/40] - Training Loss: 0.4790, Training Accuracy: 80.04% - Validation Loss: 0.3822, Validation Accuracy: 82.91%


Epoch 2/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3535 - Training Accuracy: 86.66%
Step [200/870] - Training Loss: 0.2562 - Training Accuracy: 86.59%
Step [300/870] - Training Loss: 0.3247 - Training Accuracy: 86.82%
Step [400/870] - Training Loss: 0.2158 - Training Accuracy: 87.11%
Step [500/870] - Training Loss: 0.3321 - Training Accuracy: 87.26%
Step [600/870] - Training Loss: 0.3279 - Training Accuracy: 87.47%
Step [700/870] - Training Loss: 0.2726 - Training Accuracy: 87.51%
Step [800/870] - Training Loss: 0.2795 - Training Accuracy: 87.54%


Epoch 2/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [2/40] - Training Loss: 0.2960, Training Accuracy: 87.58% - Validation Loss: 0.3657, Validation Accuracy: 83.86%


Epoch 3/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3745 - Training Accuracy: 88.53%
Step [200/870] - Training Loss: 0.2819 - Training Accuracy: 88.42%
Step [300/870] - Training Loss: 0.2079 - Training Accuracy: 88.27%
Step [400/870] - Training Loss: 0.3103 - Training Accuracy: 88.25%
Step [500/870] - Training Loss: 0.2557 - Training Accuracy: 88.39%
Step [600/870] - Training Loss: 0.3672 - Training Accuracy: 88.33%
Step [700/870] - Training Loss: 0.3294 - Training Accuracy: 88.32%
Step [800/870] - Training Loss: 0.2537 - Training Accuracy: 88.36%


Epoch 3/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [3/40] - Training Loss: 0.2777, Training Accuracy: 88.34% - Validation Loss: 0.3577, Validation Accuracy: 84.62%


Epoch 4/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2606 - Training Accuracy: 88.38%
Step [200/870] - Training Loss: 0.2331 - Training Accuracy: 88.33%
Step [300/870] - Training Loss: 0.3483 - Training Accuracy: 88.52%
Step [400/870] - Training Loss: 0.3034 - Training Accuracy: 88.48%
Step [500/870] - Training Loss: 0.2563 - Training Accuracy: 88.48%
Step [600/870] - Training Loss: 0.2617 - Training Accuracy: 88.51%
Step [700/870] - Training Loss: 0.1840 - Training Accuracy: 88.48%
Step [800/870] - Training Loss: 0.3122 - Training Accuracy: 88.51%


Epoch 4/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [4/40] - Training Loss: 0.2718, Training Accuracy: 88.52% - Validation Loss: 0.3575, Validation Accuracy: 84.49%


Epoch 5/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2604 - Training Accuracy: 88.16%
Step [200/870] - Training Loss: 0.2991 - Training Accuracy: 88.25%
Step [300/870] - Training Loss: 0.2978 - Training Accuracy: 88.40%
Step [400/870] - Training Loss: 0.2350 - Training Accuracy: 88.49%
Step [500/870] - Training Loss: 0.2776 - Training Accuracy: 88.48%
Step [600/870] - Training Loss: 0.1749 - Training Accuracy: 88.62%
Step [700/870] - Training Loss: 0.2756 - Training Accuracy: 88.72%
Step [800/870] - Training Loss: 0.2740 - Training Accuracy: 88.79%


Epoch 5/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [5/40] - Training Loss: 0.2678, Training Accuracy: 88.77% - Validation Loss: 0.3552, Validation Accuracy: 84.50%


Epoch 6/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3825 - Training Accuracy: 88.21%
Step [200/870] - Training Loss: 0.2240 - Training Accuracy: 88.52%
Step [300/870] - Training Loss: 0.2702 - Training Accuracy: 88.59%
Step [400/870] - Training Loss: 0.1598 - Training Accuracy: 88.69%
Step [500/870] - Training Loss: 0.2580 - Training Accuracy: 88.68%
Step [600/870] - Training Loss: 0.2901 - Training Accuracy: 88.86%
Step [700/870] - Training Loss: 0.2212 - Training Accuracy: 88.87%
Step [800/870] - Training Loss: 0.2653 - Training Accuracy: 88.86%


Epoch 6/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [6/40] - Training Loss: 0.2656, Training Accuracy: 88.85% - Validation Loss: 0.3563, Validation Accuracy: 84.66%


Epoch 7/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3620 - Training Accuracy: 88.77%
Step [200/870] - Training Loss: 0.3050 - Training Accuracy: 88.86%
Step [300/870] - Training Loss: 0.2098 - Training Accuracy: 88.96%
Step [400/870] - Training Loss: 0.2813 - Training Accuracy: 88.90%
Step [500/870] - Training Loss: 0.2650 - Training Accuracy: 88.90%
Step [600/870] - Training Loss: 0.2733 - Training Accuracy: 88.92%
Step [700/870] - Training Loss: 0.2481 - Training Accuracy: 88.88%
Step [800/870] - Training Loss: 0.2111 - Training Accuracy: 88.87%


Epoch 7/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [7/40] - Training Loss: 0.2638, Training Accuracy: 88.89% - Validation Loss: 0.3536, Validation Accuracy: 84.89%


Epoch 8/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3048 - Training Accuracy: 88.70%
Step [200/870] - Training Loss: 0.2804 - Training Accuracy: 88.76%
Step [300/870] - Training Loss: 0.2652 - Training Accuracy: 88.78%
Step [400/870] - Training Loss: 0.1705 - Training Accuracy: 88.80%
Step [500/870] - Training Loss: 0.2534 - Training Accuracy: 88.88%
Step [600/870] - Training Loss: 0.2580 - Training Accuracy: 88.87%
Step [700/870] - Training Loss: 0.2868 - Training Accuracy: 88.88%
Step [800/870] - Training Loss: 0.2574 - Training Accuracy: 88.92%


Epoch 8/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [8/40] - Training Loss: 0.2623, Training Accuracy: 88.92% - Validation Loss: 0.3591, Validation Accuracy: 84.32%


Epoch 9/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2316 - Training Accuracy: 89.24%
Step [200/870] - Training Loss: 0.1661 - Training Accuracy: 89.19%
Step [300/870] - Training Loss: 0.3534 - Training Accuracy: 89.19%
Step [400/870] - Training Loss: 0.2333 - Training Accuracy: 89.08%
Step [500/870] - Training Loss: 0.2240 - Training Accuracy: 89.08%
Step [600/870] - Training Loss: 0.2304 - Training Accuracy: 89.09%
Step [700/870] - Training Loss: 0.2972 - Training Accuracy: 89.05%
Step [800/870] - Training Loss: 0.2120 - Training Accuracy: 89.05%


Epoch 9/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [9/40] - Training Loss: 0.2612, Training Accuracy: 89.06% - Validation Loss: 0.3579, Validation Accuracy: 84.53%


Epoch 10/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2834 - Training Accuracy: 89.11%
Step [200/870] - Training Loss: 0.2386 - Training Accuracy: 88.94%
Step [300/870] - Training Loss: 0.2149 - Training Accuracy: 88.91%
Step [400/870] - Training Loss: 0.2993 - Training Accuracy: 89.01%
Step [500/870] - Training Loss: 0.2148 - Training Accuracy: 89.08%
Step [600/870] - Training Loss: 0.3640 - Training Accuracy: 89.04%
Step [700/870] - Training Loss: 0.2907 - Training Accuracy: 88.99%
Step [800/870] - Training Loss: 0.2183 - Training Accuracy: 89.04%


Epoch 10/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [10/40] - Training Loss: 0.2597, Training Accuracy: 89.09% - Validation Loss: 0.3606, Validation Accuracy: 84.40%


Epoch 11/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2255 - Training Accuracy: 89.35%
Step [200/870] - Training Loss: 0.2876 - Training Accuracy: 89.49%
Step [300/870] - Training Loss: 0.2145 - Training Accuracy: 89.27%
Step [400/870] - Training Loss: 0.3012 - Training Accuracy: 89.18%
Step [500/870] - Training Loss: 0.2871 - Training Accuracy: 89.08%
Step [600/870] - Training Loss: 0.3310 - Training Accuracy: 89.08%
Step [700/870] - Training Loss: 0.3019 - Training Accuracy: 89.08%
Step [800/870] - Training Loss: 0.2710 - Training Accuracy: 89.12%


Epoch 11/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [11/40] - Training Loss: 0.2589, Training Accuracy: 89.11% - Validation Loss: 0.3552, Validation Accuracy: 84.99%


Epoch 12/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2265 - Training Accuracy: 89.41%
Step [200/870] - Training Loss: 0.2402 - Training Accuracy: 89.35%
Step [300/870] - Training Loss: 0.3017 - Training Accuracy: 89.22%
Step [400/870] - Training Loss: 0.2139 - Training Accuracy: 89.18%
Step [500/870] - Training Loss: 0.1920 - Training Accuracy: 89.10%
Step [600/870] - Training Loss: 0.2696 - Training Accuracy: 89.17%
Step [700/870] - Training Loss: 0.3290 - Training Accuracy: 89.16%
Step [800/870] - Training Loss: 0.3251 - Training Accuracy: 89.13%


Epoch 12/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [12/40] - Training Loss: 0.2579, Training Accuracy: 89.17% - Validation Loss: 0.3552, Validation Accuracy: 84.68%


Epoch 13/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2522 - Training Accuracy: 89.23%
Step [200/870] - Training Loss: 0.2105 - Training Accuracy: 89.21%
Step [300/870] - Training Loss: 0.1728 - Training Accuracy: 89.24%
Step [400/870] - Training Loss: 0.2380 - Training Accuracy: 89.26%
Step [500/870] - Training Loss: 0.2225 - Training Accuracy: 89.28%
Step [600/870] - Training Loss: 0.3091 - Training Accuracy: 89.24%
Step [700/870] - Training Loss: 0.1627 - Training Accuracy: 89.21%
Step [800/870] - Training Loss: 0.1989 - Training Accuracy: 89.19%


Epoch 13/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [13/40] - Training Loss: 0.2574, Training Accuracy: 89.21% - Validation Loss: 0.3519, Validation Accuracy: 85.04%


Epoch 14/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3299 - Training Accuracy: 89.21%
Step [200/870] - Training Loss: 0.1922 - Training Accuracy: 89.27%
Step [300/870] - Training Loss: 0.3108 - Training Accuracy: 89.29%
Step [400/870] - Training Loss: 0.2451 - Training Accuracy: 89.27%
Step [500/870] - Training Loss: 0.3729 - Training Accuracy: 89.28%
Step [600/870] - Training Loss: 0.2076 - Training Accuracy: 89.27%
Step [700/870] - Training Loss: 0.2210 - Training Accuracy: 89.24%
Step [800/870] - Training Loss: 0.2808 - Training Accuracy: 89.25%


Epoch 14/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [14/40] - Training Loss: 0.2564, Training Accuracy: 89.21% - Validation Loss: 0.3668, Validation Accuracy: 84.93%


Epoch 15/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2485 - Training Accuracy: 89.12%
Step [200/870] - Training Loss: 0.2556 - Training Accuracy: 89.33%
Step [300/870] - Training Loss: 0.2820 - Training Accuracy: 89.36%
Step [400/870] - Training Loss: 0.2813 - Training Accuracy: 89.28%
Step [500/870] - Training Loss: 0.2353 - Training Accuracy: 89.26%
Step [600/870] - Training Loss: 0.2444 - Training Accuracy: 89.22%
Step [700/870] - Training Loss: 0.3084 - Training Accuracy: 89.29%
Step [800/870] - Training Loss: 0.2246 - Training Accuracy: 89.30%


Epoch 15/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [15/40] - Training Loss: 0.2557, Training Accuracy: 89.30% - Validation Loss: 0.3574, Validation Accuracy: 84.91%


Epoch 16/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2892 - Training Accuracy: 89.39%
Step [200/870] - Training Loss: 0.2039 - Training Accuracy: 89.29%
Step [300/870] - Training Loss: 0.2716 - Training Accuracy: 89.14%
Step [400/870] - Training Loss: 0.2741 - Training Accuracy: 89.12%
Step [500/870] - Training Loss: 0.2889 - Training Accuracy: 89.27%
Step [600/870] - Training Loss: 0.2469 - Training Accuracy: 89.20%
Step [700/870] - Training Loss: 0.1238 - Training Accuracy: 89.25%
Step [800/870] - Training Loss: 0.3016 - Training Accuracy: 89.28%


Epoch 16/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [16/40] - Training Loss: 0.2549, Training Accuracy: 89.29% - Validation Loss: 0.3528, Validation Accuracy: 85.00%


Epoch 17/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2631 - Training Accuracy: 89.00%
Step [200/870] - Training Loss: 0.2335 - Training Accuracy: 89.01%
Step [300/870] - Training Loss: 0.3367 - Training Accuracy: 89.14%
Step [400/870] - Training Loss: 0.3175 - Training Accuracy: 89.14%
Step [500/870] - Training Loss: 0.3039 - Training Accuracy: 89.28%
Step [600/870] - Training Loss: 0.2610 - Training Accuracy: 89.26%
Step [700/870] - Training Loss: 0.2468 - Training Accuracy: 89.22%
Step [800/870] - Training Loss: 0.2748 - Training Accuracy: 89.27%


Epoch 17/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [17/40] - Training Loss: 0.2548, Training Accuracy: 89.31% - Validation Loss: 0.3528, Validation Accuracy: 85.21%


Epoch 18/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2917 - Training Accuracy: 89.39%
Step [200/870] - Training Loss: 0.3361 - Training Accuracy: 89.29%
Step [300/870] - Training Loss: 0.2642 - Training Accuracy: 89.30%
Step [400/870] - Training Loss: 0.2580 - Training Accuracy: 89.29%
Step [500/870] - Training Loss: 0.2440 - Training Accuracy: 89.28%
Step [600/870] - Training Loss: 0.1966 - Training Accuracy: 89.30%
Step [700/870] - Training Loss: 0.2978 - Training Accuracy: 89.36%
Step [800/870] - Training Loss: 0.2693 - Training Accuracy: 89.37%


Epoch 18/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [18/40] - Training Loss: 0.2542, Training Accuracy: 89.39% - Validation Loss: 0.3552, Validation Accuracy: 85.07%


Epoch 19/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2910 - Training Accuracy: 89.35%
Step [200/870] - Training Loss: 0.2365 - Training Accuracy: 89.39%
Step [300/870] - Training Loss: 0.2297 - Training Accuracy: 89.42%
Step [400/870] - Training Loss: 0.1939 - Training Accuracy: 89.33%
Step [500/870] - Training Loss: 0.2298 - Training Accuracy: 89.42%
Step [600/870] - Training Loss: 0.3033 - Training Accuracy: 89.41%
Step [700/870] - Training Loss: 0.2596 - Training Accuracy: 89.37%
Step [800/870] - Training Loss: 0.2253 - Training Accuracy: 89.37%


Epoch 19/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [19/40] - Training Loss: 0.2543, Training Accuracy: 89.38% - Validation Loss: 0.3544, Validation Accuracy: 85.18%


Epoch 20/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2197 - Training Accuracy: 89.54%
Step [200/870] - Training Loss: 0.2403 - Training Accuracy: 89.40%
Step [300/870] - Training Loss: 0.2352 - Training Accuracy: 89.41%
Step [400/870] - Training Loss: 0.2520 - Training Accuracy: 89.49%
Step [500/870] - Training Loss: 0.2055 - Training Accuracy: 89.47%
Step [600/870] - Training Loss: 0.2871 - Training Accuracy: 89.49%
Step [700/870] - Training Loss: 0.2301 - Training Accuracy: 89.47%
Step [800/870] - Training Loss: 0.2451 - Training Accuracy: 89.48%


Epoch 20/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [20/40] - Training Loss: 0.2536, Training Accuracy: 89.43% - Validation Loss: 0.3527, Validation Accuracy: 85.11%


Epoch 21/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2105 - Training Accuracy: 89.26%
Step [200/870] - Training Loss: 0.2614 - Training Accuracy: 89.43%
Step [300/870] - Training Loss: 0.2916 - Training Accuracy: 89.39%
Step [400/870] - Training Loss: 0.2044 - Training Accuracy: 89.46%
Step [500/870] - Training Loss: 0.1678 - Training Accuracy: 89.46%
Step [600/870] - Training Loss: 0.2583 - Training Accuracy: 89.46%
Step [700/870] - Training Loss: 0.2221 - Training Accuracy: 89.49%
Step [800/870] - Training Loss: 0.2864 - Training Accuracy: 89.47%


Epoch 21/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [21/40] - Training Loss: 0.2524, Training Accuracy: 89.47% - Validation Loss: 0.3524, Validation Accuracy: 84.95%


Epoch 22/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2700 - Training Accuracy: 89.71%
Step [200/870] - Training Loss: 0.2480 - Training Accuracy: 89.67%
Step [300/870] - Training Loss: 0.2623 - Training Accuracy: 89.44%
Step [400/870] - Training Loss: 0.2705 - Training Accuracy: 89.48%
Step [500/870] - Training Loss: 0.1247 - Training Accuracy: 89.49%
Step [600/870] - Training Loss: 0.2641 - Training Accuracy: 89.46%
Step [700/870] - Training Loss: 0.2294 - Training Accuracy: 89.52%
Step [800/870] - Training Loss: 0.1556 - Training Accuracy: 89.47%


Epoch 22/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [22/40] - Training Loss: 0.2524, Training Accuracy: 89.46% - Validation Loss: 0.3527, Validation Accuracy: 84.95%


Epoch 23/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3204 - Training Accuracy: 89.68%
Step [200/870] - Training Loss: 0.2704 - Training Accuracy: 89.57%
Step [300/870] - Training Loss: 0.2039 - Training Accuracy: 89.60%
Step [400/870] - Training Loss: 0.3014 - Training Accuracy: 89.56%
Step [500/870] - Training Loss: 0.3023 - Training Accuracy: 89.53%
Step [600/870] - Training Loss: 0.2841 - Training Accuracy: 89.54%
Step [700/870] - Training Loss: 0.2721 - Training Accuracy: 89.49%
Step [800/870] - Training Loss: 0.1878 - Training Accuracy: 89.49%


Epoch 23/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [23/40] - Training Loss: 0.2518, Training Accuracy: 89.50% - Validation Loss: 0.3628, Validation Accuracy: 84.94%


Epoch 24/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.1735 - Training Accuracy: 89.20%
Step [200/870] - Training Loss: 0.2967 - Training Accuracy: 89.27%
Step [300/870] - Training Loss: 0.2068 - Training Accuracy: 89.40%
Step [400/870] - Training Loss: 0.2170 - Training Accuracy: 89.51%
Step [500/870] - Training Loss: 0.2520 - Training Accuracy: 89.43%
Step [600/870] - Training Loss: 0.3148 - Training Accuracy: 89.43%
Step [700/870] - Training Loss: 0.2447 - Training Accuracy: 89.41%
Step [800/870] - Training Loss: 0.1883 - Training Accuracy: 89.43%


Epoch 24/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [24/40] - Training Loss: 0.2523, Training Accuracy: 89.44% - Validation Loss: 0.3524, Validation Accuracy: 85.09%


Epoch 25/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2999 - Training Accuracy: 89.22%
Step [200/870] - Training Loss: 0.2676 - Training Accuracy: 89.25%
Step [300/870] - Training Loss: 0.2659 - Training Accuracy: 89.37%
Step [400/870] - Training Loss: 0.2125 - Training Accuracy: 89.45%
Step [500/870] - Training Loss: 0.2812 - Training Accuracy: 89.39%
Step [600/870] - Training Loss: 0.2167 - Training Accuracy: 89.40%
Step [700/870] - Training Loss: 0.2210 - Training Accuracy: 89.40%
Step [800/870] - Training Loss: 0.1964 - Training Accuracy: 89.44%


Epoch 25/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [25/40] - Training Loss: 0.2514, Training Accuracy: 89.45% - Validation Loss: 0.3525, Validation Accuracy: 85.05%


Epoch 26/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2772 - Training Accuracy: 89.52%
Step [200/870] - Training Loss: 0.2697 - Training Accuracy: 89.50%
Step [300/870] - Training Loss: 0.2252 - Training Accuracy: 89.51%
Step [400/870] - Training Loss: 0.2866 - Training Accuracy: 89.53%
Step [500/870] - Training Loss: 0.2551 - Training Accuracy: 89.56%
Step [600/870] - Training Loss: 0.2151 - Training Accuracy: 89.57%
Step [700/870] - Training Loss: 0.2836 - Training Accuracy: 89.54%
Step [800/870] - Training Loss: 0.3179 - Training Accuracy: 89.51%


Epoch 26/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [26/40] - Training Loss: 0.2510, Training Accuracy: 89.52% - Validation Loss: 0.3540, Validation Accuracy: 85.03%


Epoch 27/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.1788 - Training Accuracy: 89.95%
Step [200/870] - Training Loss: 0.2768 - Training Accuracy: 89.95%
Step [300/870] - Training Loss: 0.2256 - Training Accuracy: 89.68%
Step [400/870] - Training Loss: 0.2802 - Training Accuracy: 89.61%
Step [500/870] - Training Loss: 0.3605 - Training Accuracy: 89.52%
Step [600/870] - Training Loss: 0.2487 - Training Accuracy: 89.54%
Step [700/870] - Training Loss: 0.2436 - Training Accuracy: 89.48%
Step [800/870] - Training Loss: 0.2434 - Training Accuracy: 89.43%


Epoch 27/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [27/40] - Training Loss: 0.2509, Training Accuracy: 89.48% - Validation Loss: 0.3535, Validation Accuracy: 85.13%


Epoch 28/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.1924 - Training Accuracy: 89.31%
Step [200/870] - Training Loss: 0.2306 - Training Accuracy: 89.57%
Step [300/870] - Training Loss: 0.2965 - Training Accuracy: 89.56%
Step [400/870] - Training Loss: 0.1997 - Training Accuracy: 89.53%
Step [500/870] - Training Loss: 0.2739 - Training Accuracy: 89.50%
Step [600/870] - Training Loss: 0.3258 - Training Accuracy: 89.51%
Step [700/870] - Training Loss: 0.2402 - Training Accuracy: 89.48%
Step [800/870] - Training Loss: 0.2738 - Training Accuracy: 89.51%


Epoch 28/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [28/40] - Training Loss: 0.2504, Training Accuracy: 89.53% - Validation Loss: 0.3517, Validation Accuracy: 85.15%


Epoch 29/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2029 - Training Accuracy: 89.82%
Step [200/870] - Training Loss: 0.2183 - Training Accuracy: 89.86%
Step [300/870] - Training Loss: 0.2624 - Training Accuracy: 89.66%
Step [400/870] - Training Loss: 0.3512 - Training Accuracy: 89.56%
Step [500/870] - Training Loss: 0.2375 - Training Accuracy: 89.62%
Step [600/870] - Training Loss: 0.2249 - Training Accuracy: 89.61%
Step [700/870] - Training Loss: 0.1583 - Training Accuracy: 89.63%
Step [800/870] - Training Loss: 0.2761 - Training Accuracy: 89.58%


Epoch 29/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [29/40] - Training Loss: 0.2502, Training Accuracy: 89.57% - Validation Loss: 0.3534, Validation Accuracy: 85.19%


Epoch 30/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2647 - Training Accuracy: 90.14%
Step [200/870] - Training Loss: 0.4083 - Training Accuracy: 89.88%
Step [300/870] - Training Loss: 0.2445 - Training Accuracy: 89.90%
Step [400/870] - Training Loss: 0.1627 - Training Accuracy: 89.76%
Step [500/870] - Training Loss: 0.2396 - Training Accuracy: 89.72%
Step [600/870] - Training Loss: 0.2744 - Training Accuracy: 89.70%
Step [700/870] - Training Loss: 0.3266 - Training Accuracy: 89.66%
Step [800/870] - Training Loss: 0.3038 - Training Accuracy: 89.59%


Epoch 30/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [30/40] - Training Loss: 0.2500, Training Accuracy: 89.57% - Validation Loss: 0.3524, Validation Accuracy: 85.09%


Epoch 31/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2558 - Training Accuracy: 89.56%
Step [200/870] - Training Loss: 0.1837 - Training Accuracy: 89.53%
Step [300/870] - Training Loss: 0.3007 - Training Accuracy: 89.50%
Step [400/870] - Training Loss: 0.2389 - Training Accuracy: 89.52%
Step [500/870] - Training Loss: 0.2459 - Training Accuracy: 89.54%
Step [600/870] - Training Loss: 0.2737 - Training Accuracy: 89.56%
Step [700/870] - Training Loss: 0.1775 - Training Accuracy: 89.62%
Step [800/870] - Training Loss: 0.2039 - Training Accuracy: 89.56%


Epoch 31/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [31/40] - Training Loss: 0.2498, Training Accuracy: 89.56% - Validation Loss: 0.3524, Validation Accuracy: 85.17%


Epoch 32/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2843 - Training Accuracy: 89.69%
Step [200/870] - Training Loss: 0.2442 - Training Accuracy: 89.54%
Step [300/870] - Training Loss: 0.2238 - Training Accuracy: 89.55%
Step [400/870] - Training Loss: 0.3089 - Training Accuracy: 89.59%
Step [500/870] - Training Loss: 0.3343 - Training Accuracy: 89.59%
Step [600/870] - Training Loss: 0.2383 - Training Accuracy: 89.56%
Step [700/870] - Training Loss: 0.2727 - Training Accuracy: 89.59%
Step [800/870] - Training Loss: 0.2914 - Training Accuracy: 89.59%


Epoch 32/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [32/40] - Training Loss: 0.2500, Training Accuracy: 89.58% - Validation Loss: 0.3524, Validation Accuracy: 85.22%


Epoch 33/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.1518 - Training Accuracy: 89.96%
Step [200/870] - Training Loss: 0.2731 - Training Accuracy: 89.60%
Step [300/870] - Training Loss: 0.2193 - Training Accuracy: 89.59%
Step [400/870] - Training Loss: 0.2341 - Training Accuracy: 89.63%
Step [500/870] - Training Loss: 0.2151 - Training Accuracy: 89.64%
Step [600/870] - Training Loss: 0.2645 - Training Accuracy: 89.65%
Step [700/870] - Training Loss: 0.2915 - Training Accuracy: 89.59%
Step [800/870] - Training Loss: 0.3704 - Training Accuracy: 89.60%


Epoch 33/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [33/40] - Training Loss: 0.2495, Training Accuracy: 89.59% - Validation Loss: 0.3533, Validation Accuracy: 84.94%


Epoch 34/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3287 - Training Accuracy: 89.43%
Step [200/870] - Training Loss: 0.1638 - Training Accuracy: 89.57%
Step [300/870] - Training Loss: 0.2753 - Training Accuracy: 89.59%
Step [400/870] - Training Loss: 0.2974 - Training Accuracy: 89.67%
Step [500/870] - Training Loss: 0.3586 - Training Accuracy: 89.65%
Step [600/870] - Training Loss: 0.2959 - Training Accuracy: 89.60%
Step [700/870] - Training Loss: 0.2120 - Training Accuracy: 89.57%
Step [800/870] - Training Loss: 0.2232 - Training Accuracy: 89.60%


Epoch 34/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [34/40] - Training Loss: 0.2494, Training Accuracy: 89.61% - Validation Loss: 0.3525, Validation Accuracy: 85.04%


Epoch 35/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2531 - Training Accuracy: 89.77%
Step [200/870] - Training Loss: 0.3092 - Training Accuracy: 89.58%
Step [300/870] - Training Loss: 0.2040 - Training Accuracy: 89.56%
Step [400/870] - Training Loss: 0.2643 - Training Accuracy: 89.50%
Step [500/870] - Training Loss: 0.2427 - Training Accuracy: 89.57%
Step [600/870] - Training Loss: 0.2212 - Training Accuracy: 89.53%
Step [700/870] - Training Loss: 0.2518 - Training Accuracy: 89.56%
Step [800/870] - Training Loss: 0.2854 - Training Accuracy: 89.60%


Epoch 35/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [35/40] - Training Loss: 0.2492, Training Accuracy: 89.59% - Validation Loss: 0.3523, Validation Accuracy: 85.09%


Epoch 36/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2030 - Training Accuracy: 89.65%
Step [200/870] - Training Loss: 0.2785 - Training Accuracy: 89.70%
Step [300/870] - Training Loss: 0.1533 - Training Accuracy: 89.72%
Step [400/870] - Training Loss: 0.2158 - Training Accuracy: 89.77%
Step [500/870] - Training Loss: 0.1960 - Training Accuracy: 89.69%
Step [600/870] - Training Loss: 0.2353 - Training Accuracy: 89.67%
Step [700/870] - Training Loss: 0.2195 - Training Accuracy: 89.69%
Step [800/870] - Training Loss: 0.2312 - Training Accuracy: 89.62%


Epoch 36/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [36/40] - Training Loss: 0.2491, Training Accuracy: 89.61% - Validation Loss: 0.3523, Validation Accuracy: 85.13%


Epoch 37/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2743 - Training Accuracy: 89.12%
Step [200/870] - Training Loss: 0.1692 - Training Accuracy: 89.70%
Step [300/870] - Training Loss: 0.2304 - Training Accuracy: 89.60%
Step [400/870] - Training Loss: 0.2555 - Training Accuracy: 89.75%
Step [500/870] - Training Loss: 0.1932 - Training Accuracy: 89.77%
Step [600/870] - Training Loss: 0.2291 - Training Accuracy: 89.66%
Step [700/870] - Training Loss: 0.3038 - Training Accuracy: 89.63%
Step [800/870] - Training Loss: 0.2637 - Training Accuracy: 89.62%


Epoch 37/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [37/40] - Training Loss: 0.2491, Training Accuracy: 89.60% - Validation Loss: 0.3523, Validation Accuracy: 85.09%


Epoch 38/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2809 - Training Accuracy: 89.44%
Step [200/870] - Training Loss: 0.1959 - Training Accuracy: 89.68%
Step [300/870] - Training Loss: 0.2330 - Training Accuracy: 89.61%
Step [400/870] - Training Loss: 0.3635 - Training Accuracy: 89.67%
Step [500/870] - Training Loss: 0.3186 - Training Accuracy: 89.64%
Step [600/870] - Training Loss: 0.2932 - Training Accuracy: 89.72%
Step [700/870] - Training Loss: 0.2346 - Training Accuracy: 89.67%
Step [800/870] - Training Loss: 0.2753 - Training Accuracy: 89.67%


Epoch 38/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [38/40] - Training Loss: 0.2490, Training Accuracy: 89.62% - Validation Loss: 0.3522, Validation Accuracy: 85.14%


Epoch 39/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.3513 - Training Accuracy: 89.24%
Step [200/870] - Training Loss: 0.2517 - Training Accuracy: 89.62%
Step [300/870] - Training Loss: 0.2804 - Training Accuracy: 89.47%
Step [400/870] - Training Loss: 0.2709 - Training Accuracy: 89.54%
Step [500/870] - Training Loss: 0.2327 - Training Accuracy: 89.55%
Step [600/870] - Training Loss: 0.2713 - Training Accuracy: 89.53%
Step [700/870] - Training Loss: 0.1986 - Training Accuracy: 89.55%
Step [800/870] - Training Loss: 0.2839 - Training Accuracy: 89.60%


Epoch 39/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [39/40] - Training Loss: 0.2490, Training Accuracy: 89.61% - Validation Loss: 0.3522, Validation Accuracy: 85.12%


Epoch 40/40 - Training:   0%|          | 0/870 [00:00<?, ?it/s]

Step [100/870] - Training Loss: 0.2507 - Training Accuracy: 89.55%
Step [200/870] - Training Loss: 0.2611 - Training Accuracy: 89.37%
Step [300/870] - Training Loss: 0.2392 - Training Accuracy: 89.54%
Step [400/870] - Training Loss: 0.2013 - Training Accuracy: 89.65%
Step [500/870] - Training Loss: 0.1872 - Training Accuracy: 89.59%
Step [600/870] - Training Loss: 0.3102 - Training Accuracy: 89.59%
Step [700/870] - Training Loss: 0.3982 - Training Accuracy: 89.63%
Step [800/870] - Training Loss: 0.2867 - Training Accuracy: 89.59%


Epoch 40/40 - Validation:   0%|          | 0/95 [00:00<?, ?it/s]

Epoch [40/40] - Training Loss: 0.2489, Training Accuracy: 89.61% - Validation Loss: 0.3522, Validation Accuracy: 85.13%
