In [None]:
import gc
import os
import random
from collections import Counter
from datetime import datetime
from pathlib import Path
import h5py
import numpy as np
import pandas as pd
from imblearn.over_sampling import RandomOverSampler
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch.utils.data import DataLoader, TensorDataset, random_split
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models, transforms
from transformers import ViTConfig, ViTFeatureExtractor, ViTModel
from tqdm import tqdm

Notes:

In [None]:
structured_path = Path('all_structured_data.csv')
unstructured_path = Path('all_unstructured_data.csv')
hdf5_path = Path("processed_images_readmission.h5")
pretrained_model_path = Path("./chexpert_models/hybrid_encoder_chexpert_best.pt")

In [None]:
def load_processed_images_as_list(hdf5_path):
    processed_images = []
    dates_list = []
    labels_list = []

    with h5py.File(hdf5_path, 'r') as hf:
        subject_ids = list(hf.keys())  
        for subject_id in tqdm(subject_ids, desc="Loading HDF5 Samples"):
            grp = hf[subject_id]

            if "images" not in grp or "icu_readmission_30d" not in grp:
                continue

            images = grp["images"][:]  # (3, 224, 224, 3)
            dates = []
            for i in range(3):
                key = f"studydate_{i}"
                if key in grp:
                    dates.append(grp[key][()].decode("utf-8"))
                else:
                    dates.append("Unknown")
            label = int(grp["icu_readmission_30d"][()])

            processed_images.append(images)
            dates_list.append(dates)
            labels_list.append(label)

    print(f"Total samples: {len(processed_images)}")
    return processed_images, dates_list, labels_list


In [None]:
processed_images, dates_list, labels_list = load_processed_images_as_list(hdf5_path)

print(processed_images[0].shape)  
print(dates_list[0])           
print(labels_list[0])          

In [None]:
def summarize_final_feature_h5(hdf5_path):
    with h5py.File(hdf5_path, "r") as hf:
        subject_ids = list(hf.keys())
        print(f"Total samples: {len(subject_ids)}")
        
        label_counts = {0: 0, 1: 0, -1: 0}
        for sid in subject_ids:
            if "mortality_icu" in hf[sid]:
                label = int(hf[sid]["mortality_icu"][()])
                if label in label_counts:
                    label_counts[label] += 1
                else:
                    label_counts[-1] += 1 
            else:
                label_counts[-1] += 1 
        
        print(f"Label Distribution：")
        print(f" - Label 0 : {label_counts[0]}")
        print(f" - Label 1 : {label_counts[1]}")
        print(f" - Label -1 : {label_counts[-1]}")


In [None]:
summarize_final_feature_h5(hdf5_path)

In [None]:
class Hybrid3DPatchEmbedding(nn.Module):
    def __init__(self, time_steps=3, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.time_steps = time_steps
        self.patch_size = patch_size

        self.projection = nn.Conv3d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=(1, patch_size, patch_size),
            stride=(1, patch_size, patch_size)
        )

        num_patches = (img_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embedding, std=0.02)

    def forward(self, x):  # x: (B, T, C, H, W)
        B, T, C, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4)  # (B, C, T, H, W)
        x = self.projection(x)       # (B, D, T, H', W')
        D, T_out, H_out, W_out = x.shape[1], x.shape[2], x.shape[3], x.shape[4]
        x = x.permute(0, 2, 3, 4, 1).reshape(B, T_out, H_out * W_out, D)  # (B, T, N, D)
        return x + self.pos_embedding.unsqueeze(1)  # (B, T, N, D)

class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_dim=3072):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.msa = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim)
        )

    def forward(self, x):
        x = x + self.msa(self.ln1(x), self.ln1(x), self.ln1(x))[0]
        x = x + self.mlp(self.ln2(x))
        return x

class ResNetFeatureExtractor(nn.Module):
    def __init__(self, embed_dim=768):
        super().__init__()
        resnet = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        self.proj = nn.Conv2d(2048, embed_dim, kernel_size=1)

    def forward(self, x):  # x: (B, T, C, H, W)
        B, T, C, H, W = x.shape
        res_tokens = []
        for t in range(T):
            xt = x[:, t]  # (B, C, H, W)
            ft = self.backbone(xt)  # (B, 2048, H', W')
            ft = self.proj(ft)     # (B, D, H', W')
            res_tokens.append(ft.flatten(2).transpose(1, 2))  # (B, N, D)
        return torch.stack(res_tokens, dim=1)  # (B, T, N, D)

