In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import random
from tqdm.auto import tqdm
import webdataset as wds
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils.flat_models import *
import h5py
from typing import List, Dict, Any, Tuple
from sklearn.preprocessing import StandardScaler
import argparse

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

# ## MODEL TO LOAD ##
# model_name = "HCPflat_large_gsrFalse_"
# parquet_folder = "epoch99"

# # outdir = os.path.abspath(f'checkpoints/{model_name}')
# outdir = os.path.abspath(f'checkpoints/{model_name}')

# print("outdir", outdir)
# # Load previous config.yaml if available
# if os.path.exists(f"{outdir}/config.yaml"):
#     config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
#     print(f"Loaded config.yaml from ckpt folder {outdir}")
#     # create global variables from the config
#     print("\n__CONFIG__")
#     for attribute_name in config.keys():
#         print(f"{attribute_name} = {config[attribute_name]}")
#         globals()[attribute_name] = config[f'{attribute_name}']
#     print("\n")

# world_size = os.getenv('WORLD_SIZE')
# if world_size is None: 
#     world_size = 1
# else:
#     world_size = int(world_size)
# print(f"WORLD_SIZE={world_size}")

# if utils.is_interactive():
#     # Following allows you to change functions in models.py or utils.py and 
#     # have this notebook automatically update with your revisions
#     %load_ext autoreload
#     %autoreload 2

# batch_size = probe_batch_size
# num_epochs = probe_num_epochs

# data_type = torch.float32 # change depending on your mixed_precision
# global_batch_size = batch_size * world_size

device = torch.device('cuda')

# hcp_flat_path = "/weka/proj-medarc/shared/HCP-Flat"
# seed = 42
num_frames = 16
gsr = False
# num_workers = 5
batch_size = 128
# target = 'sex' # This can be 'trial_type' 'age' 'sex'

print("PID of this process =",os.getpid())

PID of this process = 1392922


In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "NSDflat_large_gsrFalse_5sess_57734"
    print("model_name:", model_name)
    outdir = os.path.abspath(f'checkpoints/{model_name}')
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = f"--hcp_flat_path=/weka/proj-medarc/shared/HCP-Flat \
                    --target=subject_id \
                    --model_name={model_name} \
                    --batch_size={batch_size} \
                    --max_lr=3e-3 --num_epochs=40 --no-save_ckpt --no-wandb_log --num_workers=10 \
                    --weight_decay=1e-5 \
                    --parquet_folder=epoch99 \
                    --outdir={outdir} \
                    --global_pooling"
    # --multisubject_ckpt=../train_logs/multisubject_subj01_1024_24bs_nolow
    # recommended max_lr on trial_type 9e-3, on age 1e-3, on sex 1e-3
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: NSDflat_large_gsrFalse_5sess_57734
--hcp_flat_path=/weka/proj-medarc/shared/HCP-Flat                     --target=subject_id                     --model_name=NSDflat_large_gsrFalse_5sess_57734                     --batch_size=128                     --max_lr=3e-3 --num_epochs=40 --no-save_ckpt --no-wandb_log --num_workers=10                     --weight_decay=1e-5                     --parquet_folder=epoch99                     --outdir=/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_5sess_57734                     --global_pooling


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="HCPflat_large_gsrFalse_",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--hcp_flat_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--batch_size", type=int, default=128,
    help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--num_epochs",type=int,default=150,
    help="number of epochs of training",
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--save_ckpt",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)
parser.add_argument(
    "--target",type=str,default='trial_type',  # "trial_type" or "subject_id" or see table on HCP_downstream.ipynb
)
parser.add_argument(
    "--num_workers",type=int,default=10,
)
parser.add_argument(
    "--weight_decay",type=float,default=1e-5,
)
parser.add_argument(
    "--parquet_folder",type=str,default='epoch99',
)
parser.add_argument(
    "--outdir",type=str,default='./checkpoints/HCPflat_large_gsrFalse_/',
    help="Path where the precomputed features are located"
)
parser.add_argument(
    "--global_pooling",action=argparse.BooleanOptionalAction,default=True,
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

print(f"------ ARGS ------- \n {args}")

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)


------ ARGS ------- 
 Namespace(model_name='NSDflat_large_gsrFalse_5sess_57734', hcp_flat_path='/weka/proj-medarc/shared/HCP-Flat', batch_size=128, wandb_log=False, num_epochs=40, lr_scheduler_type='cycle', save_ckpt=False, seed=42, max_lr=0.003, target='subject_id', num_workers=10, weight_decay=1e-05, parquet_folder='epoch99', outdir='/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_5sess_57734', global_pooling=True)


In [4]:
# #### UNCOMMENT THIS TO SAVE THE HCP-FLAT IN HDF5 FORMAT


# from torch.utils.data import default_collate
# from mae_utils.flat import load_hcp_flat_mask
# from mae_utils.flat import create_hcp_flat
# from mae_utils.flat import batch_unmask
# import mae_utils.visualize as vis


# batch_size = 1
# print(f"changed batch_size to {batch_size}")
# load_file_frames = num_frames * 2
# print(f"Calculating with {load_file_frames} frames, doubling to approximate TR")

# ## Test ##
# datasets_to_include = "HCP"
# assert "HCP" in datasets_to_include
# test_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=load_file_frames, shuffle=False, gsr=gsr, sub_list = 'test')
# test_dl = wds.WebLoader(
#     test_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# ## Train ##
# assert "HCP" in datasets_to_include
# train_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=load_file_frames, shuffle=False, gsr=gsr, sub_list = 'train')
# train_dl = wds.WebLoader(
#     train_dataset.batched(batch_size, partial=True, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# def flatten_meta(meta_dict):
#     """
#     Flatten the meta dictionary by:
#     - Replacing single-item lists with the item itself.
#     - Converting tensors to scalar numbers.
#     """
#     flattened = {}
#     for key, value in meta_dict.items():
#         if isinstance(value, list):
#             if len(value) == 1:
#                 flattened[key] = value[0]  # Replace list with its single item
#             else:
#                 flattened[key] = value  # Keep as is if multiple items
#         elif isinstance(value, torch.Tensor):
#             # Convert tensor to scalar
#             if value.numel() == 1:
#                 flattened[key] = value.item()
#             else:
#                 flattened[key] = value.tolist()  # Convert multi-element tensor to list
#         else:
#             flattened[key] = value  # Keep the value as is
#     return flattened


# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File(f'test_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(test_dl), total = 12000):
#         images = batch[0]
#         meta = batch[1]
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save(f'metadata_test_HCP_raw_flatmaps_{load_file_frames}f.npy', meta_array)


# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File(f'train_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(train_dl), total = 120000):
#         images = batch[0]
#         meta = batch[1]
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save(f'metadata_train_HCP_raw_flatmaps_{load_file_frames}f.npy', meta_array)

### Data

In [5]:
from sklearn.preprocessing import LabelEncoder

###### This is for restricted
# Categorical columns (e.g., demographic categories, binary diagnoses)
categorical_columns = [
    "Gender",
    "Race",
    "Ethnicity",
    "SSAGA_PanicDisorder",  # Panic disorder diagnosis (yes/no)
    "SSAGA_Depressive_Ep"   # Depressive episode diagnosis (yes/no)
]

# Numerical columns (continuous, counts, raw scores, standardized scores, etc.)
numerical_columns = [
    # Basic demographics
    "Age_in_Yrs",
    
    # Cognitive / "IQ-like" Measures
    "PMAT24_A_CR",
    "CardSort_Unadj",
    "CardSort_AgeAdj",
    "ListSort_Unadj",
    "ListSort_AgeAdj",
    "PicSeq_Unadj",
    "PicSeq_AgeAdj",
    
    # Personality Traits (Big Five)
    "NEOFAC_A",
    "NEOFAC_O",
    "NEOFAC_C",
    "NEOFAC_N",
    "NEOFAC_E",
    # If you have all 60 NEO item-level responses:
    # "NEORAW_01", "NEORAW_02", ..., "NEORAW_60",
    
    # Psychopathology / Mental Health
    "ASR_Anxd_Raw",
    "ASR_Attn_Raw",
    "ASR_Aggr_Raw",
    "DSM_Depr_Raw",
    "DSM_Anxi_Raw",
    "SSAGA_Depressive_Sx",  # Symptom count or severity
    
    # Substance Use Phenotypes
    "SSAGA_Alc_12_Frq",
    "SSAGA_Alc_12_Max_Drinks",
    "SSAGA_Times_Used_Illicits",
    "SSAGA_Times_Used_Cocaine",
    "Total_Drinks_7days",
    "Total_Any_Tobacco_7days",
    
    # Anthropometric / Basic Health
    "BMI",
    "Height",
    "Weight",
    "BPSystolic",
    "BPDiastolic",
    "HbA1C",
    "ThyroidHormone",
    
    # Sleep / Quality of Life
    "PSQI_Score",
    # If you have separate PSQI component scores, list them here too, e.g.:
    # "PSQI_Component1", "PSQI_Component2", ...
    "PainInterf_Tscore",
    "LifeSatisf_Unadj",
    "MeanPurp_Unadj"
]


if target in ['subject_id', 'trial_type']:
    target_type = 'special'
elif target in categorical_columns:
    target_type = 'categorical'
elif target in numerical_columns:
    target_type = 'numerical'


# open the file containing subject information
if not (target in ['subject_id', 'trial_type']):
    subject_information_HCP_path = os.path.join(hcp_flat_path, "subjects_data_restricted.csv")
    try:
        subject_information_HCP = pd.read_csv(subject_information_HCP_path)
    except:
        try:
            subject_information_HCP = pd.read_csv('./unrestricted_clane9_4_23_2024_13_28_14.csv')   
        except:
            assert False, "Subject information file not found"

    # # show the first few rows of the subject information
    # subject_information_HCP[age_related_columns + sex_related_columns].head()    

    if target in numerical_columns:
        # Count NaNs or missing values
        n_missing = subject_information_HCP[target].isnull().sum()
        print(f"Number of missing values in {target}: {n_missing}. Replacing with mean.")
        mean_ = subject_information_HCP[target].mean()
        # Replace missing values or NaNs with the mean
        subject_information_HCP[target].fillna(mean_, inplace=True)
        # Initialize the scaler
        scaler = StandardScaler()    
        # Perform z-score normalization
        subject_information_HCP[f'{target}_z'] = scaler.fit_transform(subject_information_HCP[[target]])

    if target in categorical_columns:
        # Perform label encoding
        label_enc = LabelEncoder()
        subject_information_HCP[f'{target}_encoded'] = label_enc.fit_transform(subject_information_HCP[target])

def train_test_split_by_subject(df, test_ratio=0.1, random_state=42):
    """
    Split a dataframe into train and test so that
    every subject in test also appears in train at least once.

    Parameters
    ----------
    df : pd.DataFrame
        Your dataset, containing at least the columns:
        ['sub', ...]
    test_ratio : float
        Percentage of each subject's rows to allocate to test.
    random_state : int
        Random seed for reproducibility.

    Returns
    -------
    train_df : pd.DataFrame
    test_df : pd.DataFrame
    """
    np.random.seed(random_state)
    
    train_dfs = []
    test_dfs = []
    
    # Group by subject
    for subject, df_sub in df.groupby('sub'):
        n = len(df_sub)
        
        # If the subject only has 1 row, put it all in train
        if n == 1:
            train_dfs.append(df_sub)
        else:
            # Decide how many rows go to test
            n_test = int(round(test_ratio * n))
            # Ensure at least 1 row ends up in train
            # (i.e. if rounding leads to n_test == n, reduce n_test by 1)
            if n_test >= n:
                n_test = n - 1
            
            # Randomly sample n_test rows for test
            test_rows = df_sub.sample(n_test, random_state=random_state)
            # The remaining go to train
            train_rows = df_sub.drop(test_rows.index)
            
            test_dfs.append(test_rows)
            train_dfs.append(train_rows)
    
    # Combine all splits
    train_df = pd.concat(train_dfs).sample(frac=1, random_state=random_state).reset_index(drop=True)
    test_df = pd.concat(test_dfs).sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    return train_df, test_df

    
def get_label_restricted(subject_id: List[str], target: str, normalized: bool = True) -> List:
    """
    Get the label for the given subject id and target
    """
    subject_id = [int(x) for x in subject_id]

    if target in numerical_columns:
        target_array = []
        for subject in subject_id:
            c_target = subject_information_HCP[subject_information_HCP['Subject'] == subject][f'{target}' if not normalized else f'{target}_z'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_target) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_target) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            target_array.append(np.float32(c_target[0]))

        return np.array(target_array)
    
    elif target in categorical_columns:
        target_array = []
        for subject in subject_id:
            c_target = subject_information_HCP[subject_information_HCP['Subject'] == subject][f'{target}_encoded'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_target) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_target) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            target_array.append(np.int8(c_target[0]))

        return np.array(target_array)
    

