In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, WeightedRandomSampler, SubsetRandomSampler
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import KFold

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class GalaxyDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        
        # Store all labels for sampling weights calculation
        self.labels = []
        for img_name in self.image_files:
            label_str = img_name.split('_')[-1].split('.')[0]
            self.labels.append(int(label_str))
        
        # Calculate class weights
        label_counter = Counter(self.labels)
        self.class_weights = {cls: 1.0/count for cls, count in label_counter.items()}
        
        # Store sample weights for WeightedRandomSampler
        self.sample_weights = [self.class_weights[label] for label in self.labels]
        self.sample_weights = torch.FloatTensor(self.sample_weights)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        label = self.labels[idx]
        
        # Load and preprocess image
        image = Image.open(img_path).convert("RGB")
        image = image.resize((128, 128))
        
        # Data augmentation for training
        if torch.rand(1) > 0.5:  # Random horizontal flip
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
        
        # Convert to numpy array and normalize
        image_array = np.array(image) / 255.0
        
        # Convert to tensor
        image_tensor = torch.FloatTensor(image_array).permute(2, 0, 1)
        
        return image_tensor, label

class GalaxyCNN(nn.Module):
    def __init__(self, num_classes):
        super(GalaxyCNN, self).__init__()
        
        self.features = nn.Sequential(
            # First block
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Second block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Third block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Fourth block
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),  # Reduced dropout
            nn.Linear(512 * 8 * 8, 2048),  # Smaller first dense layer
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),  # Reduced dropout
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0.01):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf

    def __call__(self, val_loss):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score - self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, early_stopping, num_epochs, fold=None):
    model.to(device)
    best_val_acc = 0.0
    train_losses, val_losses = [], []
    fold_str = f" (Fold {fold})" if fold is not None else ""

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}{fold_str}')
        
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            progress_bar.set_postfix({'loss': running_loss / len(train_loader), 'accuracy': 100. * correct / total})
        
        train_losses.append(running_loss / len(train_loader))
        
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        val_loss /= len(val_loader)
        val_accuracy = 100. * val_correct / val_total
        val_losses.append(val_loss)
        
        print(f'\nEpoch {epoch+1}/{num_epochs}{fold_str}: Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')
        
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f"Early stopping triggered{fold_str}")
            break
        
        scheduler.step(val_accuracy)
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            if fold is not None:
                torch.save(model.state_dict(), f'best_model_fold{fold}.pth')
            else:
                torch.save(model.state_dict(), 'best_model.pth')

    if fold is not None:
        plot_loss(train_losses, val_losses, fold)
    else:
        plot_loss(train_losses, val_losses)
        
    return best_val_acc

def plot_loss(train_losses, val_losses, fold=None):
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    title = 'Training & Validation Loss'
    if fold is not None:
        title += f' (Fold {fold})'
    plt.title(title)
    plt.savefig(f'loss_plot{"_fold" + str(fold) if fold is not None else ""}.png')
    plt.close()

def evaluate_model(model, loader, num_classes=10, fold=None):
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc=f'Evaluating{"" if fold is None else " (Fold " + str(fold) + ")"}'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100. * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    
    conf_matrix = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(num_classes), yticklabels=range(num_classes))
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    title = 'Confusion Matrix'
    if fold is not None:
        title += f' (Fold {fold})'
    plt.title(title)
    plt.savefig(f'confusion_matrix{"_fold" + str(fold) if fold is not None else ""}.png')
    plt.close()
    
    print("\nClassification Report:\n", classification_report(all_labels, all_preds, digits=4))

    return accuracy

