In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use('Agg') # Use Agg backend for non-interactive plotting
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.model_selection import ParameterGrid, train_test_split
import os
from datetime import datetime
import numpy as np
from torch.nn.utils.rnn import pad_sequence

# Ensure compatibility with loading .sav files
if hasattr(np, 'core') and hasattr(np.core, 'multiarray') and hasattr(np.core.multiarray, 'scalar'):
    torch.serialization.add_safe_globals([np.core.multiarray.scalar])
elif hasattr(np, '_core') and hasattr(np._core, 'multiarray') and hasattr(np._core.multiarray, 'scalar'):
    torch.serialization.add_safe_globals([np._core.multiarray.scalar])

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Model Class Definitions ---
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_size, 1)

    def forward(self, gru_out):
        attn_weights = torch.softmax(self.attn(gru_out), dim=1)
        weighted = gru_out * attn_weights
        return weighted.sum(dim=1)

class DenseBlock(nn.Module):
    def __init__(self, input_size, growth_rate, num_layers, dropout_rate=0.3):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(
                nn.Sequential(
                    nn.LayerNorm(input_size + i * growth_rate),
                    nn.Linear(input_size + i * growth_rate, growth_rate * 4),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(growth_rate * 4, growth_rate),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate)
                )
            )
        self.num_layers = num_layers
        self.growth_rate = growth_rate

    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_features = layer(torch.cat(features, dim=-1))
            features.append(new_features)
        return torch.cat(features, dim=-1)

class TransitionLayer(nn.Module):
    def __init__(self, input_size, output_size, dropout_rate=0.3):
        super(TransitionLayer, self).__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(input_size),
            nn.Linear(input_size, output_size),
            nn.ReLU(),
            nn.Dropout(dropout_rate))

    def forward(self, x):
        return self.net(x)