In [6]:
# f_train = h5py.File('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/train_hcp_raw_flatmaps.hdf5', 'r')
# flatmaps_train = f_train['flatmaps']

# f_test = h5py.File('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/test_hcp_raw_flatmaps.hdf5', 'r')
# flatmaps_test = f_test['flatmaps']

# metadata_train = np.load('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/metadata_train_HCP_raw_flatmaps.npy', allow_pickle=True)
# metadata_test = np.load('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/metadata_test_HCP_raw_flatmaps.npy', allow_pickle=True)

train_features = pd.read_parquet(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/train.parquet")
test_features = pd.read_parquet(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/test.parquet")

if target == "subject_id":
    # We have to redo the train test split including every possible subject from test on train.
    all_features = pd.concat([train_features, test_features])
    all_features["sub"] = all_features["sub"].apply(lambda x: x[0])
    train_features, test_features = train_test_split_by_subject(all_features)

### Create the dataloader

In [7]:
from torch.utils.data import Dataset, DataLoader

class HCPFlatDataset(Dataset):
    def __init__(self, parquet_data):
        self.parquet_data = parquet_data

    def __len__(self):
        return len(self.parquet_data)

    def __getitem__(self, idx):
        c_feature = torch.Tensor(self.parquet_data.iloc[idx]['feature'])
        c_sub = self.parquet_data.iloc[idx]['sub']
        c_trial_type = self.parquet_data.iloc[idx]['trial_type'][0]
        if isinstance(c_sub, List):
            c_sub = c_sub[0]
        return c_feature, c_trial_type, c_sub

# Loading to cpu for faster training, this can take several minutes.  Remove this [:] if you want to move one at the time.
train_dataset = HCPFlatDataset(train_features)
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

test_dataset = HCPFlatDataset(test_features)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)


In [8]:
from sklearn.preprocessing import LabelEncoder

if target == "trial_type":
    INCLUDE_CONDS = {
        "fear",
        "neut",
        "math",
        "story",
        "lf",
        "lh",
        "rf",
        "rh",
        "t",
        "match",
        "relation",
        "mental",
        "rnd",
        "0bk_body",
        "2bk_body",
        "0bk_faces",
        "2bk_faces",
        "0bk_places",
        "2bk_places",
        "0bk_tools",
        "2bk_tools",
    }
    # Initialize the label encoder
    label_enc = LabelEncoder()
    label_enc.fit(sorted(INCLUDE_CONDS))  # Ensure consistent ordering
elif target == 'subject_id':
    all_subs = train_features['sub'].values
    label_enc = LabelEncoder()
    label_enc.fit(all_subs)
    
if target_type in ['categorical', 'special']:
    num_classes = len(label_enc.classes_)
    print(f"Number of classes: {num_classes}")

Number of classes: 1093


### Create pytorch model

In [12]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # Flatten the input except for the batch dimension
        x = x.view(x.size(0), -1)
        out = self.linear(x)
        return out  # Raw logits

# Determine the input dimension from a single sample
# Assuming images are of shape [1, 16, 144, 320]
sample_batch = next(iter(train_dl))
sample_image = sample_batch[0][0]  # Shape: [1, 16, 144, 320]
input_dim = sample_image.view(-1).size(0)
print(f"Input dimension: {input_dim}")


Input dimension: 1024


  c_feature = torch.Tensor(self.parquet_data.iloc[idx]['feature'])


In [13]:
# Initialize the model
if (target in ["trial_type", "subject_id"]) or (target in categorical_columns):
    model = LinearClassifier(input_dim=input_dim, num_classes=num_classes)
    criterion = nn.CrossEntropyLoss()

elif target in numerical_columns:
    model = LinearClassifier(input_dim=input_dim, num_classes=1)
    criterion = nn.MSELoss()


# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# import schedulefree
# optimizer = schedulefree.AdamWScheduleFree(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)

num_iterations_per_epoch = math.ceil(train_features.shape[0]/batch_size)

if lr_scheduler_type == 'linear':
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        total_iters=int(np.ceil(num_epochs*num_iterations_per_epoch)),
        last_epoch=-1
    )
elif lr_scheduler_type == 'cycle':
    total_steps=int(np.ceil(num_epochs*num_iterations_per_epoch))
    print("total_steps", total_steps)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )


total_steps 27400


### Wandb logging

In [14]:
import wandb
import uuid

myuuid = uuid.uuid4()
str(myuuid)
if utils.is_interactive():
    print("Running in interactive notebook. Disabling W&B and ckpt saving.")
    wandb_log = False
    save_ckpt = False

if wandb_log:
    wandb_project = 'fMRI-foundation-model'
    wandb_config = {
        "model_name": f"HCPflat_raw_{target}",
        "batch_size": batch_size,
        "weight_decay": weight_decay,
        "num_epochs": num_epochs,
        "seed": seed,
        "lr_scheduler_type": lr_scheduler_type,
        "save_ckpt": save_ckpt,
        "seed": seed,
        "max_lr": max_lr,
        "target": target,
        "num_workers": num_workers,
        "weight_decay": weight_decay
    }
    print("wandb_config:\n", wandb_config)
    random_id = random.randint(0, 100000)
    wandb_id = "HCPflat_raw" + f"_{model_suffix}_{target}_{myuuid}"
    print("wandb_id:", wandb_id)
    wandb.init(
        id=wandb_id,
        project=wandb_project,
        name="HCPflat_raw"+ f"_{model_suffix}_{target}",
        config=wandb_config,
        resume="allow",
    )

Running in interactive notebook. Disabling W&B and ckpt saving.


### Training loop