def k_fold_cross_validation(dataset, num_folds=5, num_classes=10, num_epochs=30):
    # First, check that the dataset is loaded properly
    print(f"Dataset contains {len(dataset)} samples")
    if len(dataset) == 0:
        print("ERROR: Dataset is empty! Check the data directory path.")
        return 0
    
    # Print some sample filenames to verify
    print(f"Sample image filenames: {dataset.image_files[:3]}")
    
    # Initialize k-fold cross validation
    kfold = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    
    # Store accuracy for each fold
    fold_accuracies = []
    
    # Create indices for KFold to use
    indices = list(range(len(dataset)))
    
    # Loop through each fold
    for fold, (train_ids, val_ids) in enumerate(kfold.split(indices)):
        print(f"\n{'='*20} FOLD {fold+1}/{num_folds} {'='*20}")
        
        # Create data samplers for train and validation sets
        train_sampler = SubsetRandomSampler(train_ids)
        val_sampler = SubsetRandomSampler(val_ids)
        
        # Create data loaders for current fold
        train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler)
        val_loader = DataLoader(dataset, batch_size=32, sampler=val_sampler)
        
        # Initialize model, criterion, optimizer, and early stopping for current fold
        model = GalaxyCNN(num_classes=num_classes).to(device)
        
        # Get class weights for criterion from the dataset
        class_weights = torch.FloatTensor([1.0/1081, 1.0/1853, 1.0/2645, 1.0/2027, 1.0/334, 
                                          1.0/2043, 1.0/1829, 1.0/2628, 1.0/1423, 1.0/1873]).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
        early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True, min_lr=1e-6)
        
        # Train and evaluate model for current fold
        train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, early_stopping, num_epochs, fold+1)
        accuracy = evaluate_model(model, val_loader, num_classes, fold+1)
        fold_accuracies.append(accuracy)
        
        # Save model for current fold
        torch.save(model.state_dict(), f"galaxy_cnn_fold{fold+1}.pth")
        
    # Print average and per-fold accuracies
    print("\n" + "="*50)
    print(f"K-Fold Cross-Validation Results for {num_folds} Folds")
    print("="*50)
    print(f"Average Accuracy: {np.mean(fold_accuracies):.2f}% (±{np.std(fold_accuracies):.2f}%)")
    print("\nPer-fold Accuracy:")
    for fold, accuracy in enumerate(fold_accuracies):
        print(f"Fold {fold+1}: {accuracy:.2f}%")
    print("="*50)
    
    # Return average accuracy
    return np.mean(fold_accuracies)

# Dataset path
data_dir = os.path.expanduser("Decals_data/Decals_data_images")

# Check if data directory exists
if not os.path.exists(data_dir):
    print(f"ERROR: Data directory '{data_dir}' does not exist!")
    print("Please check the path and try again.")
    # You could add code to exit here if needed
else:
    print(f"Data directory found: {data_dir}")
    # List first few items to verify
    files = os.listdir(data_dir)[:5]
    print(f"First few files in directory: {files}")

# Create dataset
dataset = GalaxyDataset(root_dir=data_dir)

# Print dataset information
print(f"Dataset loaded with {len(dataset)} images")
print(f"Class distribution: {Counter(dataset.labels)}")

# Perform 5-fold cross-validation
num_classes = 10  # Based on your class weights, seems like 10 classes
num_epochs = 30    # Same as original
k_fold_cross_validation(dataset, num_folds=5, num_classes=num_classes, num_epochs=num_epochs)

# After finding the best fold, you might want to train on the entire dataset
print("\nTraining final model on the entire dataset...")
# Create data loader for full dataset
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize model, criterion, optimizer, and early stopping
model = GalaxyCNN(num_classes=num_classes).to(device)
class_weights = torch.FloatTensor([1.0/1081, 1.0/1853, 1.0/2645, 1.0/2027, 1.0/334, 
                                  1.0/2043, 1.0/1829, 1.0/2628, 1.0/1423, 1.0/1873]).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True, min_lr=1e-6)

# We need a small validation set for the scheduler and early stopping
val_size = int(0.1 * len(dataset))
train_size = len(dataset) - val_size
temp_train_dataset, temp_val_dataset = random_split(dataset, [train_size, val_size])
val_loader = DataLoader(temp_val_dataset, batch_size=32, shuffle=False)

# Train final model
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, early_stopping, num_epochs)

# Save final model
torch.save(model.state_dict(), "galaxy_cnn_final.pth")
print("Final model saved successfully!")

Using device: cuda
Data directory found: Decals_data/Decals_data_images
First few files in directory: ['image_9483_5.png', 'image_15099_8.png', 'image_9226_5.png', 'image_16911_9.png', 'image_3980_2.png']
Dataset loaded with 17736 images
Class distribution: Counter({2: 2645, 7: 2628, 5: 2043, 3: 2027, 9: 1873, 1: 1853, 6: 1829, 8: 1423, 0: 1081, 4: 334})
Dataset contains 17736 samples
Sample image filenames: ['image_9483_5.png', 'image_15099_8.png', 'image_9226_5.png']



Epoch 1/30 (Fold 1): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.67it/s, loss=2.04, accuracy=24.6]



Epoch 1/30 (Fold 1): Validation Loss: 1.8693, Accuracy: 27.79%


