In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
from torch.optim import lr_scheduler
import datetime

from classes.OCTADataset import OCTADataset
from classes.Config import Config
from classes.Folds import Folds
from classes.train import train
from classes.test import test
from logger.Logger import Logger, TrainLogger, ValidationLogger, TestLogger
from utils.initialize_model import initialize_model

# Configuration setup
CONFIG = Config("configs/config.py")
MODEL_CONFIG = Config(CONFIG.model_config)

DEVICE = torch.device(CONFIG.device if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")  

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=CONFIG.transform.resize.size, 
                      antialias=CONFIG.transform.resize.antialias),
    transforms.Normalize(mean=CONFIG.transform.normalize.mean, 
                         std=CONFIG.transform.normalize.std)
])

# Full dataset initialization
full_dataset = OCTADataset(
    images_directory_root_path=CONFIG.OCTADataset.images_directory_root_path, 
    labels_directory_root_path=CONFIG.OCTADataset.labels_directory_root_path, 
    transform=transform,
    num_classes=CONFIG.OCTADataset.num_classes
)

# Fold initialization
folds = Folds(dataset=full_dataset, 
              folds_folder_directory_path=CONFIG.k_fold_train.fold_indices_directory_path, 
              class_number=CONFIG.OCTADataset.num_classes)

# Logger initialization
LOGGER = Logger()
timestamp = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")

# K-Fold training and testing loop
for i in range(folds.num_folds()):
    fold_indices = folds.get_fold(i)
    
    # Subset datasets for the current fold
    train_fold_dataset = torch.utils.data.Subset(full_dataset, fold_indices["train"])
    val_fold_dataset = torch.utils.data.Subset(full_dataset, fold_indices["val"])
    test_fold_dataset = torch.utils.data.Subset(full_dataset, fold_indices["test"])
    
    # DataLoader setup
    train_fold_loader = DataLoader(train_fold_dataset, batch_size=CONFIG.k_fold_train.batch_size, shuffle=True)
    val_fold_loader = DataLoader(val_fold_dataset, batch_size=CONFIG.k_fold_train.batch_size, shuffle=False)
    test_fold_loader = DataLoader(test_fold_dataset, batch_size=CONFIG.k_fold_train.batch_size, shuffle=False)
    
    # Create fold-specific directory
    fold_dir = LOGGER.create_fold_dir(fold_number=i)
    
    # Initialize loggers for training, validation, and testing
    train_log = TrainLogger(fold_dir)
    validation_log = ValidationLogger(fold_dir)
    test_log = TestLogger(fold_dir)
    
    # Model, criterion, optimizer, and scheduler setup
    model = initialize_model(MODEL_CONFIG)
    model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), 
                           lr=CONFIG.train.optimizer.lr)
    scheduler = lr_scheduler.StepLR(optimizer, 
                                    step_size=CONFIG.train.scheduler.step_size, 
                                    gamma=CONFIG.train.scheduler.gamma)
    num_epochs = CONFIG.k_fold_train.num_epochs
    print(train_fold_loader)
    # Training process
    train(model, 
          train_fold_loader, 
          val_fold_loader, 
          None, 
          optimizer, 
          criterion, 
          num_epochs, 
          DEVICE, 
          train_log, 
          validation_log)
    
    test(model,
         test_fold_loader,
         DEVICE,
         test_log)
    
    if CONFIG.save_model:
        torch.save(model.state_dict(), fr"{CONFIG.save_model_path}/{timestamp}/{str(MODEL_CONFIG.model.model)}_{CONFIG.OCTADataset.num_classes}_classes_{CONFIG.k_fold_train.num_epochs}_epochs_fold_{i}.pth")