In [15]:
for epoch in range(num_epochs):
    running_train_loss = 0.0
    correct_train = 0
    mse_age_train = 0.0
    total_train = 0
    step = 0

    # Training Phase
    model.train()
    optimizer.zero_grad()  # Reset gradients before starting training

    for batch in tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        optimizer.zero_grad()
        images = batch[0].to(device).float()  # Shape: [batch_size, 1, 16, 144, 320]

        # Prepare labels based on target type
        if target == "trial_type":
            labels = batch[1]  # List of labels
            labels = label_enc.transform(labels)
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        if target == 'subject_id':
            labels = label_enc.transform(batch[2])
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        elif target in numerical_columns:
            labels = get_label_restricted(batch[2], target)
            labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
        elif target in categorical_columns:
            labels = get_label_restricted(batch[2], target)
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
            
        # Forward pass
        outputs = model(images)  # Output shape depends on the target

        # Compute loss
        loss = criterion(outputs.squeeze(), labels.squeeze())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate loss
        running_train_loss += loss.item()

        # Calculate and accumulate metrics
        if target_type in ["categorical", "special"]:
            _, predicted = torch.max(outputs, 1)
            correct_train += (predicted == labels).sum().item()
        elif target_type == "numerical":
            mse_age_train += (torch.sum((outputs.squeeze() - labels) ** 2).item())

        total_train += labels.size(0)
        step += 1

        # Print intermediate metrics every 100 steps
        if step % 100 == 0:
            if target_type in ['categorical', 'special']:
                current_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
                print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training Accuracy: {current_accuracy:.2f}%")
            elif target_type == 'numerical':
                current_mse = mse_age_train / total_train if total_train > 0 else 0.0
                print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training MSE: {current_mse:.4f}")

        if lr_scheduler_type is not None:
                lr_scheduler.step()

    # Calculate epoch-level metrics
    epoch_train_loss = running_train_loss / step if step > 0 else 0.0

    if target_type in ['categorical', 'special']:
        train_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
    elif target_type == 'numerical':
        train_mse = mse_age_train / total_train if total_train > 0 else 0.0

    # Validation Phase
    model.eval()
    running_val_loss = 0.0
    correct_val = 0
    mse_age_val = 0.0
    total_val = 0
    step_val = 0
    with torch.no_grad():
        for batch in tqdm(test_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            images = batch[0].to(device).float()  # Removed unsqueeze(1) unless specifically needed

            # Prepare labels based on target type
            if target == "trial_type":
                labels = batch[1]  # List of labels
                labels = label_enc.transform(labels)
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
            if target == 'subject_id':
                labels = label_enc.transform(batch[2])
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
            elif target in numerical_columns:
                labels = get_label_restricted(batch[2], target)
                labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
            elif target in categorical_columns:
                labels = get_label_restricted(batch[2], target)
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]

            
            # Forward pass
            outputs = model(images)

            
            loss = criterion(outputs.squeeze(), labels.squeeze())

            # Accumulate loss
            running_val_loss += loss.item() 

            # Calculate and accumulate metrics
            if target_type in ["categorical", "special"]:
                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
            elif target_type == "numerical":
                mse_age_val += (torch.sum((outputs.squeeze() - labels) ** 2).item())

            total_val += labels.size(0)
            step_val += 1

    # Calculate epoch-level validation metrics
    epoch_val_loss = running_val_loss / step_val if step_val > 0 else 0.0

    if target_type in ['categorical', 'special']:
        val_accuracy = 100 * correct_val / total_val if total_val > 0 else 0.0
    elif target_type == 'numerical':
        val_mse = mse_age_val / total_val if total_val > 0 else 0.0


    # Print epoch-level metrics
    if target_type in ['categorical', 'special']:
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"- Training Loss: {epoch_train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}% "
              f"- Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    elif target_type == 'numerical':
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"- Training Loss: {epoch_train_loss:.4f}, Training MSE: {train_mse:.4f} "
              f"- Validation Loss: {epoch_val_loss:.4f}, Validation MSE: {val_mse:.4f}")

    # Log metrics with wandb
    if wandb_log:
        log_dict = {
            "epoch_train_loss": epoch_train_loss,
            "epoch_val_loss": epoch_val_loss,
        }
        if target_type in ['categorical', 'special']:
            log_dict.update({
                f"train_accuracy_{target}": train_accuracy,
                f"val_accuracy_{target}": val_accuracy,
            })
        elif target_type == 'numerical':
            log_dict.update({
                f"train_mse_{target}": train_mse,
                f"val_mse_{target}": val_mse,
            })
        wandb.log(log_dict)

# Save checkpoint if required
if save_ckpt:
    outdir = os.path.abspath(f'checkpoints/{"HCPflat_raw"+ f"_{model_suffix}_{target}"}_{random_id}')
    os.makedirs(outdir, exist_ok=True)
    print("Saving checkpoint to:", outdir)
    # Save model state
    torch.save(model.state_dict(), os.path.join(outdir, "model.pth"))
    # Save configuration
    with open(os.path.join(outdir, "config.yaml"), 'w') as f:
        yaml.dump(wandb_config, f)
    print(f"Model and config saved to {outdir}")


Epoch 1/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 7.0651 - Training Accuracy: 0.09%
Step [200/685] - Training Loss: 7.0668 - Training Accuracy: 0.15%
Step [300/685] - Training Loss: 7.0299 - Training Accuracy: 0.16%
Step [400/685] - Training Loss: 7.1200 - Training Accuracy: 0.22%
Step [500/685] - Training Loss: 7.1347 - Training Accuracy: 0.26%
Step [600/685] - Training Loss: 7.1563 - Training Accuracy: 0.30%


Epoch 1/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [1/40] - Training Loss: 7.0663, Training Accuracy: 0.35% - Validation Loss: 7.2991, Validation Accuracy: 0.69%


Epoch 2/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 7.1502 - Training Accuracy: 0.90%
Step [200/685] - Training Loss: 7.5711 - Training Accuracy: 0.83%
Step [300/685] - Training Loss: 7.4754 - Training Accuracy: 0.86%
Step [400/685] - Training Loss: 7.4018 - Training Accuracy: 0.86%
Step [500/685] - Training Loss: 7.9718 - Training Accuracy: 0.94%
Step [600/685] - Training Loss: 7.6814 - Training Accuracy: 0.99%


Epoch 2/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [2/40] - Training Loss: 7.3618, Training Accuracy: 1.03% - Validation Loss: 6.9951, Validation Accuracy: 1.19%


Epoch 3/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 7.0974 - Training Accuracy: 2.18%
Step [200/685] - Training Loss: 6.7064 - Training Accuracy: 2.02%
Step [300/685] - Training Loss: 7.1806 - Training Accuracy: 1.96%
Step [400/685] - Training Loss: 7.1565 - Training Accuracy: 2.04%
Step [500/685] - Training Loss: 6.7765 - Training Accuracy: 2.10%
Step [600/685] - Training Loss: 7.4101 - Training Accuracy: 2.11%