class EnhancedKhmerDigitGRU(nn.Module):
    def __init__(self, input_size=2, hidden_size=256, num_layers=4, 
                 output_size=10, dropout_rate=0.3, gmm_components=10,
                 growth_rate=32, dense_layers=4):
        super(EnhancedKhmerDigitGRU, self).__init__()
        self.gmm_components = gmm_components
        self.input_feature_size = input_size
        
        self.initial_dense = DenseBlock(self.input_feature_size + 2, growth_rate, dense_layers, dropout_rate)
        current_size = self.input_feature_size + 2 + growth_rate * dense_layers
        
        self.transition1 = TransitionLayer(current_size, hidden_size // 2, dropout_rate)
        current_size = hidden_size // 2
        
        self.gru = nn.GRU(current_size, hidden_size, num_layers, 
                         batch_first=True, dropout=dropout_rate, bidirectional=True)
        
        self.attention = Attention(hidden_size * 2)
        
        self.final_dense = DenseBlock(hidden_size * 2, growth_rate, dense_layers, dropout_rate)
        final_size = hidden_size * 2 + growth_rate * dense_layers
        
        self.fc_class = nn.Sequential(
            nn.LayerNorm(final_size),
            nn.Linear(final_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, output_size)
        )
        
        self.fc_gmm = nn.Sequential(
            nn.LayerNorm(hidden_size * 2), 
            nn.Linear(hidden_size * 2, gmm_components * 5)
        )
        
        self.dropout = nn.Dropout(dropout_rate)
        self.softmax = nn.Softmax(dim=-1)

    def preprocess_input(self, x): 
        batch_size, seq_len, current_input_features = x.shape
        delta = torch.zeros(batch_size, seq_len, 2, device=x.device) 
        if seq_len > 1:
            delta[:, 1:, 0] = x[:, 1:, 0] - x[:, :-1, 0]
            delta[:, 1:, 1] = x[:, 1:, 1] - x[:, :-1, 1]
        return torch.cat([x, delta], dim=-1)

    def gmm_loss(self, gmm_params, targets):
        batch_size, seq_len_gmm_out, _ = gmm_params.shape
        _, seq_len_targets, _ = targets.shape 

        if seq_len_gmm_out == 0 or seq_len_targets == 0:
             return torch.tensor(0.0, device=gmm_params.device, requires_grad=True)

        M = self.gmm_components
        gmm_params_reshaped = gmm_params.view(batch_size, seq_len_gmm_out, M, 5)

        pi_logits = gmm_params_reshaped[..., 0]
        mu_x = gmm_params_reshaped[..., 1]
        mu_y = gmm_params_reshaped[..., 2]
        log_sigma_x = gmm_params_reshaped[..., 3]
        log_sigma_y = gmm_params_reshaped[..., 4]

        pi = self.softmax(pi_logits)
        sigma_x = torch.exp(log_sigma_x)
        sigma_y = torch.exp(log_sigma_y)

        targets_expanded = targets.unsqueeze(2)

        x_term_exp = ((targets_expanded[..., 0] - mu_x) / (sigma_x + 1e-8))**2
        y_term_exp = ((targets_expanded[..., 1] - mu_y) / (sigma_y + 1e-8))**2
        
        log_pdf_x = -0.5 * x_term_exp - log_sigma_x - 0.5 * torch.log(torch.tensor(2 * torch.pi, device=gmm_params.device))
        log_pdf_y = -0.5 * y_term_exp - log_sigma_y - 0.5 * torch.log(torch.tensor(2 * torch.pi, device=gmm_params.device))
        
        log_likelihood_components = log_pdf_x + log_pdf_y
        
        log_pi = torch.log(pi + 1e-10)
        sum_terms = log_pi + log_likelihood_components
        
        log_prob_sum = torch.logsumexp(sum_terms, dim=2)
        
        return -log_prob_sum.mean()

    def forward(self, x, lengths=None):
        x_aug = self.preprocess_input(x)
        x_dense = self.initial_dense(x_aug)
        x_trans = self.transition1(x_dense)
        
        gru_out, _ = self.gru(x_trans)
        gru_out_dp = self.dropout(gru_out)
        
        attention_out = self.attention(gru_out_dp)
        
        final_features_input = attention_out.unsqueeze(1)
        final_features = self.final_dense(final_features_input).squeeze(1)
        
        class_out = self.fc_class(final_features)
        gmm_out = self.fc_gmm(gru_out_dp)
            
        return class_out, gmm_out

# --- Dataset and Data Loading ---
class CustomDataset(Dataset):
    def __init__(self, data_sequences, labels):
        self.data_sequences = data_sequences
        self.labels = labels

    def __len__(self):
        return len(self.data_sequences)

    def __getitem__(self, idx):
        return self.data_sequences[idx], self.labels[idx]

def collate_fn(batch):
    sequences = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long)
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0.0)
    labels = torch.tensor(labels, dtype=torch.long)
    return padded_sequences, labels, lengths

