In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import glob

# Configuration
class Config:
    ROOT_FOLDER = '/content/drive/MyDrive/dataset/TeamDeepwave/dataset/preprocessed/'
    BATCH_SIZE = 128  # Increased batch size
    MODEL_PATH = '/content/drive/MyDrive/dataset/TeamDeepwave/sample_Lee/sample_2500/1.pt'  # Path to your best model

CONFIG = Config()

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Custom Dataset for Mel-spectrogram images
class CustomDataset(Dataset):
    def __init__(self, mel_files, transform=None):
        self.mel_files = mel_files
        self.transform = transform

    def __len__(self):
        return len(self.mel_files)

    def __getitem__(self, idx):
        mel_image = Image.open(self.mel_files[idx]).convert('RGB')
        if self.transform:
            mel_image = self.transform(mel_image)
        return mel_image

# Load file paths for the test dataset
def load_test_file_paths(root_folder):
    test_folder = os.path.join(root_folder, 'test')
    print(f"Looking for .png files in: {test_folder}")
    try:
        mel_files = glob.glob(os.path.join(test_folder, '*.png'))
        if not mel_files:
            raise FileNotFoundError(f"No .png files found in {test_folder}")
    except Exception as e:
        print(f"Error accessing the directory: {e}")
        mel_files = []
    return mel_files

# Data transformations for Mel-spectrogram images
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load test file paths
test_mel_files = load_test_file_paths(CONFIG.ROOT_FOLDER)

# Ensure non-empty loaders
if not test_mel_files:
    raise ValueError("No test files found. Ensure the test directory contains .png files.")

print(f"Loaded {len(test_mel_files)} test files.")

# Create test dataset and loader
test_dataset = CustomDataset(test_mel_files, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=CONFIG.BATCH_SIZE, shuffle=False, num_workers=4)  # Increase num_workers

# Define the CNN model for Mel-spectrogram images (same as the one used during training)
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, output_dim):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(128 * 16 * 16, 256)
        self.fc2 = nn.Linear(256, output_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.gelu(self.conv1(x)))
        x = self.pool(F.gelu(self.conv2(x)))
        x = self.pool(F.gelu(self.conv3(x)))
        x = x.view(-1, 128 * 16 * 16)
        x = F.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Initialize the model and load the state dictionary
model = CNN(output_dim=2).to(device)
model.load_state_dict(torch.load(CONFIG.MODEL_PATH, map_location=device))
model.eval()

# Prediction on test dataset
def predict(model, loader, device):
    model.eval()
    all_predictions = []
    with torch.no_grad():
        for mel in tqdm(loader, desc="Predicting", leave=False):
            mel = mel.to(device)
            outputs = model(mel)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            all_predictions.extend(probs.cpu().numpy())
    return np.array(all_predictions)

# Predict on test data
test_predictions = predict(model, test_loader, device)

# Create a DataFrame for submission
submission_df = pd.DataFrame(test_predictions, columns=['fake', 'real'])

# Extracting IDs from test file paths
test_ids = [os.path.basename(f).replace('.png', '') for f in test_mel_files]
submission_df.insert(0, 'id', test_ids)

# Save to CSV
submission_csv_path = '/content/drive/MyDrive/submission.csv'
submission_df.to_csv(submission_csv_path, index=False)
print(f'Submission file created at {submission_csv_path}!')

Looking for .png files in: /content/drive/MyDrive/dataset/TeamDeepwave/dataset/preprocessed/test
Loaded 50000 test files.


Predicting:   4%|â–Ž         | 14/391 [07:25<2:39:33, 25.39s/it]