In [1]:
# Imports and setup
import os
from torch.cuda.amp import GradScaler, autocast
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt
import seaborn as sns
import time

# Inicjalizacja modelu EfficientNet-B1
def initialize_model(device, num_classes=4):
    """
    Inicjalizuje model EfficientNet-B1 z odpowiednią liczbą klas i przenosi go na wybrane urządzenie.
    Args:
        device (torch.device): Urządzenie do trenowania (CPU lub GPU).
        num_classes (int, opcjonalnie): Liczba klas w problemie klasyfikacji. Domyślnie 4.
    Returns:
        model (torch.nn.Module): Zainicjalizowany model EfficientNet-B1.
    """
    weights = EfficientNet_B1_Weights.DEFAULT
    model = efficientnet_b1(weights=weights)
    num_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_features, num_classes)
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        print(f"Let's use {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    return model

class Alaska2Dataset(Dataset):
    """
    Klasa dataset do ładowania danych Alaska2.

    Args:
        filenames (list): Lista ścieżek do plików obrazów.
        labels (list): Lista etykiet odpowiadających plikom obrazów.
        transform (callable, opcjonalnie): Funkcja transformacji stosowana do obrazów.
    """

    def __init__(self, filenames, labels, transform=None):
        self.filenames = filenames
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        image_path = self.filenames[idx]
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

# Loaders with train, val, and test split
def prepare_loaders(base_dir, transform, test_size=0.2, val_size=0.1, random_state=42, batch_size=32, num_workers=4):
    """
    Przygotowuje DataLoadery dla zbiorów trenowania, walidacji i testowania.

    Args:
        base_dir (str): Katalog bazowy z danymi.
        transform (callable): Funkcja transformacji stosowana do obrazów.
        test_size (float, opcjonalnie): Procentowy rozmiar zbioru testowego. Domyślnie 0.2.
        val_size (float, opcjonalnie): Procentowy rozmiar zbioru walidacyjnego. Domyślnie 0.1.
        random_state (int, opcjonalnie): Stan losowy do podziału danych. Domyślnie 42.
        batch_size (int, opcjonalnie): Rozmiar batcha. Domyślnie 32.
        num_workers (int, opcjonalnie): Liczba wątków pracowników do ładowania danych. Domyślnie 4.

    Returns:
        train_loader (DataLoader): DataLoader dla zbioru trenowania.
        val_loader (DataLoader): DataLoader dla zbioru walidacyjnego.
        test_loader (DataLoader): DataLoader dla zbioru testowego.
    """

    
    all_paths, all_labels = get_image_paths_and_labels(base_dir)
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=(test_size + val_size), random_state=random_state
    )
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=(test_size / (test_size + val_size)), random_state=random_state
    )
    train_dataset = Alaska2Dataset(train_paths, train_labels, transform)
    val_dataset = Alaska2Dataset(val_paths, val_labels, transform)
    test_dataset = Alaska2Dataset(test_paths, test_labels, transform)
   
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    print(f"Training labels: {np.unique(train_labels, return_counts=True)}")
    print(f"Validation labels: {np.unique(val_labels, return_counts=True)}")
    print(f"Test labels: {np.unique(test_labels, return_counts=True)}")

    return train_loader, val_loader, test_loader

# Paths and Labels
def get_image_paths_and_labels(base_dir):
    """
    Pobiera ścieżki do obrazów i odpowiadające im etykiety z katalogu bazowego.

    Args:
        base_dir (str): Katalog bazowy z danymi.

    Returns:
        paths (list): Lista ścieżek do obrazów.
        labels (list): Lista etykiet odpowiadających obrazom.
    """

    
    dirs = ['Cover','JMiPOD', "UERD", "JUNIWARD"]
    paths = []
    labels = []
    for idx, folder in enumerate(dirs):
        full_path = os.path.join(base_dir, folder)
        folder_paths = [os.path.join(full_path, file_name) for file_name in os.listdir(full_path) if file_name.endswith('.jpg')]
        paths.extend(folder_paths)
        
        labels.extend([idx] * len(folder_paths))
        
        print(np.unique(labels))
    return paths, labels



# Training the model
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10, patience=10, save_path='steganalysis_model_high_resolution.pt'):
    """
    Trenuje model z wykorzystaniem podanych loaderów danych, kryterium, optymalizatora i urządzenia.

    Args:
        model (torch.nn.Module): Model do trenowania.
        train_loader (DataLoader): DataLoader dla zbioru trenowania.
        val_loader (DataLoader): DataLoader dla zbioru walidacyjnego.
        criterion (torch.nn.Module): Funkcja strat.
        optimizer (torch.optim.Optimizer): Optymalizator.
        device (torch.device): Urządzenie do trenowania (CPU lub GPU).
        num_epochs (int, opcjonalnie): Liczba epok trenowania. Domyślnie 10.
        patience (int, opcjonalnie): Liczba epok do wczesnego zatrzymania bez poprawy. Domyślnie 10.
        save_path (str, opcjonalnie): Ścieżka do zapisu najlepszego modelu. Domyślnie 'steganalysis_model_high_resolution.pt'.
    """
   
    scaler = GradScaler()
    best_loss = np.inf
    patience_counter = 0
    start_time = time.time()
    max_duration = 11 * 3600  # 11 hours in seconds
    
    for epoch in range(num_epochs):
        if time.time() - start_time > max_duration:
            print("Training stopped due to time limit.")
            break
        
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch {epoch + 1}/{num_epochs}, Training Loss: {epoch_loss:.4f}')
        
        val_loss = validate_model(model, val_loader, criterion, device)
        print(f'Epoch {epoch + 1}/{num_epochs}, Validation Loss: {val_loss:.4f}')
        
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("Saving the best model...")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping! No improvement in validation loss for {patience} epochs.')
                break
    
    print("Training completed.")