# --- Training Function ---
def train_model(model, train_loader, val_loader, criterion_class, optimizer, device, epochs, gmm_weight=0.1, results_dir="results"):
    model.to(device)
    train_losses_epoch, val_losses_epoch = [], []
    best_val_loss = float('inf')
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [T]", leave=False)
        for inputs, labels, lengths in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            class_out, gmm_out = model(inputs, lengths=lengths)
            loss_class = criterion_class(class_out, labels)
            loss_gmm = torch.tensor(0.0, device=device)
            if inputs.shape[1] > 1:
                actual_deltas = torch.zeros(inputs.shape[0], inputs.shape[1] - 1, 2, device=device)
                actual_deltas[..., 0] = inputs[:, 1:, 0] - inputs[:, :-1, 0]
                actual_deltas[..., 1] = inputs[:, 1:, 1] - inputs[:, :-1, 1]
                gmm_out_for_loss = gmm_out[:, :-1, :] 
                if gmm_out_for_loss.shape[1] == actual_deltas.shape[1] and gmm_out_for_loss.shape[1] > 0:
                    loss_gmm = model.gmm_loss(gmm_out_for_loss, actual_deltas)
            
            loss = loss_class + gmm_weight * loss_gmm
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item(), cls=loss_class.item(), gmm=loss_gmm.item())

        avg_train_loss = running_loss / len(train_loader)
        train_losses_epoch.append(avg_train_loss)

        model.eval()
        val_running_loss = 0.0
        val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [V]", leave=False)
        with torch.no_grad():
            for inputs, labels, lengths in val_progress_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                class_out, gmm_out_val = model(inputs, lengths=lengths)
                loss_class_val = criterion_class(class_out, labels)
                loss_gmm_val = torch.tensor(0.0, device=device)
                if inputs.shape[1] > 1:
                    actual_deltas_val = torch.zeros(inputs.shape[0], inputs.shape[1] - 1, 2, device=device)
                    actual_deltas_val[..., 0] = inputs[:, 1:, 0] - inputs[:, :-1, 0]
                    actual_deltas_val[..., 1] = inputs[:, 1:, 1] - inputs[:, :-1, 1]
                    gmm_out_val_for_loss = gmm_out_val[:, :-1, :]
                    if gmm_out_val_for_loss.shape[1] == actual_deltas_val.shape[1] and gmm_out_val_for_loss.shape[1] > 0:
                        loss_gmm_val = model.gmm_loss(gmm_out_val_for_loss, actual_deltas_val)
                current_val_sample_loss = loss_class_val + gmm_weight * loss_gmm_val
                val_running_loss += current_val_sample_loss.item()
                val_progress_bar.set_postfix(val_loss=current_val_sample_loss.item())
        avg_val_loss = val_running_loss / len(val_loader)
        val_losses_epoch.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            print(f"✅ New best validation loss: {avg_val_loss:.4f}.")
    if best_model_state:
        model.load_state_dict(best_model_state)
        print("Loaded best model state based on validation loss.")
    return train_losses_epoch, val_losses_epoch, model

