# Libraries used

In [None]:
import warnings
warnings.filterwarnings('ignore')
from torch.utils.data import Dataset, DataLoader, Subset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import (
    Compose,
    ToTensor,
    PILToTensor,
    Normalize,
    CenterCrop,
    RandAugment,
    RandomHorizontalFlip,
    RandomVerticalFlip,
    Resize,
    RandomRotation,
    RandomResizedCrop,
    InterpolationMode
)
from timm.data.transforms import RandomResizedCropAndInterpolation
from tqdm import tqdm
import math
import numpy as np
import sys
import os
import pandas as pd
from mae.models_mae import mae_vit_base_patch16_dec512d8b
from US_data_loading import USImagesDataset
from PIL import Image, ImageDraw, UnidentifiedImageError
import matplotlib.pyplot as plt
import glob
from timm.models.vision_transformer import Block
from sklearn.model_selection import KFold
import random
import csv

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.float = float

# Training the USF-MAE model

In [None]:
def adjust_learning_rate(epoch, sched_config):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch < sched_config['warmup_epochs']:
        lr = sched_config['max_lr'] * epoch / sched_config['warmup_epochs']
    else:
        lr = sched_config['min_lr'] + (sched_config['max_lr'] - sched_config['min_lr']) * 0.5 * \
            (1. + math.cos(math.pi * (epoch - sched_config['warmup_epochs']) / (sched_config['total_epochs'] - sched_config['warmup_epochs'])))
    return lr

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['max_lr']

In [None]:
images_path_parent = "../US Datasets for USF-MAE Training"

# Recursively get all .png files in all subfolders
all_image_paths = glob.glob(os.path.join(images_path_parent, '**', '*.png'), recursive=True)

print(f"Found {len(all_image_paths)} .png images.")

def create_folds(data, n_splits=5):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    folds = list(kf.split(data))
    return folds

folds = create_folds(all_image_paths)


In [None]:
def filter_valid_images(image_paths):
    valid_paths = []
    for path in image_paths:
        try:
            with Image.open(path) as img:
                img.verify()  # Only checks, doesn't decode
            valid_paths.append(path)
        except (OSError, UnidentifiedImageError, SyntaxError) as e:
            print(f"Skipping corrupted file: {path} — {str(e)}")
    return valid_paths

all_image_paths = filter_valid_images(all_image_paths)
dataset = USImagesDataset(image_paths=all_image_paths)
print(f"Total usable images after filtering: {len(dataset)}")

In [None]:
%%time

pretrained_weights = torch.load("mae_checkpoint/mae_pretrain_vit_base_full.pth")["model"]

save_dir = 'mae_US_saved_models'
os.makedirs(save_dir, exist_ok=True)

results = []

log_csv_path = os.path.join(save_dir, "loss_log.csv")
with open(log_csv_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["fold", "lr", "weight_decay", "epochs", "final_train_loss", "val_loss"])

for fold_idx, (train_idx, val_idx) in enumerate(folds):

    print(f"Training fold {fold_idx+1}/{len(folds)}")

    train_dataset = Subset(USImagesDataset(all_image_paths, do_augmentation=True), train_idx)
    val_dataset = Subset(USImagesDataset(all_image_paths, do_augmentation=False), val_idx)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    for base_learning_rate in [3e-4, 1e-3]:
        for weight_decay in [0.01, 0.05]:
            for num_epochs in [100]:
                learning_rate = base_learning_rate * 128 / 256

                model = mae_vit_base_patch16_dec512d8b().cuda()
                model.load_state_dict(pretrained_weights)
                optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

                # Scheduler configuration
                total_epochs = len(train_loader) * num_epochs
                warmup_epochs = int(total_epochs * 0.1)
                sched_config = {
                    "max_lr": learning_rate,
                    "min_lr": 1.0e-5,
                    "total_epochs": total_epochs,
                    "warmup_epochs": warmup_epochs,
                }

                epoch_counter = 0
                for epoch in tqdm(range(num_epochs)):
                    model.train()
                    all_losses = []
                    for images in train_loader:
                        new_learning_rate = adjust_learning_rate(epoch=epoch_counter, sched_config=sched_config)
                        for g in optimizer.param_groups:
                            g["lr"] = new_learning_rate
                        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                            loss, _, _ = model(images.cuda())

                        loss.backward()
                        all_losses.append(loss.item())
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                        optimizer.step()
                        optimizer.zero_grad()
                        epoch_counter += 1

                    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {torch.tensor(all_losses).mean().item()}")

                model.eval()
                val_losses = []
                with torch.no_grad():
                    for images in val_loader:
                        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                            val_loss, _, _ = model(images.cuda())
                        val_losses.append(val_loss.item())

                mean_val_loss = torch.tensor(val_losses).mean().item()
                print(f"Val Loss: {mean_val_loss}, Hyperparameters: {base_learning_rate}, {weight_decay}, {num_epochs}")

                model_filename = f"USF-MAE_fold{fold_idx+1}_lr{base_learning_rate}_wd{weight_decay}_epochs{num_epochs}.pt"
                model_path = os.path.join(save_dir, model_filename)
                torch.save(model.state_dict(), model_path)
                print(f"Model saved at: {model_path}")
                
                # Log results to CSV
                with open(log_csv_path, mode='a', newline='') as file:
                    writer = csv.writer(file)
                    writer.writerow([
                        fold_idx + 1,
                        base_learning_rate,
                        weight_decay,
                        num_epochs,
                        round(torch.tensor(all_losses).mean().item(), 6),
                        round(mean_val_loss, 6)
                    ])