Epoch 2/30 (Fold 1): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.71it/s, loss=1.73, accuracy=33.6]



Epoch 2/30 (Fold 1): Validation Loss: 1.5027, Accuracy: 42.02%


Epoch 3/30 (Fold 1): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.70it/s, loss=1.54, accuracy=41.2]



Epoch 3/30 (Fold 1): Validation Loss: 1.5093, Accuracy: 41.26%


Epoch 4/30 (Fold 1): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.75it/s, loss=1.4, accuracy=48.2]



Epoch 4/30 (Fold 1): Validation Loss: 1.4695, Accuracy: 47.35%


Epoch 5/30 (Fold 1): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.72it/s, loss=1.27, accuracy=53.6]



Epoch 5/30 (Fold 1): Validation Loss: 1.2263, Accuracy: 54.17%


Epoch 6/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.68it/s, loss=1.18, accuracy=58]



Epoch 6/30 (Fold 1): Validation Loss: 1.1453, Accuracy: 59.08%


Epoch 7/30 (Fold 1): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.75it/s, loss=1.1, accuracy=60.4]



Epoch 7/30 (Fold 1): Validation Loss: 1.1166, Accuracy: 58.99%


Epoch 8/30 (Fold 1): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.75it/s, loss=1.05, accuracy=63.2]



Epoch 8/30 (Fold 1): Validation Loss: 1.3099, Accuracy: 56.96%
EarlyStopping counter: 1 out of 5


Epoch 9/30 (Fold 1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.66it/s, loss=0.981, accuracy=65.8]



Epoch 9/30 (Fold 1): Validation Loss: 0.9545, Accuracy: 66.12%


Epoch 10/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.56it/s, loss=0.927, accuracy=67.7]



Epoch 10/30 (Fold 1): Validation Loss: 0.9635, Accuracy: 66.54%


Epoch 11/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.67it/s, loss=0.882, accuracy=69.2]



Epoch 11/30 (Fold 1): Validation Loss: 0.9162, Accuracy: 68.66%


Epoch 12/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.62it/s, loss=0.853, accuracy=70.1]



Epoch 12/30 (Fold 1): Validation Loss: 0.8511, Accuracy: 71.98%


Epoch 16/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.86it/s, loss=0.698, accuracy=75.7]



Epoch 16/30 (Fold 1): Validation Loss: 0.8548, Accuracy: 71.28%


Epoch 17/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.89it/s, loss=0.686, accuracy=75.8]



Epoch 17/30 (Fold 1): Validation Loss: 0.9262, Accuracy: 69.67%
EarlyStopping counter: 1 out of 5


Epoch 18/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.87it/s, loss=0.648, accuracy=77.8]



Epoch 18/30 (Fold 1): Validation Loss: 0.9925, Accuracy: 68.57%
EarlyStopping counter: 2 out of 5


Epoch 19/30 (Fold 1): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.82it/s, loss=0.626, accuracy=78]



Epoch 19/30 (Fold 1): Validation Loss: 0.9393, Accuracy: 69.73%
EarlyStopping counter: 3 out of 5


Epoch 20/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.65it/s, loss=0.492, accuracy=82.3]



Epoch 20/30 (Fold 1): Validation Loss: 0.7133, Accuracy: 77.85%


Epoch 21/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.69it/s, loss=0.459, accuracy=83.2]



Epoch 21/30 (Fold 1): Validation Loss: 0.7232, Accuracy: 77.62%


Epoch 22/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.75it/s, loss=0.443, accuracy=83.8]



Epoch 22/30 (Fold 1): Validation Loss: 0.7051, Accuracy: 78.49%


Epoch 23/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.82it/s, loss=0.412, accuracy=84.8]



Epoch 23/30 (Fold 1): Validation Loss: 0.7356, Accuracy: 77.99%
EarlyStopping counter: 1 out of 5


Epoch 24/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.67it/s, loss=0.412, accuracy=84.7]



Epoch 24/30 (Fold 1): Validation Loss: 0.7331, Accuracy: 78.44%
EarlyStopping counter: 2 out of 5


Epoch 25/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.64it/s, loss=0.392, accuracy=85.6]



Epoch 25/30 (Fold 1): Validation Loss: 0.7416, Accuracy: 78.24%
EarlyStopping counter: 3 out of 5


Epoch 26/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.65it/s, loss=0.387, accuracy=85.8]