# --- Evaluation Function ---
def evaluate_model(model, test_loader, device, verbose=True):
    model.eval()
    model.to(device)
    all_preds, all_labels_list = [], []
    with torch.no_grad():
        for inputs, labels, lengths in tqdm(test_loader, desc="Evaluating Test Set", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            class_out, _ = model(inputs, lengths=lengths)
            preds = torch.argmax(class_out, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels_list.extend(labels.cpu().numpy())
    metrics = {
        'accuracy': accuracy_score(all_labels_list, all_preds),
        'precision': precision_score(all_labels_list, all_preds, average="macro", zero_division=0),
        'recall': recall_score(all_labels_list, all_preds, average="macro", zero_division=0),
        'f1': f1_score(all_labels_list, all_preds, average="macro", zero_division=0)
    }
    if verbose:
        print(f"Test Accuracy: {metrics['accuracy']:.4f}")
        print(f"Test Precision: {metrics['precision']:.4f}")
        print(f"Test Recall: {metrics['recall']:.4f}")
        print(f"Test F1-score: {metrics['f1']:.4f}")
    return metrics

# --- Save Model Metadata ---
def save_model_metadata(model_path, params, metrics, results_dir):
    metadata_filename = os.path.splitext(os.path.basename(model_path))[0] + ".txt"
    with open(os.path.join(results_dir, metadata_filename), "w") as f:
        f.write(f"Parameters:\n{params}\n\n")
        f.write(f"Metrics:\n")
        for k, v in metrics.items():
            f.write(f"  {k}: {v}\n")
    print(f"✅ Saved model metadata to {metadata_filename}")

# --- Main Execution ---
if __name__ == "__main__":
    # Correct path to the .sav file
    data_path = "/kaggle/input/fortest/normalization_augmentation.sav"
    print(f"Loading data from {data_path}...")
    
    try:
        all_sequences_raw = torch.load(data_path, map_location=torch.device('cpu'), weights_only=False)
        print(f"Successfully loaded data. Number of samples: {len(all_sequences_raw)}")
    except Exception as e:
        print(f"Error loading data: {e}")
        raise

    # Class distribution from the dataset
    class_counts = [2444, 1898, 1924, 2249, 1911, 1846, 2145, 2002, 1859, 2301]
    all_labels_list_generated = []
    for i, count in enumerate(class_counts):
        all_labels_list_generated.extend([i] * count)
    
    if len(all_labels_list_generated) != len(all_sequences_raw):
        raise ValueError(f"Mismatch in number of labels ({len(all_labels_list_generated)}) and sequences ({len(all_sequences_raw)})")

    # Process sequences and labels
    final_sequences = []
    final_labels = []
    
    for i, sample_entry in enumerate(tqdm(all_sequences_raw, desc="Processing sequences and labels")):
        if not (isinstance(sample_entry, list) and len(sample_entry) >= 2):
            continue

        list_of_normalized_strokes = sample_entry[1]

        if not isinstance(list_of_normalized_strokes, list):
            continue

        flat_sequence_points_for_digit = []
        for stroke_as_list_of_points in list_of_normalized_strokes:
            if not isinstance(stroke_as_list_of_points, list):
                continue
            
            for point_tuple in stroke_as_list_of_points:
                if isinstance(point_tuple, (tuple, list)) and len(point_tuple) == 2:
                    try:
                        x_coord = float(point_tuple[0])
                        y_coord = float(point_tuple[1])
                        flat_sequence_points_for_digit.append((x_coord, y_coord))
                    except (ValueError, TypeError, IndexError):
                        continue
        
        if flat_sequence_points_for_digit:
            final_sequences.append(torch.tensor(flat_sequence_points_for_digit, dtype=torch.float32))
            final_labels.append(all_labels_list_generated[i])

    print(f"Processed {len(final_sequences)} sequences with {len(final_labels)} labels.")
    if not final_sequences:
        raise ValueError("No sequences were successfully processed.")

    # Split data
    train_sequences, temp_sequences, train_labels, temp_labels = train_test_split(
        final_sequences, final_labels, test_size=0.2, stratify=final_labels, random_state=42
    )
    val_sequences, test_sequences, val_labels, test_labels = train_test_split(
        temp_sequences, temp_labels, test_size=0.5, stratify=temp_labels, random_state=42
    )
    print(f"Train: {len(train_sequences)}, Val: {len(val_sequences)}, Test: {len(test_sequences)}")

    # Create datasets and dataloaders
    train_dataset = CustomDataset(train_sequences, train_labels)
    val_dataset = CustomDataset(val_sequences, val_labels)
    test_dataset = CustomDataset(test_sequences, test_labels)

    batch_size = 16
    num_workers_dl = 0

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                            collate_fn=collate_fn, pin_memory=True if device.type == 'cuda' else False, 
                            num_workers=num_workers_dl)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                          collate_fn=collate_fn, pin_memory=True if device.type == 'cuda' else False, 
                          num_workers=num_workers_dl)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                           collate_fn=collate_fn, pin_memory=True if device.type == 'cuda' else False, 
                           num_workers=num_workers_dl)
    
    # Model parameters
    current_params = {
        "lr": 0.001,
        "epochs": 10,  # Increased from 1 to get better results
        "gmm_weight": 0.05,
        "growth_rate": 32,
        "dense_layers": 3,
        "hidden_size": 128,
        "num_gru_layers": 2 
    }
    
    # Initialize model
    model = EnhancedKhmerDigitGRU(
        input_size=2,
        hidden_size=current_params['hidden_size'],
        num_layers=current_params['num_gru_layers'],
        output_size=10,
        dropout_rate=0.3,
        gmm_components=10,
        growth_rate=current_params['growth_rate'],
        dense_layers=current_params['dense_layers']
    )
    
    criterion_class = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=current_params['lr'])
    
    # Create results directory
    results_dir = "khmer_model_results_optimized"
    os.makedirs(results_dir, exist_ok=True)
    
    # Train model
    print(f"Starting training with params: {current_params}")
    train_losses, val_losses, best_trained_model = train_model(
        model, train_loader, val_loader, criterion_class, optimizer, device,
        current_params['epochs'], current_params['gmm_weight'], results_dir=results_dir
    )
    
    # Evaluate model
    print("Training finished. Evaluating on test set...")
    final_metrics = evaluate_model(best_trained_model, test_loader, device, verbose=True)
    
    # Save model and metadata
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    final_model_filename = f"best_model_val_loss_{timestamp}.pt"
    final_model_path = os.path.join(results_dir, final_model_filename)
    torch.save(best_trained_model.state_dict(), final_model_path)
    print(f"✅ Saved final best model to {final_model_path}")
    save_model_metadata(final_model_path, current_params, final_metrics, results_dir)

    # Plot training curves
    if current_params['epochs'] > 0 and train_losses and val_losses:
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
        plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Losses')
        plt.legend()
        plt.grid(True)
        loss_plot_path = os.path.join(results_dir, f"loss_plot_{timestamp}.png")
        plt.savefig(loss_plot_path)
        print(f"Saved loss plot to {loss_plot_path}")
    else:
        print("Skipping loss plot generation as epochs was 0 or no losses recorded.")

    # Print final results
    print("\n🏆 Final Optimized Model Results:")
    print(f"Parameters: {current_params}")
    if val_losses:
        print(f"Best Validation Loss achieved: {min(val_losses):.4f}")
    print("Test Metrics:")
    for k, v in final_metrics.items():
        print(f"  {k}: {v:.4f}")
    print(f"\nAll results saved in ./{results_dir}")