# Validate the model
def validate_model(model, loader, criterion, device):
    """
    Waliduje model na podanym loaderze danych.

    Args:
        model (torch.nn.Module): Model do walidacji.
        loader (DataLoader): DataLoader dla zbioru walidacyjnego.
        criterion (torch.nn.Module): Funkcja strat.
        device (torch.device): Urządzenie do walidacji (CPU lub GPU).

    Returns:
        loss (float): Średnia strata na zbiorze walidacyjnym.
    """

    
    
    
    model.eval()
    running_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)

    loss = running_loss / total_samples
    return loss

# Confusion matrix plotting
def save_confusion_matrix_plot(true_labels, pred_labels, class_names, file_path):
    """
    Zapisuje wykres macierzy pomyłek do pliku.
    Args:
        true_labels (list): Lista prawdziwych etykiet.
        pred_labels (list): Lista przewidzianych etykiet.
        class_names (list): Lista nazw klas.
        file_path (str): Ścieżka do zapisu wykresu.
    """

    cm = confusion_matrix(true_labels, pred_labels)
    cm = confusion_matrix(true_labels, pred_labels, labels=[0, 1, 2, 3])

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.savefig(file_path)
    plt.close()

def main():
    # Ustawienie urządzenia do trenowania (GPU jeśli dostępne, w przeciwnym razie CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Transformacje stosowane do obrazów
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Katalog bazowy zawierający dane
    base_dir = '/kaggle/input/alaska2-image-steganalysis'
    batch_size = 60  
    num_workers = 4

    # Przygotowanie loaderów danych
    train_loader, val_loader, test_loader = prepare_loaders(base_dir, transform, batch_size=batch_size, num_workers=num_workers)

    # Inicjalizacja modelu
    num_classes = 4
    model = initialize_model(device, num_classes=num_classes)

    # Ustawienia do obliczania strat i optymalizatora
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Trenowanie modelu z dostosowaną szybkością uczenia i cierpliwością
    train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=100, patience=15)
   
    # Ewaluacja na zbiorze testowym
    true_labels = []
    pred_labels = []
    model.eval()  # Ustawienie modelu w trybie ewaluacji
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)  # Uzyskanie przewidywanej klasy
            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())

    # Definiowanie nazw klas według danych
    class_names = ['Cover', 'JMiPOD', 'UERD', 'JUNIWARD']

    # Zapisanie wykresu macierzy pomyłek
    save_confusion_matrix_plot(true_labels, pred_labels, class_names, 'confusion_matrix_test.png')
   
if __name__ == '__main__':
    main()


Using device: cuda
[0]
[0 1]
[0 1 2]
[0 1 2 3]
Training labels: (array([0, 1, 2, 3]), array([52552, 52547, 52403, 52497]))
Validation labels: (array([0, 1, 2, 3]), array([7434, 7425, 7614, 7527]))
Test labels: (array([0, 1, 2, 3]), array([15014, 15028, 14983, 14976]))


Downloading: "https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b1-c27df63c.pth
100%|██████████| 30.1M/30.1M [00:00<00:00, 105MB/s] 


Let's use 2 GPUs!
Epoch 1/100, Training Loss: 1.0329
Epoch 1/100, Validation Loss: 0.9110
Saving the best model...
Epoch 2/100, Training Loss: 0.9281
Epoch 2/100, Validation Loss: 0.8824
Saving the best model...
Epoch 3/100, Training Loss: 0.8927
Epoch 3/100, Validation Loss: 0.8685
Saving the best model...
Epoch 4/100, Training Loss: 0.8655
Epoch 4/100, Validation Loss: 0.8488
Saving the best model...
Epoch 5/100, Training Loss: 0.8420
Epoch 5/100, Validation Loss: 0.8366
Saving the best model...
Epoch 6/100, Training Loss: 0.8224
Epoch 6/100, Validation Loss: 0.8086
Saving the best model...
Epoch 7/100, Training Loss: 0.8052
Epoch 7/100, Validation Loss: 0.7960
Saving the best model...
Epoch 8/100, Training Loss: 0.7876
Epoch 8/100, Validation Loss: 0.8242
Epoch 9/100, Training Loss: 0.7711
Epoch 9/100, Validation Loss: 0.7969
Epoch 10/100, Training Loss: 0.7569
Epoch 10/100, Validation Loss: 0.7845
Saving the best model...
Epoch 11/100, Training Loss: 0.7426
Epoch 11/100, Validation