In [None]:
import torch
import torchvision.transforms as T
from tqdm import tqdm

# Define TTA transforms using functional operations
tta_transforms = [
    lambda x: x,  # Original
    lambda x: T.functional.hflip(x),  # Horizontal flip
    lambda x: T.functional.vflip(x),  # Vertical flip
    lambda x: T.functional.vflip(T.functional.hflip(x)),  # Both flips
    lambda x: T.functional.rotate(x, 90),  # 90 degrees
    lambda x: T.functional.rotate(x, 180),  # 180 degrees
    lambda x: T.functional.rotate(x, 270),  # 270 degrees
    lambda x: T.functional.rotate(T.functional.hflip(x), 90)  # Horizontal flip + 90 degrees
]

def test_evaluate_tta8(model, fold, test_loader, device, verbose=1):
    model.eval()
    preds = []
    with torch.no_grad():
        if verbose == 1:
            progress_bar = tqdm(enumerate(test_loader), total=len(test_loader))
        else:
            progress_bar = enumerate(test_loader)
            
        for batch_idx, (data, target) in progress_bar:
            data, target = data.to(device), target.to(device)
            
            # Apply TTA and get predictions
            batch_preds = []
            for transform in tta_transforms:
                augmented_data = transform(data)
                outputs = model(augmented_data)
                batch_preds.append(torch.softmax(outputs, dim=1))
                
            # Average predictions from all augmentations
            avg_preds = torch.stack(batch_preds).mean(dim=0)
            preds.append(avg_preds)
            
    preds = torch.vstack(preds)
    return preds