Epoch 26/30 (Fold 1): Validation Loss: 0.7409, Accuracy: 78.69%
EarlyStopping counter: 4 out of 5


Epoch 27/30 (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.65it/s, loss=0.374, accuracy=86.1]



Epoch 27/30 (Fold 1): Validation Loss: 0.7307, Accuracy: 78.78%
EarlyStopping counter: 5 out of 5
Early stopping triggered (Fold 1)


Evaluating (Fold 1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:12<00:00,  8.78it/s]


Accuracy: 78.55%

Classification Report:
               precision    recall  f1-score   support

           0     0.4813    0.4147    0.4455       217
           1     0.8338    0.8061    0.8197       361
           2     0.8789    0.9194    0.8987       521
           3     0.8380    0.9061    0.8707       394
           4     0.5714    0.6792    0.6207        53
           5     0.8073    0.7844    0.7957       422
           6     0.6675    0.7867    0.7222       347
           7     0.6899    0.5859    0.6336       524
           8     0.8384    0.9228    0.8786       298
           9     0.9016    0.8467    0.8733       411

    accuracy                         0.7855      3548
   macro avg     0.7508    0.7652    0.7559      3548
weighted avg     0.7830    0.7855    0.7824      3548




Epoch 1/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.64it/s, loss=2.07, accuracy=23.9]



Epoch 1/30 (Fold 2): Validation Loss: 1.8383, Accuracy: 29.86%


Epoch 2/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.73it/s, loss=1.77, accuracy=32.9]



Epoch 2/30 (Fold 2): Validation Loss: 1.7291, Accuracy: 29.32%


Epoch 3/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.80it/s, loss=1.61, accuracy=38.5]



Epoch 3/30 (Fold 2): Validation Loss: 3.5280, Accuracy: 20.72%
EarlyStopping counter: 1 out of 5


Epoch 4/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.78it/s, loss=1.51, accuracy=42.5]



Epoch 4/30 (Fold 2): Validation Loss: 2.0738, Accuracy: 24.08%
EarlyStopping counter: 2 out of 5


Epoch 5/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.67it/s, loss=1.37, accuracy=48.8]



Epoch 5/30 (Fold 2): Validation Loss: 1.3052, Accuracy: 50.47%


Epoch 6/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.60it/s, loss=1.28, accuracy=52.7]



Epoch 6/30 (Fold 2): Validation Loss: 1.1601, Accuracy: 57.94%


Epoch 7/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.59it/s, loss=1.18, accuracy=56.9]



Epoch 7/30 (Fold 2): Validation Loss: 1.2118, Accuracy: 58.30%
EarlyStopping counter: 1 out of 5


Epoch 8/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.65it/s, loss=1.11, accuracy=59.9]



Epoch 8/30 (Fold 2): Validation Loss: 1.0595, Accuracy: 59.66%


Epoch 9/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.62it/s, loss=1.06, accuracy=62.2]



Epoch 9/30 (Fold 2): Validation Loss: 1.1550, Accuracy: 57.12%
EarlyStopping counter: 1 out of 5


Epoch 10/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.69it/s, loss=1, accuracy=64.2]



Epoch 10/30 (Fold 2): Validation Loss: 0.9798, Accuracy: 65.44%


Epoch 11/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.68it/s, loss=0.972, accuracy=64.9]



Epoch 11/30 (Fold 2): Validation Loss: 1.1021, Accuracy: 61.60%
EarlyStopping counter: 1 out of 5


Epoch 12/30 (Fold 2): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.71it/s, loss=0.91, accuracy=67.7]



Epoch 12/30 (Fold 2): Validation Loss: 1.2069, Accuracy: 57.63%
EarlyStopping counter: 2 out of 5


Epoch 13/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.69it/s, loss=0.878, accuracy=68.8]



Epoch 13/30 (Fold 2): Validation Loss: 0.9130, Accuracy: 67.01%


Epoch 14/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.70it/s, loss=0.848, accuracy=69.8]



Epoch 14/30 (Fold 2): Validation Loss: 0.8734, Accuracy: 68.62%


Epoch 15/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.74it/s, loss=0.803, accuracy=72.1]



Epoch 15/30 (Fold 2): Validation Loss: 0.8423, Accuracy: 70.82%


Epoch 16/30 (Fold 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.60it/s, loss=0.771, accuracy=73]