class VisionTransformerHybrid(nn.Module):
    def __init__(self, time_steps=3, img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_heads=12, num_layers=12, load_pretrained=False, vit_weight_path=None):
        super().__init__()
        self.time_steps = time_steps
        self.patch_embed = Hybrid3DPatchEmbedding(time_steps, img_size, patch_size, in_channels, embed_dim)
        self.transformers = nn.ModuleList([
            nn.Sequential(*[TransformerEncoder(embed_dim, num_heads) for _ in range(num_layers)])
            for _ in range(time_steps)
        ])
        self.resnet = ResNetFeatureExtractor(embed_dim)

        if load_pretrained and vit_weight_path is not None:
            state_dict = torch.load(vit_weight_path, map_location='cpu')
            self.patch_embed.load_state_dict(state_dict, strict=False)
            print("✅ Patch embedding loaded")

    def forward(self, x):  # x: (B, T, C, H, W)
        x_vit = self.patch_embed(x)  # (B, T, N, D)
        B, T, N, D = x_vit.shape
        vit_out = []
        for t in range(T):
            xt = x_vit[:, t]  # (B, N, D)
            vt = self.transformers[t](xt)  # (B, N, D)
            vit_out.append(vt)
        vit_out = torch.stack(vit_out, dim=1)  # (B, T, N, D)

        resnet_out = self.resnet(x)  # (B, T, 49, D)
        fused = torch.cat([vit_out, resnet_out], dim=2)  # ✅ (B, T, 245, D)
        return fused

def extract_vit_features_with_temporal_structure(images, vit_model, batch_size=32, device='cuda'):
    all_features = []
    total = len(images)

    for i in tqdm(range(0, total, batch_size), desc="Extracting Hybrid Features"):
        batch = images[i:i+batch_size]
        batch = np.stack(batch)  # (B, T, H, W, C)
        batch = batch.transpose(0, 1, 4, 2, 3)  # (B, T, C, H, W)
        batch_tensor = torch.tensor(batch, dtype=torch.float32).to(device)

        with torch.no_grad():
            fused_out = vit_model(batch_tensor)  # (B, T, 245, D)
            all_features.append(fused_out.cpu())

    return torch.cat(all_features, dim=0)  # (N, T, 245, D)

In [None]:
def clean_isoformat(d):
    if '.' in d:
        head, frac = d.split('.')
        frac = frac[:6] 
        return f"{head}.{frac}"
    return d

class TimeBiasLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.tensor(1.0))  
        self.c = nn.Parameter(torch.tensor(0.0))  

    def forward(self, R):  # R: (B, T, T)
        return 1 / (1 + torch.exp(self.a * R - self.c))  

class DistanceAwareAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.time_bias = TimeBiasLayer()

    def forward(self, x, R):  # x: (B, T, N, D), R: (B, T, T)
        B, T, N, D = x.shape

        Q = self.q_proj(x).view(B, T, N, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4)
        K = self.k_proj(x).view(B, T, N, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4)
        V = self.v_proj(x).view(B, T, N, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4)

        Q = Q.reshape(B, self.num_heads, T * N, self.head_dim)
        K = K.reshape(B, self.num_heads, T * N, self.head_dim)
        V = V.reshape(B, self.num_heads, T * N, self.head_dim)

        attn_logits = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_logits = torch.relu(attn_logits)

        R_hat = self.time_bias(R)  # (B, T, T)
        R_hat = R_hat.unsqueeze(1).repeat(1, self.num_heads, 1, 1)  # (B, H, T, T)
        R_hat = R_hat.repeat_interleave(N, dim=2).repeat_interleave(N, dim=3)  # (B, H, T*N, T*N)

        attn_weights = attn_logits * R_hat
        attn_probs = torch.softmax(attn_weights, dim=-1)

        context = torch.matmul(attn_probs, V)
        context = context.view(B, self.num_heads, T, N, self.head_dim).permute(0, 2, 3, 1, 4).contiguous()
        context = context.view(B, T, N, D)

        return self.out_proj(context)

class DistanceAwareTemporalTransformer(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, num_layers=4):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(embed_dim),
                DistanceAwareAttention(embed_dim, num_heads),
                nn.LayerNorm(embed_dim),
                nn.Sequential(
                    nn.Linear(embed_dim, embed_dim * 4),
                    nn.ReLU(),
                    nn.Linear(embed_dim * 4, embed_dim)
                )
            ]) for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(embed_dim)
        self.output_layer = nn.Linear(embed_dim, embed_dim)  

    def forward(self, x, R):  # x: (B, T, N, D), R: (B, T, T)
        for ln1, attn, ln2, ff in self.layers:
            x = x + attn(ln1(x), R)
            x = x + ff(ln2(x))
        x = self.final_norm(x)
        return self.output_layer(x)  # (B, T, N, D)