Epoch 3/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [3/40] - Training Loss: 7.0958, Training Accuracy: 2.11% - Validation Loss: 7.4936, Validation Accuracy: 1.68%


Epoch 4/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 6.7011 - Training Accuracy: 3.45%
Step [200/685] - Training Loss: 6.8007 - Training Accuracy: 3.29%
Step [300/685] - Training Loss: 6.7754 - Training Accuracy: 3.22%
Step [400/685] - Training Loss: 6.8147 - Training Accuracy: 3.20%
Step [500/685] - Training Loss: 6.9430 - Training Accuracy: 3.19%
Step [600/685] - Training Loss: 6.5853 - Training Accuracy: 3.21%


Epoch 4/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [4/40] - Training Loss: 6.7169, Training Accuracy: 3.23% - Validation Loss: 6.5050, Validation Accuracy: 2.40%


Epoch 5/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 6.7308 - Training Accuracy: 4.62%
Step [200/685] - Training Loss: 6.3986 - Training Accuracy: 4.39%
Step [300/685] - Training Loss: 6.2805 - Training Accuracy: 4.42%
Step [400/685] - Training Loss: 6.4492 - Training Accuracy: 4.42%
Step [500/685] - Training Loss: 6.3174 - Training Accuracy: 4.43%
Step [600/685] - Training Loss: 6.5973 - Training Accuracy: 4.47%


Epoch 5/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [5/40] - Training Loss: 6.4212, Training Accuracy: 4.44% - Validation Loss: 6.2611, Validation Accuracy: 2.47%


Epoch 6/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 6.3730 - Training Accuracy: 6.18%
Step [200/685] - Training Loss: 5.8078 - Training Accuracy: 5.95%
Step [300/685] - Training Loss: 6.3007 - Training Accuracy: 5.94%
Step [400/685] - Training Loss: 6.1255 - Training Accuracy: 5.88%
Step [500/685] - Training Loss: 6.4644 - Training Accuracy: 5.78%
Step [600/685] - Training Loss: 6.0168 - Training Accuracy: 5.71%


Epoch 6/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [6/40] - Training Loss: 6.1870, Training Accuracy: 5.75% - Validation Loss: 6.5903, Validation Accuracy: 2.98%


Epoch 7/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 5.9370 - Training Accuracy: 7.31%
Step [200/685] - Training Loss: 6.3515 - Training Accuracy: 6.99%
Step [300/685] - Training Loss: 5.6957 - Training Accuracy: 6.94%
Step [400/685] - Training Loss: 6.0466 - Training Accuracy: 6.94%
Step [500/685] - Training Loss: 5.5124 - Training Accuracy: 6.98%
Step [600/685] - Training Loss: 6.1644 - Training Accuracy: 6.89%


Epoch 7/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [7/40] - Training Loss: 5.9845, Training Accuracy: 6.85% - Validation Loss: 5.8626, Validation Accuracy: 3.45%


Epoch 8/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 5.6255 - Training Accuracy: 8.65%
Step [200/685] - Training Loss: 5.4993 - Training Accuracy: 8.56%
Step [300/685] - Training Loss: 5.5303 - Training Accuracy: 8.47%
Step [400/685] - Training Loss: 5.6364 - Training Accuracy: 8.33%
Step [500/685] - Training Loss: 5.6967 - Training Accuracy: 8.18%
Step [600/685] - Training Loss: 6.0530 - Training Accuracy: 8.07%


Epoch 8/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [8/40] - Training Loss: 5.7896, Training Accuracy: 8.01% - Validation Loss: 5.9707, Validation Accuracy: 3.09%


Epoch 9/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 5.3207 - Training Accuracy: 9.27%
Step [200/685] - Training Loss: 5.5739 - Training Accuracy: 9.32%
Step [300/685] - Training Loss: 5.5583 - Training Accuracy: 9.17%
Step [400/685] - Training Loss: 5.7431 - Training Accuracy: 9.08%
Step [500/685] - Training Loss: 5.5405 - Training Accuracy: 9.08%
Step [600/685] - Training Loss: 5.7401 - Training Accuracy: 9.10%


Epoch 9/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [9/40] - Training Loss: 5.6153, Training Accuracy: 9.10% - Validation Loss: 5.5362, Validation Accuracy: 4.06%


Epoch 10/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 5.7526 - Training Accuracy: 11.02%
Step [200/685] - Training Loss: 5.5888 - Training Accuracy: 10.58%
Step [300/685] - Training Loss: 5.5501 - Training Accuracy: 10.53%
Step [400/685] - Training Loss: 5.4063 - Training Accuracy: 10.62%
Step [500/685] - Training Loss: 5.7607 - Training Accuracy: 10.59%
Step [600/685] - Training Loss: 5.0256 - Training Accuracy: 10.55%


Epoch 10/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [10/40] - Training Loss: 5.4421, Training Accuracy: 10.48% - Validation Loss: 5.7312, Validation Accuracy: 4.01%


Epoch 11/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 5.5524 - Training Accuracy: 11.71%
Step [200/685] - Training Loss: 4.8207 - Training Accuracy: 11.50%
Step [300/685] - Training Loss: 4.7786 - Training Accuracy: 11.49%
Step [400/685] - Training Loss: 5.4091 - Training Accuracy: 11.48%
Step [500/685] - Training Loss: 4.5513 - Training Accuracy: 11.39%
Step [600/685] - Training Loss: 5.4039 - Training Accuracy: 11.39%


Epoch 11/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [11/40] - Training Loss: 5.2891, Training Accuracy: 11.39% - Validation Loss: 5.2787, Validation Accuracy: 4.66%


Epoch 12/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 5.3529 - Training Accuracy: 13.20%
Step [200/685] - Training Loss: 5.3526 - Training Accuracy: 12.95%
Step [300/685] - Training Loss: 5.1360 - Training Accuracy: 12.99%
Step [400/685] - Training Loss: 4.8241 - Training Accuracy: 13.00%
Step [500/685] - Training Loss: 4.8139 - Training Accuracy: 12.94%
Step [600/685] - Training Loss: 5.6051 - Training Accuracy: 12.88%


Epoch 12/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [12/40] - Training Loss: 5.1394, Training Accuracy: 12.79% - Validation Loss: 4.9638, Validation Accuracy: 4.72%