Epoch 16/30 (Fold 2): Validation Loss: 0.8795, Accuracy: 70.20%
EarlyStopping counter: 1 out of 5


Epoch 17/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.71it/s, loss=0.744, accuracy=73.7]



Epoch 17/30 (Fold 2): Validation Loss: 0.8066, Accuracy: 72.79%


Epoch 18/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.76it/s, loss=0.722, accuracy=74.8]



Epoch 18/30 (Fold 2): Validation Loss: 0.8319, Accuracy: 72.93%
EarlyStopping counter: 1 out of 5


Epoch 19/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.77it/s, loss=0.669, accuracy=76.1]



Epoch 19/30 (Fold 2): Validation Loss: 0.7945, Accuracy: 74.18%


Epoch 20/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.65it/s, loss=0.654, accuracy=76.7]



Epoch 20/30 (Fold 2): Validation Loss: 0.7491, Accuracy: 74.82%


Epoch 21/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.57it/s, loss=0.603, accuracy=78.4]



Epoch 21/30 (Fold 2): Validation Loss: 0.8302, Accuracy: 70.71%
EarlyStopping counter: 1 out of 5


Epoch 22/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.57it/s, loss=0.598, accuracy=78.4]



Epoch 22/30 (Fold 2): Validation Loss: 0.8339, Accuracy: 71.72%
EarlyStopping counter: 2 out of 5


Epoch 23/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.74it/s, loss=0.556, accuracy=79.9]



Epoch 23/30 (Fold 2): Validation Loss: 0.8394, Accuracy: 74.06%
EarlyStopping counter: 3 out of 5


Epoch 24/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.67it/s, loss=0.536, accuracy=80.4]



Epoch 24/30 (Fold 2): Validation Loss: 0.9048, Accuracy: 70.51%
EarlyStopping counter: 4 out of 5


Epoch 25/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.60it/s, loss=0.437, accuracy=83.8]



Epoch 25/30 (Fold 2): Validation Loss: 0.7082, Accuracy: 77.47%


Epoch 26/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:07<00:00,  6.62it/s, loss=0.378, accuracy=85.9]



Epoch 26/30 (Fold 2): Validation Loss: 0.7106, Accuracy: 77.30%


Epoch 27/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.68it/s, loss=0.369, accuracy=86.1]



Epoch 27/30 (Fold 2): Validation Loss: 0.7227, Accuracy: 77.11%
EarlyStopping counter: 1 out of 5


Epoch 28/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.66it/s, loss=0.351, accuracy=86.5]



Epoch 28/30 (Fold 2): Validation Loss: 0.7158, Accuracy: 78.26%


Epoch 29/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.85it/s, loss=0.339, accuracy=86.8]



Epoch 29/30 (Fold 2): Validation Loss: 0.7190, Accuracy: 78.43%


Epoch 30/30 (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.73it/s, loss=0.329, accuracy=87.5]



Epoch 30/30 (Fold 2): Validation Loss: 0.7308, Accuracy: 78.32%
EarlyStopping counter: 1 out of 5


Evaluating (Fold 2): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:12<00:00,  8.82it/s]


Accuracy: 78.63%

Classification Report:
               precision    recall  f1-score   support

           0     0.4314    0.3777    0.4027       233
           1     0.8218    0.7977    0.8095       341
           2     0.8854    0.9021    0.8937       531
           3     0.8881    0.9315    0.9093       409
           4     0.5904    0.7903    0.6759        62
           5     0.7884    0.7939    0.7911       427
           6     0.6730    0.7696    0.7181       369
           7     0.7164    0.6065    0.6569       554
           8     0.8721    0.8929    0.8824       252
           9     0.8842    0.9106    0.8972       369

    accuracy                         0.7863      3547
   macro avg     0.7551    0.7773    0.7637      3547
weighted avg     0.7834    0.7863    0.7833      3547




Epoch 1/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.73it/s, loss=2.03, accuracy=25.5]



Epoch 1/30 (Fold 3): Validation Loss: 1.7674, Accuracy: 31.63%


Epoch 2/30 (Fold 3): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.72it/s, loss=1.7, accuracy=34.6]



Epoch 2/30 (Fold 3): Validation Loss: 1.6470, Accuracy: 36.93%


Epoch 3/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.68it/s, loss=1.53, accuracy=41.7]



Epoch 3/30 (Fold 3): Validation Loss: 1.6042, Accuracy: 38.17%


Epoch 4/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.66it/s, loss=1.37, accuracy=48.7]