def build_time_distance_matrix(study_dates_batch):
    batch_timestamps = []
    for dates in study_dates_batch:
        ts = [datetime.fromisoformat(clean_isoformat(d)).timestamp() for d in dates]
        Tn = ts[-1]
        distances = [[abs(Tn - t_i) for t_i in ts] for _ in ts]
        batch_timestamps.append(distances)
    return torch.tensor(batch_timestamps, dtype=torch.float32)

def run_temporal_encoder_example(fused_features):
    model = DistanceAwareTemporalTransformer().to("cuda").eval()
    all_outputs = []

    for i in range(len(fused_features)):
        x = fused_features[i].unsqueeze(0).to("cuda")  # (1, 3, 245, 768)
        dates = [dates_list[i]]  # [[t0, t1, t2]]
        R = build_time_distance_matrix(dates).to("cuda")  # (1, 3, 3)

        with torch.no_grad():
            output = model(x, R)  # (1, 3, 245, 768)
        all_outputs.append(output.squeeze(0).cpu())

    return all_outputs  # List of (3, 245, 768)

class ReadmissionPredictor(nn.Module):
    def __init__(self, embed_dim=768):  
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 256), 
            nn.ReLU(),
            nn.Linear(256, 1)  
        )

    def forward(self, x):
        return self.classifier(x).squeeze(1)  

In [None]:
def map_pretrained_keys(pretrained_dict, model_dict):
    new_dict = {}
    missing_keys = []  # List to store missing keys
    loaded_keys_count = 0  # Counter for valid keys loaded
    
    for k, v in pretrained_dict.items():
        new_key = k
        if new_key not in model_dict:
            print(f"Warning: Key '{k}' from the pretrained model does not have a direct match in the model.")
            missing_keys.append(k)  # Track the missing key
            continue
        new_dict[new_key] = v  # If matched, add to the new dictionary
        loaded_keys_count += 1  # Increment counter for valid keys loaded
        
    return new_dict, missing_keys, loaded_keys_count

def map_pretrained_to_current(pretrained_keys, current_keys):
    mapping = {}

    for p_key in pretrained_keys:
        for c_key in current_keys:
            if p_key.replace('backbone.', '') == c_key.replace('backbone.', ''):
                mapping[p_key] = c_key
                break

    return mapping

pretrained_keys = [
    'backbone.patch_embed.pos_embedding', 'backbone.patch_embed.projection.weight',
    'backbone.patch_embed.projection.bias', 'backbone.transformers.0.0.ln1.weight', 'backbone.transformers.0.0.ln1.bias',
    'backbone.transformers.0.0.msa.in_proj_weight', 'backbone.transformers.0.0.msa.in_proj_bias',
    'backbone.transformers.0.0.msa.out_proj.weight', 'backbone.transformers.0.0.msa.out_proj.bias',
    'backbone.transformers.0.0.ln2.weight', 'backbone.transformers.0.0.ln2.bias', 'backbone.transformers.0.0.mlp.0.weight',
    'backbone.transformers.0.0.mlp.0.bias', 'backbone.transformers.0.0.mlp.2.weight', 'backbone.transformers.0.0.mlp.2.bias',
    'backbone.transformers.0.1.ln1.weight', 'backbone.transformers.0.1.ln1.bias', 'backbone.transformers.0.1.msa.in_proj_weight',
    'backbone.transformers.0.1.msa.in_proj_bias', 'backbone.transformers.0.1.msa.out_proj.weight',
    'backbone.transformers.0.1.msa.out_proj.bias', 'backbone.transformers.0.1.ln2.weight', 'backbone.transformers.0.1.ln2.bias',
    'backbone.transformers.0.1.mlp.0.weight', 'backbone.transformers.0.1.mlp.0.bias', 'backbone.transformers.0.1.mlp.2.weight',
    'backbone.transformers.0.1.mlp.2.bias', 'backbone.resnet.backbone.0.weight', 'backbone.resnet.backbone.1.weight',
    'backbone.resnet.backbone.1.bias', 'backbone.resnet.backbone.1.running_mean', 'backbone.resnet.backbone.1.running_var'
]