Epoch 13/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 4.6580 - Training Accuracy: 14.29%
Step [200/685] - Training Loss: 4.9083 - Training Accuracy: 14.35%
Step [300/685] - Training Loss: 4.9828 - Training Accuracy: 14.17%
Step [400/685] - Training Loss: 4.9253 - Training Accuracy: 14.01%
Step [500/685] - Training Loss: 4.9516 - Training Accuracy: 13.94%
Step [600/685] - Training Loss: 5.2470 - Training Accuracy: 13.90%


Epoch 13/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [13/40] - Training Loss: 5.0040, Training Accuracy: 13.86% - Validation Loss: 5.1335, Validation Accuracy: 5.76%


Epoch 14/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 4.9697 - Training Accuracy: 15.98%
Step [200/685] - Training Loss: 4.8468 - Training Accuracy: 15.44%
Step [300/685] - Training Loss: 5.0241 - Training Accuracy: 15.31%
Step [400/685] - Training Loss: 4.8068 - Training Accuracy: 15.22%
Step [500/685] - Training Loss: 4.7189 - Training Accuracy: 15.22%
Step [600/685] - Training Loss: 4.8032 - Training Accuracy: 15.19%


Epoch 14/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [14/40] - Training Loss: 4.8605, Training Accuracy: 15.23% - Validation Loss: 5.0107, Validation Accuracy: 5.50%


Epoch 15/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 4.4213 - Training Accuracy: 17.16%
Step [200/685] - Training Loss: 4.4352 - Training Accuracy: 17.10%
Step [300/685] - Training Loss: 4.7139 - Training Accuracy: 16.91%
Step [400/685] - Training Loss: 4.9286 - Training Accuracy: 16.69%
Step [500/685] - Training Loss: 5.0116 - Training Accuracy: 16.71%
Step [600/685] - Training Loss: 4.8148 - Training Accuracy: 16.54%


Epoch 15/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [15/40] - Training Loss: 4.7200, Training Accuracy: 16.47% - Validation Loss: 4.8986, Validation Accuracy: 5.56%


Epoch 16/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 4.3435 - Training Accuracy: 18.16%
Step [200/685] - Training Loss: 4.3457 - Training Accuracy: 17.95%
Step [300/685] - Training Loss: 4.8304 - Training Accuracy: 17.81%
Step [400/685] - Training Loss: 4.3093 - Training Accuracy: 17.71%
Step [500/685] - Training Loss: 4.9439 - Training Accuracy: 17.56%
Step [600/685] - Training Loss: 4.7677 - Training Accuracy: 17.47%


Epoch 16/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [16/40] - Training Loss: 4.5944, Training Accuracy: 17.50% - Validation Loss: 4.5564, Validation Accuracy: 6.11%


Epoch 17/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 4.6131 - Training Accuracy: 19.31%
Step [200/685] - Training Loss: 4.2037 - Training Accuracy: 19.21%
Step [300/685] - Training Loss: 4.3686 - Training Accuracy: 19.09%
Step [400/685] - Training Loss: 4.1070 - Training Accuracy: 18.94%
Step [500/685] - Training Loss: 4.4315 - Training Accuracy: 18.89%
Step [600/685] - Training Loss: 4.5392 - Training Accuracy: 18.80%


Epoch 17/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [17/40] - Training Loss: 4.4756, Training Accuracy: 18.75% - Validation Loss: 4.1061, Validation Accuracy: 6.56%


Epoch 18/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 4.5647 - Training Accuracy: 20.76%
Step [200/685] - Training Loss: 4.4254 - Training Accuracy: 20.95%
Step [300/685] - Training Loss: 4.4038 - Training Accuracy: 20.77%
Step [400/685] - Training Loss: 4.0970 - Training Accuracy: 20.57%
Step [500/685] - Training Loss: 4.2747 - Training Accuracy: 20.53%
Step [600/685] - Training Loss: 4.3488 - Training Accuracy: 20.42%


Epoch 18/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [18/40] - Training Loss: 4.3524, Training Accuracy: 20.32% - Validation Loss: 4.1377, Validation Accuracy: 7.08%


Epoch 19/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.9783 - Training Accuracy: 22.03%
Step [200/685] - Training Loss: 4.3368 - Training Accuracy: 22.04%
Step [300/685] - Training Loss: 4.4020 - Training Accuracy: 21.98%
Step [400/685] - Training Loss: 4.3665 - Training Accuracy: 21.80%
Step [500/685] - Training Loss: 4.2211 - Training Accuracy: 21.78%
Step [600/685] - Training Loss: 3.7774 - Training Accuracy: 21.76%


Epoch 19/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [19/40] - Training Loss: 4.2354, Training Accuracy: 21.74% - Validation Loss: 4.4352, Validation Accuracy: 6.84%


Epoch 20/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 4.2901 - Training Accuracy: 23.79%
Step [200/685] - Training Loss: 3.5569 - Training Accuracy: 23.37%
Step [300/685] - Training Loss: 4.4184 - Training Accuracy: 23.15%
Step [400/685] - Training Loss: 4.0333 - Training Accuracy: 23.13%
Step [500/685] - Training Loss: 4.0910 - Training Accuracy: 23.16%
Step [600/685] - Training Loss: 3.9474 - Training Accuracy: 23.07%


Epoch 20/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [20/40] - Training Loss: 4.1185, Training Accuracy: 23.02% - Validation Loss: 4.1839, Validation Accuracy: 7.66%


Epoch 21/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.8594 - Training Accuracy: 25.06%
Step [200/685] - Training Loss: 3.6599 - Training Accuracy: 24.71%
Step [300/685] - Training Loss: 4.0577 - Training Accuracy: 24.42%
Step [400/685] - Training Loss: 3.7699 - Training Accuracy: 24.46%
Step [500/685] - Training Loss: 4.0354 - Training Accuracy: 24.44%
Step [600/685] - Training Loss: 4.1638 - Training Accuracy: 24.33%


Epoch 21/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [21/40] - Training Loss: 4.0242, Training Accuracy: 24.27% - Validation Loss: 4.0819, Validation Accuracy: 8.09%