Epoch 4/30 (Fold 3): Validation Loss: 1.2606, Accuracy: 53.43%


Epoch 5/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.77it/s, loss=1.25, accuracy=54.4]



Epoch 5/30 (Fold 3): Validation Loss: 1.1345, Accuracy: 57.88%


Epoch 6/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.77it/s, loss=1.16, accuracy=58.2]



Epoch 6/30 (Fold 3): Validation Loss: 1.1297, Accuracy: 61.40%


Epoch 7/30 (Fold 3): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.68it/s, loss=1.1, accuracy=60.7]



Epoch 7/30 (Fold 3): Validation Loss: 1.0812, Accuracy: 64.28%


Epoch 8/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.70it/s, loss=1.04, accuracy=62.9]



Epoch 8/30 (Fold 3): Validation Loss: 1.0493, Accuracy: 64.03%


Epoch 9/30 (Fold 3): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.73it/s, loss=0.976, accuracy=66]



Epoch 9/30 (Fold 3): Validation Loss: 0.9591, Accuracy: 69.33%


Epoch 10/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.69it/s, loss=0.933, accuracy=67]



Epoch 10/30 (Fold 3): Validation Loss: 0.8720, Accuracy: 69.69%


Epoch 11/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.84it/s, loss=0.859, accuracy=70.1]



Epoch 11/30 (Fold 3): Validation Loss: 1.0405, Accuracy: 62.67%
EarlyStopping counter: 1 out of 5


Epoch 12/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.87it/s, loss=0.837, accuracy=70.6]



Epoch 12/30 (Fold 3): Validation Loss: 1.0849, Accuracy: 60.30%
EarlyStopping counter: 2 out of 5


Epoch 13/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.87it/s, loss=0.791, accuracy=72.5]



Epoch 13/30 (Fold 3): Validation Loss: 0.9249, Accuracy: 67.16%
EarlyStopping counter: 3 out of 5


Epoch 14/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.88it/s, loss=0.749, accuracy=73.4]



Epoch 14/30 (Fold 3): Validation Loss: 0.8103, Accuracy: 73.30%


Epoch 15/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.85it/s, loss=0.741, accuracy=74.7]



Epoch 15/30 (Fold 3): Validation Loss: 0.8686, Accuracy: 72.43%
EarlyStopping counter: 1 out of 5


Epoch 16/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.81it/s, loss=0.698, accuracy=75.9]



Epoch 16/30 (Fold 3): Validation Loss: 0.9648, Accuracy: 68.85%
EarlyStopping counter: 2 out of 5


Epoch 17/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.70it/s, loss=0.662, accuracy=76.5]



Epoch 17/30 (Fold 3): Validation Loss: 0.7635, Accuracy: 76.26%


Epoch 18/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.71it/s, loss=0.625, accuracy=78]



Epoch 18/30 (Fold 3): Validation Loss: 0.7554, Accuracy: 74.77%


Epoch 19/30 (Fold 3): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.69it/s, loss=0.6, accuracy=78.4]



Epoch 19/30 (Fold 3): Validation Loss: 0.7942, Accuracy: 74.40%
EarlyStopping counter: 1 out of 5


Epoch 20/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.72it/s, loss=0.574, accuracy=79.6]



Epoch 20/30 (Fold 3): Validation Loss: 0.7632, Accuracy: 74.80%


Epoch 21/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.78it/s, loss=0.562, accuracy=80.3]



Epoch 21/30 (Fold 3): Validation Loss: 0.7861, Accuracy: 74.63%
EarlyStopping counter: 1 out of 5


Epoch 22/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:05<00:00,  6.73it/s, loss=0.432, accuracy=84.3]



Epoch 22/30 (Fold 3): Validation Loss: 0.7033, Accuracy: 78.52%


Epoch 23/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:06<00:00,  6.73it/s, loss=0.397, accuracy=85.4]



Epoch 23/30 (Fold 3): Validation Loss: 0.7097, Accuracy: 79.19%


Epoch 24/30 (Fold 3): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [01:04<00:00,  6.87it/s, loss=0.381, accuracy=86.3]



Epoch 24/30 (Fold 3): Validation Loss: 0.6866, Accuracy: 79.56%


Epoch 25/30 (Fold 3):  25%|█████████████████████████                                                                           | 111/444 [00:16<00:48,  6.92it/s, loss=0.0918, accuracy=86.6]