current_keys = [
    'patch_embed.pos_embedding', 'patch_embed.projection.weight', 'patch_embed.projection.bias',
    'transformers.0.0.ln1.weight', 'transformers.0.0.ln1.bias', 'transformers.0.0.msa.in_proj_weight',
    'transformers.0.0.msa.in_proj_bias', 'transformers.0.0.msa.out_proj.weight', 'transformers.0.0.msa.out_proj.bias',
    'transformers.0.0.ln2.weight', 'transformers.0.0.ln2.bias', 'transformers.0.0.mlp.0.weight', 'transformers.0.0.mlp.0.bias',
    'transformers.0.0.mlp.2.weight', 'transformers.0.0.mlp.2.bias', 'transformers.0.1.ln1.weight', 'transformers.0.1.ln1.bias',
    'transformers.0.1.msa.in_proj_weight', 'transformers.0.1.msa.in_proj_bias', 'transformers.0.1.msa.out_proj.weight',
    'transformers.0.1.msa.out_proj.bias', 'transformers.0.1.ln2.weight', 'transformers.0.1.ln2.bias', 'transformers.0.1.mlp.0.weight',
    'transformers.0.1.mlp.0.bias', 'transformers.0.1.mlp.2.weight', 'transformers.0.1.mlp.2.bias', 'resnet.backbone.0.weight',
    'resnet.backbone.1.weight', 'resnet.backbone.1.bias', 'resnet.backbone.1.running_mean', 'resnet.backbone.1.running_var'
]

mapping_result = map_pretrained_to_current(pretrained_keys, current_keys)

for p_key, c_key in mapping_result.items():
    print(f"Pretrained Key: {p_key} -> Current Key: {c_key}")


def load_pretrained_model(model, pretrained_model_path, device='cuda'):
    state_dict = torch.load(pretrained_model_path, map_location=device)
    
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    
    pretrained_keys = list(state_dict.keys())
    current_keys = list(model.state_dict().keys())

    mapping_result = map_pretrained_to_current(pretrained_keys, current_keys)

    new_state_dict = {}
    missing_keys = []
    loaded_keys_count = 0

    for p_key, c_key in mapping_result.items():
        if c_key in model.state_dict():
            new_state_dict[c_key] = state_dict[p_key]  
            loaded_keys_count += 1
        else:
            missing_keys.append(p_key)  

    model.load_state_dict(new_state_dict, strict=False)


    print(f"✅ Pre-trained model loaded from {pretrained_model_path}")
    if missing_keys:
        print(f"❌ Missing keys: {missing_keys}")
    else:
        print("All keys matched successfully!")
    print(f"🔑 {loaded_keys_count} valid keys were loaded from the pretrained model.")

    model.eval()

    return model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vit_model = VisionTransformerHybrid().to(device)

vit_model = load_pretrained_model(vit_model, pretrained_model_path, device)

In [None]:
vit_model.eval()

In [None]:
# CPU code