Using device: cpu
Loading data from /kaggle/input/fortest/normalization_augmentation.sav...
Successfully loaded data. Number of samples: 20579


Processing sequences and labels: 100%|██████████| 20579/20579 [00:02<00:00, 7112.27it/s]


Processed 20579 sequences with 20579 labels.
Train: 16463, Val: 2058, Test: 2058
Starting training with params: {'lr': 0.001, 'epochs': 10, 'gmm_weight': 0.05, 'growth_rate': 32, 'dense_layers': 3, 'hidden_size': 128, 'num_gru_layers': 2}


                                                                                

Epoch 1/10, Train Loss: 1.9217, Val Loss: 1.7412
✅ New best validation loss: 1.7412.


                                                                                

Epoch 2/10, Train Loss: 1.5803, Val Loss: 1.4246
✅ New best validation loss: 1.4246.


                                                                               

Epoch 3/10, Train Loss: 1.4166, Val Loss: 1.2912
✅ New best validation loss: 1.2912.


                                                                               

Epoch 4/10, Train Loss: 1.2663, Val Loss: 1.2062
✅ New best validation loss: 1.2062.


                                                                                 

Epoch 5/10, Train Loss: 1.1413, Val Loss: 0.9192
✅ New best validation loss: 0.9192.


                                                                               

Epoch 6/10, Train Loss: 1.0432, Val Loss: 1.1910


                                                                                 

Epoch 7/10, Train Loss: 0.9415, Val Loss: 0.9825


                                                                                

Epoch 8/10, Train Loss: 0.7994, Val Loss: 5.0254


                                                                                 

Epoch 9/10, Train Loss: 0.7104, Val Loss: 0.4835
✅ New best validation loss: 0.4835.


                                                                                  

Epoch 10/10, Train Loss: 0.5856, Val Loss: 0.5323
Loaded best model state based on validation loss.
Training finished. Evaluating on test set...


                                                                      

Test Accuracy: 0.7143
Test Precision: 0.7212
Test Recall: 0.7027
Test F1-score: 0.7022
✅ Saved final best model to khmer_model_results_optimized/best_model_val_loss_20250507_150309.pt
✅ Saved model metadata to best_model_val_loss_20250507_150309.txt
Saved loss plot to khmer_model_results_optimized/loss_plot_20250507_150309.png

🏆 Final Optimized Model Results:
Parameters: {'lr': 0.001, 'epochs': 10, 'gmm_weight': 0.05, 'growth_rate': 32, 'dense_layers': 3, 'hidden_size': 128, 'num_gru_layers': 2}
Best Validation Loss achieved: 0.4835
Test Metrics:
  accuracy: 0.7143
  precision: 0.7212
  recall: 0.7027
  f1: 0.7022

All results saved in ./khmer_model_results_optimized