Epoch 22/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.4234 - Training Accuracy: 27.23%
Step [200/685] - Training Loss: 3.7431 - Training Accuracy: 26.72%
Step [300/685] - Training Loss: 4.0295 - Training Accuracy: 26.53%
Step [400/685] - Training Loss: 3.8466 - Training Accuracy: 26.32%
Step [500/685] - Training Loss: 3.8771 - Training Accuracy: 26.27%
Step [600/685] - Training Loss: 3.6276 - Training Accuracy: 26.16%


Epoch 22/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [22/40] - Training Loss: 3.9147, Training Accuracy: 26.13% - Validation Loss: 4.2804, Validation Accuracy: 7.67%


Epoch 23/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.6838 - Training Accuracy: 27.67%
Step [200/685] - Training Loss: 3.5610 - Training Accuracy: 28.11%
Step [300/685] - Training Loss: 3.6085 - Training Accuracy: 28.19%
Step [400/685] - Training Loss: 4.1397 - Training Accuracy: 27.99%
Step [500/685] - Training Loss: 4.0125 - Training Accuracy: 27.87%
Step [600/685] - Training Loss: 3.9097 - Training Accuracy: 27.81%


Epoch 23/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [23/40] - Training Loss: 3.8110, Training Accuracy: 27.71% - Validation Loss: 3.8287, Validation Accuracy: 8.51%


Epoch 24/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.7549 - Training Accuracy: 30.31%
Step [200/685] - Training Loss: 3.5923 - Training Accuracy: 29.82%
Step [300/685] - Training Loss: 3.9532 - Training Accuracy: 29.60%
Step [400/685] - Training Loss: 3.5826 - Training Accuracy: 29.44%
Step [500/685] - Training Loss: 3.5785 - Training Accuracy: 29.39%
Step [600/685] - Training Loss: 3.6985 - Training Accuracy: 29.25%


Epoch 24/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [24/40] - Training Loss: 3.7190, Training Accuracy: 29.24% - Validation Loss: 3.7695, Validation Accuracy: 8.46%


Epoch 25/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.4141 - Training Accuracy: 31.50%
Step [200/685] - Training Loss: 3.9823 - Training Accuracy: 30.86%
Step [300/685] - Training Loss: 3.4738 - Training Accuracy: 30.69%
Step [400/685] - Training Loss: 3.5371 - Training Accuracy: 30.63%
Step [500/685] - Training Loss: 3.6114 - Training Accuracy: 30.49%
Step [600/685] - Training Loss: 3.6810 - Training Accuracy: 30.54%


Epoch 25/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [25/40] - Training Loss: 3.6443, Training Accuracy: 30.49% - Validation Loss: 3.7543, Validation Accuracy: 9.07%


Epoch 26/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.3438 - Training Accuracy: 33.20%
Step [200/685] - Training Loss: 3.3975 - Training Accuracy: 32.71%
Step [300/685] - Training Loss: 3.8462 - Training Accuracy: 32.61%
Step [400/685] - Training Loss: 3.6358 - Training Accuracy: 32.55%
Step [500/685] - Training Loss: 3.6687 - Training Accuracy: 32.44%
Step [600/685] - Training Loss: 3.6830 - Training Accuracy: 32.42%


Epoch 26/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [26/40] - Training Loss: 3.5499, Training Accuracy: 32.30% - Validation Loss: 3.6901, Validation Accuracy: 9.48%


Epoch 27/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.6678 - Training Accuracy: 33.99%
Step [200/685] - Training Loss: 3.6192 - Training Accuracy: 34.11%
Step [300/685] - Training Loss: 3.0815 - Training Accuracy: 34.02%
Step [400/685] - Training Loss: 3.2146 - Training Accuracy: 34.06%
Step [500/685] - Training Loss: 3.4251 - Training Accuracy: 33.94%
Step [600/685] - Training Loss: 3.6897 - Training Accuracy: 33.95%


Epoch 27/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [27/40] - Training Loss: 3.4767, Training Accuracy: 34.01% - Validation Loss: 3.2724, Validation Accuracy: 10.23%


Epoch 28/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.5306 - Training Accuracy: 35.29%
Step [200/685] - Training Loss: 3.4618 - Training Accuracy: 35.80%
Step [300/685] - Training Loss: 3.5080 - Training Accuracy: 35.87%
Step [400/685] - Training Loss: 3.6245 - Training Accuracy: 35.92%
Step [500/685] - Training Loss: 3.2744 - Training Accuracy: 35.86%
Step [600/685] - Training Loss: 3.4511 - Training Accuracy: 36.00%


Epoch 28/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [28/40] - Training Loss: 3.3977, Training Accuracy: 35.82% - Validation Loss: 3.2292, Validation Accuracy: 10.35%


Epoch 29/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.2246 - Training Accuracy: 37.83%
Step [200/685] - Training Loss: 3.3784 - Training Accuracy: 37.42%
Step [300/685] - Training Loss: 3.5478 - Training Accuracy: 37.47%
Step [400/685] - Training Loss: 3.4478 - Training Accuracy: 37.44%
Step [500/685] - Training Loss: 3.2604 - Training Accuracy: 37.38%
Step [600/685] - Training Loss: 3.4485 - Training Accuracy: 37.29%


Epoch 29/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [29/40] - Training Loss: 3.3349, Training Accuracy: 37.34% - Validation Loss: 3.4021, Validation Accuracy: 10.81%


Epoch 30/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.0977 - Training Accuracy: 39.56%
Step [200/685] - Training Loss: 3.5580 - Training Accuracy: 39.09%
Step [300/685] - Training Loss: 3.4419 - Training Accuracy: 38.96%
Step [400/685] - Training Loss: 3.2744 - Training Accuracy: 39.04%
Step [500/685] - Training Loss: 3.2537 - Training Accuracy: 39.00%
Step [600/685] - Training Loss: 3.5476 - Training Accuracy: 38.94%


Epoch 30/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [30/40] - Training Loss: 3.2759, Training Accuracy: 38.93% - Validation Loss: 3.3148, Validation Accuracy: 11.33%


Epoch 31/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.4107 - Training Accuracy: 39.80%
Step [200/685] - Training Loss: 3.2864 - Training Accuracy: 40.46%
Step [300/685] - Training Loss: 3.3182 - Training Accuracy: 40.15%
Step [400/685] - Training Loss: 3.0827 - Training Accuracy: 40.22%
Step [500/685] - Training Loss: 3.2745 - Training Accuracy: 40.34%
Step [600/685] - Training Loss: 3.3113 - Training Accuracy: 40.28%