In [None]:
def train_model_with_temporal_features(
    processed_images, dates_list, labels_list,
    epochs=300, batch_size=32, lr=1e-4, device='cpu',
    save_path='model_checkpoint.pth',
    enable_temporal_grad=False,
    enable_vit_grad=False,
    use_pretrained_vit=True,
    pretrained_model_path='pretrained.pth'
):
    vit_model = VisionTransformerHybrid().to(device)
    if use_pretrained_vit:
        vit_model = load_pretrained_model(vit_model, pretrained_model_path, device)
    vit_model.train(mode=enable_vit_grad)
    for param in vit_model.parameters():
        param.requires_grad = enable_vit_grad

    all_fused_features = []
    for i in tqdm(range(0, len(processed_images), batch_size), desc="Extracting Hybrid Features"):
        batch_images = processed_images[i:i+batch_size]
        batch = np.stack(batch_images)
        batch = batch.transpose(0, 1, 4, 2, 3)
        batch_tensor = torch.tensor(batch, dtype=torch.float32).to(device)

        with torch.set_grad_enabled(enable_vit_grad):
            fused_out = checkpoint(vit_model, batch_tensor)
            all_fused_features.append(fused_out.cpu())

        del batch_tensor, batch
        gc.collect()

    fused_features = torch.cat(all_fused_features, dim=0)

    temporal_model = DistanceAwareTemporalTransformer().to(device)
    temporal_model.train(mode=enable_temporal_grad)
    for param in temporal_model.parameters():
        param.requires_grad = enable_temporal_grad

    all_temporal_outputs = []
    for i in tqdm(range(len(fused_features)), desc="Temporal Modeling"):
        x = fused_features[i].unsqueeze(0).to(device)
        dates = [dates_list[i]]
        R = build_time_distance_matrix(dates).to(device)

        with torch.set_grad_enabled(enable_temporal_grad):
            temporal_output = temporal_model(x, R)
            all_temporal_outputs.append(temporal_output.squeeze(0).cpu())

        del x, R, temporal_output
        gc.collect()

    temporal_outputs = torch.stack(all_temporal_outputs, dim=0)

    pooled = [x.mean(dim=(0, 1)) for x in temporal_outputs]
    X = torch.stack(pooled)
    y = torch.tensor(labels_list).float()

    ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
    X_resampled, y_resampled = ros.fit_resample(X.numpy(), y.numpy())

    resampled_train_dataset = TensorDataset(
        torch.tensor(X_resampled, dtype=torch.float32),
        torch.tensor(y_resampled, dtype=torch.float32)
    )

    val_set = TensorDataset(X, y)
    test_set = TensorDataset(X, y)

    train_loader = DataLoader(resampled_train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size)
    test_loader = DataLoader(test_set, batch_size=batch_size)

    classifier = ReadmissionPredictor(embed_dim=768).to(device)
    optimizer = torch.optim.AdamW(classifier.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.BCEWithLogitsLoss()

    best_auc = 0
    for epoch in range(epochs):
        classifier.train()
        total_loss = 0
        optimizer.zero_grad()

        for step, (X_batch, y_batch) in enumerate(train_loader):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            logits = classifier(X_batch)
            loss = criterion(logits, y_batch)
            loss.backward()

            if (step + 1) % 4 == 0:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item() * X_batch.size(0)

        scheduler.step()
        avg_loss = total_loss / len(train_loader.dataset)

        classifier.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch = X_batch.to(device)
                logits = classifier(X_batch).sigmoid().cpu()
                all_preds.extend(logits.numpy())
                all_labels.extend(y_batch.cpu().numpy())

        preds_bin = [1 if p > 0.5 else 0 for p in all_preds]
        auc_score = roc_auc_score(all_labels, all_preds)
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, AUC={auc_score:.4f}, F1={f1_score(all_labels, preds_bin):.4f}")
        
        os.makedirs(save_path, exist_ok=True)
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch + 1,
                'classifier_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'auc': auc_score
            }, os.path.join(save_path, f"model_epoch_{epoch+1}.pth"))

        if auc_score > best_auc:
            best_auc = auc_score
            torch.save({
                'epoch': epoch + 1,
                'classifier_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'auc': best_auc
            }, os.path.join(save_path, f"model_best.pth"))

    print("\nBest model saved to:", save_path)

    classifier.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            logits = classifier(X_batch).sigmoid().cpu()
            all_preds.extend(logits.numpy())
            all_labels.extend(y_batch.cpu().numpy())

    preds_bin = [1 if p > 0.5 else 0 for p in all_preds]
    print("\nTest Results:")
    print(f"AUC       : {roc_auc_score(all_labels, all_preds):.4f}")
    print(f"Accuracy  : {accuracy_score(all_labels, preds_bin):.4f}")
    print(f"F1 Score  : {f1_score(all_labels, preds_bin):.4f}")
    print(f"Precision : {precision_score(all_labels, preds_bin):.4f}")
    print(f"Recall    : {recall_score(all_labels, preds_bin):.4f}")

In [None]:
save_path = f"./cxr_models/pretrained_readmission_{datetime.now().strftime('%Y%m%d_%H%M%S')}/"
train_model_with_temporal_features(
    processed_images, dates_list, labels_list,
    epochs=1000, batch_size=32, lr=1e-4, device='cpu',
    save_path=save_path,
    enable_temporal_grad=False,
    enable_vit_grad=False,
    use_pretrained_vit=True,
    pretrained_model_path=pretrained_model_path
)

In [None]:
save_path = f"./cxr_models/readmission_{datetime.now().strftime('%Y%m%d_%H%M%S')}/"
train_model_with_temporal_features(
    processed_images, dates_list, labels_list,
    epochs=1000, batch_size=32, lr=1e-4, device='cpu',
    save_path=save_path,
    enable_temporal_grad=False,
    enable_vit_grad=False,
    use_pretrained_vit=False,
    pretrained_model_path=pretrained_model_path
)