In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

# Configuration
class Config:
    ROOT_FOLDER = '/content/drive/MyDrive/dataset/TeamDeepwave/dataset/preprocessed/'
    BATCH_SIZE = 64
    MODEL_PATH = '/content/drive/MyDrive/models/best-model.pt'  # Path to your best model

CONFIG = Config()

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Custom Dataset for Mel-spectrogram images
class CustomDataset(Dataset):
    def __init__(self, mel_files, transform=None):
        self.mel_files = mel_files
        self.transform = transform

    def __len__(self):
        return len(self.mel_files)

    def __getitem__(self, idx):
        mel_image = Image.open(self.mel_files[idx]).convert('RGB')
        if self.transform:
            mel_image = self.transform(mel_image)
        return mel_image

# Load file paths for the test dataset
def load_test_file_paths(root_folder):
    mel_files = glob.glob(os.path.join(root_folder, 'test', '*.png'))
    return mel_files

# Data transformations for Mel-spectrogram images
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load test file paths
test_mel_files = load_test_file_paths(CONFIG.ROOT_FOLDER)

# Create test dataset and loader
test_dataset = CustomDataset(test_mel_files, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=CONFIG.BATCH_SIZE, shuffle=False)

# Load the best model
model = torch.load(CONFIG.MODEL_PATH, map_location=device)
model.eval()

# Prediction on test dataset
def predict(model, loader, device):
    model.eval()
    all_predictions = []
    with torch.no_grad():
        for mel in tqdm(loader, desc="Predicting", leave=False):
            mel = mel.to(device)
            outputs = model(mel)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            all_predictions.extend(probs.cpu().numpy())
    return np.array(all_predictions)

# Predict on test data
test_predictions = predict(model, test_loader, device)

# Create a DataFrame for submission
submission_df = pd.DataFrame(test_predictions, columns=['fake', 'real'])

# Extracting IDs from test file paths
test_ids = [os.path.basename(f).replace('.png', '') for f in test_mel_files]
submission_df.insert(0, 'id', test_ids)

# Save to CSV
submission_df.to_csv('submission.csv', index=False)
print('Submission file created successfully!')