Epoch 31/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [31/40] - Training Loss: 3.2210, Training Accuracy: 40.35% - Validation Loss: 3.2296, Validation Accuracy: 11.46%


Epoch 32/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.1993 - Training Accuracy: 41.96%
Step [200/685] - Training Loss: 3.3173 - Training Accuracy: 41.58%
Step [300/685] - Training Loss: 3.2147 - Training Accuracy: 41.85%
Step [400/685] - Training Loss: 3.6356 - Training Accuracy: 41.97%
Step [500/685] - Training Loss: 3.1079 - Training Accuracy: 42.00%
Step [600/685] - Training Loss: 3.3706 - Training Accuracy: 42.10%


Epoch 32/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [32/40] - Training Loss: 3.1701, Training Accuracy: 42.11% - Validation Loss: 3.3302, Validation Accuracy: 12.02%


Epoch 33/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.1944 - Training Accuracy: 43.69%
Step [200/685] - Training Loss: 3.1110 - Training Accuracy: 43.65%
Step [300/685] - Training Loss: 3.0720 - Training Accuracy: 43.49%
Step [400/685] - Training Loss: 3.0428 - Training Accuracy: 43.36%
Step [500/685] - Training Loss: 3.5982 - Training Accuracy: 43.38%
Step [600/685] - Training Loss: 3.1281 - Training Accuracy: 43.54%


Epoch 33/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [33/40] - Training Loss: 3.1261, Training Accuracy: 43.58% - Validation Loss: 3.2640, Validation Accuracy: 12.06%


Epoch 34/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.0170 - Training Accuracy: 44.86%
Step [200/685] - Training Loss: 3.2428 - Training Accuracy: 45.38%
Step [300/685] - Training Loss: 3.0078 - Training Accuracy: 45.15%
Step [400/685] - Training Loss: 3.1076 - Training Accuracy: 45.08%
Step [500/685] - Training Loss: 3.0418 - Training Accuracy: 45.18%
Step [600/685] - Training Loss: 3.0315 - Training Accuracy: 45.16%


Epoch 34/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [34/40] - Training Loss: 3.0894, Training Accuracy: 45.09% - Validation Loss: 3.2809, Validation Accuracy: 12.41%


Epoch 35/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.0126 - Training Accuracy: 46.49%
Step [200/685] - Training Loss: 3.1835 - Training Accuracy: 46.23%
Step [300/685] - Training Loss: 2.8957 - Training Accuracy: 46.17%
Step [400/685] - Training Loss: 2.8843 - Training Accuracy: 46.05%
Step [500/685] - Training Loss: 2.9924 - Training Accuracy: 46.16%
Step [600/685] - Training Loss: 3.0543 - Training Accuracy: 46.19%


Epoch 35/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [35/40] - Training Loss: 3.0572, Training Accuracy: 46.27% - Validation Loss: 3.0989, Validation Accuracy: 12.82%


Epoch 36/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.2225 - Training Accuracy: 46.96%
Step [200/685] - Training Loss: 2.9372 - Training Accuracy: 46.83%
Step [300/685] - Training Loss: 2.9709 - Training Accuracy: 46.93%
Step [400/685] - Training Loss: 3.0636 - Training Accuracy: 47.07%
Step [500/685] - Training Loss: 2.9470 - Training Accuracy: 47.13%
Step [600/685] - Training Loss: 3.1220 - Training Accuracy: 47.27%


Epoch 36/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [36/40] - Training Loss: 3.0306, Training Accuracy: 47.28% - Validation Loss: 2.9946, Validation Accuracy: 13.12%


Epoch 37/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 2.8070 - Training Accuracy: 48.45%
Step [200/685] - Training Loss: 2.9360 - Training Accuracy: 48.04%
Step [300/685] - Training Loss: 2.9115 - Training Accuracy: 48.23%
Step [400/685] - Training Loss: 2.9374 - Training Accuracy: 48.19%
Step [500/685] - Training Loss: 2.9929 - Training Accuracy: 48.43%
Step [600/685] - Training Loss: 2.8517 - Training Accuracy: 48.39%


Epoch 37/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [37/40] - Training Loss: 3.0079, Training Accuracy: 48.34% - Validation Loss: 2.7943, Validation Accuracy: 13.20%


Epoch 38/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 3.0811 - Training Accuracy: 49.27%
Step [200/685] - Training Loss: 3.1923 - Training Accuracy: 48.98%
Step [300/685] - Training Loss: 3.0337 - Training Accuracy: 48.72%
Step [400/685] - Training Loss: 3.1966 - Training Accuracy: 48.68%
Step [500/685] - Training Loss: 2.8511 - Training Accuracy: 48.89%
Step [600/685] - Training Loss: 2.6366 - Training Accuracy: 48.91%


Epoch 38/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [38/40] - Training Loss: 2.9918, Training Accuracy: 48.91% - Validation Loss: 2.9813, Validation Accuracy: 13.37%


Epoch 39/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 2.8320 - Training Accuracy: 49.54%
Step [200/685] - Training Loss: 3.4554 - Training Accuracy: 48.98%
Step [300/685] - Training Loss: 2.9493 - Training Accuracy: 48.85%
Step [400/685] - Training Loss: 3.3073 - Training Accuracy: 49.24%
Step [500/685] - Training Loss: 3.2653 - Training Accuracy: 49.37%
Step [600/685] - Training Loss: 2.9312 - Training Accuracy: 49.40%


Epoch 39/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [39/40] - Training Loss: 2.9799, Training Accuracy: 49.49% - Validation Loss: 2.7177, Validation Accuracy: 13.49%


Epoch 40/40 - Training:   0%|          | 0/685 [00:00<?, ?it/s]

Step [100/685] - Training Loss: 2.8940 - Training Accuracy: 49.49%
Step [200/685] - Training Loss: 3.2272 - Training Accuracy: 49.66%
Step [300/685] - Training Loss: 3.0455 - Training Accuracy: 49.49%
Step [400/685] - Training Loss: 2.7973 - Training Accuracy: 49.52%
Step [500/685] - Training Loss: 3.0205 - Training Accuracy: 49.63%
Step [600/685] - Training Loss: 2.9294 - Training Accuracy: 49.72%


Epoch 40/40 - Validation:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch [40/40] - Training Loss: 2.9714, Training Accuracy: 49.70% - Validation Loss: 3.1319, Validation Accuracy: 13.59%
