# Loading Packages

In [None]:
import os
import torch
from utils.visualize       import visualize_images_from_classes
from utils.datasets        import returnDataLoader, returnFeatureLoader
from utils.model_execution import runFMA, runDLA, runKBA
from utils.analysis        import print_analysis

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.xpu.is_available():
    device = torch.device("xpu")
elif torch.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

# Load Patches

In [None]:
while True:
    data_dir = input("Please enter the dataset directory, for example: ./data/: ") or "./data/combined/"

    if not os.path.isdir(data_dir):
        print(f"Directory '{data_dir}' does not exist. Please try again.\n")
        continue
    print("\nDataset directory confirmed.\n")

    break

classes            = []
artifact_free      = input("Please enter the name of the artifact free folder (e.g., artifact_free): ").strip()
artifact_free_path = os.path.join(data_dir, artifact_free)

if artifact_free and os.path.isdir(artifact_free_path):
    classes.append(artifact_free)
    print(f"Folder '{artifact_free_path}' exists and was added as artifact free.")
else:
    if artifact_free:
        print(f"Folder '{artifact_free_path}' does not exist. Please check the name.")

artifact      = input("Please enter the name of the artifact folder (e.g., artifact): ").strip()
artifact_path = os.path.join(data_dir, artifact)

if artifact and os.path.isdir(artifact_path):
    classes.append(artifact)
    print(f"Folder '{artifact_path}' exists and was added as artifact.")
else:
    if artifact:
        print(f"Folder '{artifact_path}' does not exist. Please check the name.")

if classes:
    print("\nDataset loaded successfully.")
else:
    print("\nNo valid class folders were provided. Defaulting to sample classes.")
    classes = ['artifact_free', 'blur']


# Random visualization of different classes

In [None]:
visualize_images_from_classes(data_dir, classes)

# Prepare data for classification

In [None]:
fma_dataloader = returnDataLoader(data_dir, classes, 16)
print("Total dataset size (samples): ", len(fma_dataloader.dataset))

dla_dataloader = returnDataLoader(data_dir, classes, 64)
print("Total dataset size (samples): ", len(dla_dataloader.dataset))


In [None]:
feature_loader = returnFeatureLoader(data_dir)
print("Total dataset size (samples): ", len(feature_loader.dataset))

# Classify images with different models

### FMA

In [None]:
runFMA(fma_dataloader, device, './models/fma_binary_blur.pth')

### DLA

In [None]:
runDLA(dla_dataloader, device, './models/dla_binary.pth')

### KBA

In [None]:
runKBA(feature_loader, './models/kba_binary.pkl')

# Analysis

In [None]:
print_analysis('./results/fma_results.csv')
print_analysis('./results/dla_results.csv')
print_analysis('./results/kba_results.csv')