# 1. Setup and Configuration
- Loads necessary libraries
- Defines model hyperparameters and training configurations
- Specifies the path to the processed data file from the preprocessing notebook

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import random
import pickle
import os
import time
from collections import Counter, deque 
from tqdm.notebook import tqdm 
import multiprocessing

In [2]:
PROCESSED_DATA_PICKLE_PATH = 'processed_cache_data.pkl' 
SEQ_LENGTH = 20      
MODEL_MAX_SEQ_LENGTH = 50
BATCH_SIZE = 8             
NUM_EPOCHS = 5 
K_PREFETCH_MODEL = 1
LRU_CACHE_SIZE_PERCENTAGE = 0.001 
GRAD_CLIP = 1.0
TRAIN_SPLIT_RATIO = 0.8
NUM_WORKERS_DATALOADER = 8
NUM_INIT_WORKERS_DATASET = 0

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# 2. Define Transformer Model Components

In [4]:
'''
Kelas ini mengimplementasikan mekanisme positional encoding dari transformer
Mekanisme ini akan menambahkan informasi mengenai posisi dari token pada sequence masukan
Mekanisme ini penting karena transformer memroses token secara parallel sehingga perlu diketahui konteks posisi dari token
'''
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()

        # Ditambahkan dropout untuk mencegah overfitting
        self.dropout = nn.Dropout(p=dropout)

        # Tensor ini akan merepresentasikan posisi token
        position = torch.arange(max_len).unsqueeze(1)

        # Menghitung pembagi dari fungsi sinus dan cosinus yang akan digunakan pada positional encoding
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        # Inisialisasi tensor positional encoding
        pe = torch.zeros(max_len, d_model)

        # Nilai dari positional encoding untuk elemen berindeks ganjil adalah cosinus dan genap adalah sinus
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Memasukkan positional encoding sebagai buffer --> buffer adalah state dari model yang tidak dilatih
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Menghitung nilai positional encoding dengan menambahkan nilai positional encoding ke tensor masukan
        # Setelah itu aplikasikan dropout
        x = x + self.pe[:x.size(1), :]
        return self.dropout(x)

In [5]:
'''
Kelas ini mengimplementasikan multi-query self-attention (MQSA)
MQSA merupakan optimasi dari Multi-Head Self-Attention (MHSA)
MQSA tetap memiliki multiple query heads akan tetapi hanya menggunakan satu projeksi Key dan Value pada seluruh Query head
Optimasi MQSA ini menyebabkan berkurangnya jumlah parameter yang dibutuhkan
'''
class MultiQuerySelfAttention(nn.Module):
    def __init__(self, d_model: int, num_query_heads: int, dropout: float = 0.1):
        super().__init__()

        # Inisialisasi jumlah attention head dan dimensinya
        assert d_model % num_query_heads == 0, "d_model must be divisible by num_query_heads"
        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.query_head_dim = d_model // num_query_heads # Dimension of each query head

        # Dimensi dari projeksi Key dan Value pada implementasi ini akan sama dengan dimensi dari Query head
        self.kv_dim = self.query_head_dim 

        # Inisialisasi linear layer proyeksi Query, Key, dan Value
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, self.kv_dim)
        self.W_v = nn.Linear(d_model, self.kv_dim)

        # Inisialisasi linear layer yang akan menjadi output dari Q, K, dan V
        self.W_o = nn.Linear(d_model, d_model) # Output projection

        # Digunakan juga dropout untuk mencegah overfitting
        self.dropout_attn = nn.Dropout(dropout)

    def forward(self, query_input: torch.Tensor, key_input: torch.Tensor, value_input: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size = query_input.size(0)
        seq_len_q = query_input.size(1)
        seq_len_k = key_input.size(1)
        seq_len_v = value_input.size(1)

        # Dilakukan proyeksi Query untuk setiap attention head
        Q = self.W_q(query_input)
        Q = Q.view(batch_size, seq_len_q, self.num_query_heads, self.query_head_dim).transpose(1, 2)

        # Dilakukan proyeksi Key dan Value pada shared projection
        K_shared = self.W_k(key_input)    # (B, L_k, D_kv)
        V_shared = self.W_v(value_input)  # (B, L_v, D_kv)

        # Dilakukan pemrosesan supaya shared Key dan Value dapat dibroadcast ke setiap head
        K_shared = K_shared.unsqueeze(1) # (B, 1, L_k, D_kv)
        V_shared = V_shared.unsqueeze(1) # (B, 1, L_v, D_kv)
        
        # Dihitung nilai atensi (kemiripan query ke key) menggunakan scaled dot-product attention
        attention_scores = torch.matmul(Q, K_shared.transpose(-2, -1)) / math.sqrt(self.query_head_dim)

        # Apabila ada masking, maka diterapkan
        # Masking ini digunakan untuk mencegah decoder dari melihat token di depan (future token)
        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(attention_mask == True, float('-inf'))
            
        attention_weights = F.softmax(attention_scores, dim=-1) # (B, H_q, L_q, L_k)
        attention_weights = self.dropout_attn(attention_weights)
        
        # Hitung context vector
        # Hasil vector ini adalah weighted average dari Values
        context_vector = torch.matmul(attention_weights, V_shared)
        context_vector = context_vector.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model) # D_model = H_q * D_kv (if D_kv=D_qh)
        output = self.W_o(context_vector) 
        return output, attention_weights

In [6]:
'''
Implementasi dari mekanisme position wise feed forward
Mekanisme / lapisan ini berada tepat setelah MHSA
Lapisan ini digunakan untuk menambahkan atau membuat fitur menjadi non-linear
'''
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

In [7]:
'''
Kelas yang merepresentasikan satu blok decoder pada transformer
Kelas ini menggabungkan MHQA dan Position-Wise Feed-Forward
'''
class DecoderBlockScratch(nn.Module):
    def __init__(self, d_model: int, num_query_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attention = MultiQuerySelfAttention(d_model, num_query_heads, dropout) 
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        norm_x = self.norm1(x)
        attn_output, _ = self.self_attention(norm_x, norm_x, norm_x, attention_mask)
        x = x + self.dropout1(attn_output)
        norm_x = self.norm2(x)
        ff_output = self.feed_forward(norm_x)
        x = x + self.dropout2(ff_output)
        return x

In [8]:
'''
Kelas yang merepresentasikan decoder yang telah di-assembly menjadi transformer
Transformer ini melakukan embedding dengan mengubah token menjadi dense vector
'''
class DecoderOnlyTransformerScratch(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, num_query_heads: int, num_layers: int, # num_heads -> num_query_heads
                 d_ff: int, max_seq_length: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_length)
        self.decoder_blocks = nn.ModuleList([
            DecoderBlockScratch(d_model, num_query_heads, d_ff, dropout) for _ in range(num_layers)]) # Pass num_query_heads
        self.final_norm = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout_emb = nn.Dropout(dropout)
        self._init_weights()

    def _init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc_out.bias.data.zero_()
        self.fc_out.weight.data.uniform_(-initrange, initrange)

    def _generate_causal_mask(self, size: int, device: torch.device) -> torch.Tensor:
        # (Unchanged)
        return torch.triu(torch.ones(size, size, device=device, dtype=torch.bool), diagonal=1)

    def forward(self, src: torch.Tensor, src_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
        batch_size, seq_len = src.shape
        device = src.device
        if seq_len > self.max_seq_length:
            pass 
        emb_out = self.embedding(src) * math.sqrt(self.d_model)
        x = self.dropout_emb(self.pos_encoder(emb_out))
        
        causal_mask_base = self._generate_causal_mask(seq_len, device) 
        combined_attention_mask = causal_mask_base.unsqueeze(0).unsqueeze(0) 

        if src_key_padding_mask is not None:
            expanded_padding_mask = src_key_padding_mask.unsqueeze(1).unsqueeze(2) 
            combined_attention_mask = combined_attention_mask | expanded_padding_mask
            combined_attention_mask = combined_attention_mask.bool() 

        for block in self.decoder_blocks:
            x = block(x, combined_attention_mask)
            
        x = self.final_norm(x)
        logits = self.fc_out(x)
        return logits

# 3. Define `CacheDataset`

In [9]:
'''
Fungsi ini digunakan untuk menciptakan slices dari sequence
Slices yang dihasilkan menggunakan prinsip sliding window
'''
def _create_single_sequence_pair_for_mp(args_tuple):
    indexed_obj_ids_ref, i, sequence_length_val = args_tuple
    
    input_seq_list = indexed_obj_ids_ref[i : i + sequence_length_val]
    target_seq_list = indexed_obj_ids_ref[i + 1 : i + sequence_length_val + 1]
    
    return torch.tensor(input_seq_list, dtype=torch.long), \
           torch.tensor(target_seq_list, dtype=torch.long)

In [10]:
'''
Kelas dataset dari sekuens akses objek
Kelas ini akan merubah sekuens masukan menjadi potongan sekuens berbasis sliding window
'''
class CacheDataset(Dataset):
    def __init__(self, filtered_obj_id_sequence: list, list_of_popular_objects: list, sequence_length: int, num_init_workers: int = 0):
        super().__init__()
        self.sequence_length = sequence_length
        self.popular_objects_vocab = sorted(list(set(list_of_popular_objects)))
        self.obj_to_idx = {obj: i for i, obj in enumerate(self.popular_objects_vocab)}
        self.idx_to_obj = {i: obj for obj, i in self.obj_to_idx.items()}
        self.vocab_size = len(self.popular_objects_vocab)
        
        self.indexed_obj_ids = [self.obj_to_idx[obj] for obj in filtered_obj_id_sequence if obj in self.obj_to_idx]
        
        self.input_sequences = []
        self.target_sequences = []
        
        if len(self.indexed_obj_ids) >= self.sequence_length + 1:
            num_total_sequences = len(self.indexed_obj_ids) - self.sequence_length

            actual_init_workers = 0
            if num_init_workers > 0:
                 actual_init_workers = min(num_init_workers, os.cpu_count() if os.cpu_count() else 1) 
            
            min_sequences_for_parallel = 1000 
            min_sequences_per_worker = 50 

            if actual_init_workers > 0 and \
               num_total_sequences >= min_sequences_for_parallel and \
               (num_total_sequences / actual_init_workers) >= min_sequences_per_worker:
                
                print(f"Using {actual_init_workers} workers for CacheDataset sequence creation ({num_total_sequences} sequences).")
                tasks_args = [(self.indexed_obj_ids, i, self.sequence_length) for i in range(num_total_sequences)]
                
                with multiprocessing.Pool(processes=actual_init_workers) as pool:
                    results_list_pairs = []
                    for pair in tqdm(pool.imap_unordered(_create_single_sequence_pair_for_mp, tasks_args), 
                                     total=num_total_sequences, 
                                     desc="Creating Dataset Sequences (Parallel)", 
                                     unit="sequence", 
                                     leave=False):
                        results_list_pairs.append(pair)
                
                if results_list_pairs: 
                    raw_input_sequences, raw_target_sequences = zip(*results_list_pairs)
                    self.input_sequences = [torch.tensor(s, dtype=torch.long) for s in raw_input_sequences]
                    self.target_sequences = [torch.tensor(s, dtype=torch.long) for s in raw_target_sequences]
            else:
                if actual_init_workers > 0 : 
                    print(f"Dataset size ({num_total_sequences} sequences) or worker load too small for parallel init with {actual_init_workers} workers. Using sequential.")
                
                for i in tqdm(range(num_total_sequences), desc="Creating Dataset Sequences (Sequential)", unit="sequence", leave=False):
                    self.input_sequences.append(torch.tensor(self.indexed_obj_ids[i : i + self.sequence_length], dtype=torch.long))
                    self.target_sequences.append(torch.tensor(self.indexed_obj_ids[i + 1 : i + self.sequence_length + 1], dtype=torch.long))

    def __len__(self):
        return len(self.input_sequences)
    
    def __getitem__(self, idx):
        return self.input_sequences[idx], self.target_sequences[idx]
    
    def get_vocab_info(self):
        return {'obj_to_idx': self.obj_to_idx, 'idx_to_obj': self.idx_to_obj, 'vocab_size': self.vocab_size}

# 4. Define Training and Evaluation Loops

In [20]:
def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, 
                optimizer: optim.Optimizer, device: torch.device, grad_clip_value: float = None, epoch_num: int = 0, config_name: str = ""):
    model.train()
    total_loss = 0.0
    batch_iterator = tqdm(dataloader, desc=f"Epoch {epoch_num} Training", leave=False, unit="batch")
    for batch_idx, (input_seqs, target_seqs) in enumerate(batch_iterator):
        input_seqs, target_seqs = input_seqs.to(device), target_seqs.to(device)
        optimizer.zero_grad()
        outputs = model(input_seqs, src_key_padding_mask=None)
        loss = criterion(outputs.view(-1, outputs.size(-1)), target_seqs.view(-1))
        loss.backward()
        if grad_clip_value:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_value)
        optimizer.step()
        total_loss += loss.item()
        batch_iterator.set_postfix_str(f"Loss: {loss.item():.4f}")
    return total_loss / len(dataloader) if len(dataloader) > 0 else 0.0

'''
Evaluasi digunakan mensimulasikan performa LRU cache dengan prefetcher
Prefetcher yang digunakan hanya melakukan prefetching terhadap satu objek saja
'''
def evaluate_model_with_prefetcher(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, 
                                   device: torch.device, model_vocab_size: int, 
                                   k_items_to_prefetch: int = 1, 
                                   cache_size_percentage: float = 0.1):
    model.eval()
    total_loss = 0.0
    total_lru_misses = 0
    total_lru_accesses = 0
    
    if model_vocab_size == 0:
        print("Warning: model_vocab_size is 0. LRU simulation might not be meaningful.")
        cache_capacity = k_items_to_prefetch 
    else:
        cache_capacity = max(1, int(model_vocab_size * cache_size_percentage))
    
    lru_cache = deque(maxlen=cache_capacity)

    batch_iterator = tqdm(dataloader, desc="Evaluating with Prefetcher+LRU", leave=False, unit="batch")
    with torch.no_grad():
        for input_seqs, target_seqs in batch_iterator: 
            input_seqs, target_seqs = input_seqs.to(device), target_seqs.to(device)
            outputs = model(input_seqs, src_key_padding_mask=None) 
            loss = criterion(outputs.view(-1, outputs.size(-1)), target_seqs.view(-1))
            total_loss += loss.item() 

            for b in range(input_seqs.size(0)): 
                for s in range(target_seqs.size(1)): 
                    current_logits_for_prefetch = outputs[b, s, :]
                    actual_demanded_item_idx = target_seqs[b, s].item()

                    if k_items_to_prefetch > 0:
                        _, predicted_indices_to_prefetch = torch.topk(
                            current_logits_for_prefetch, 
                            k=min(k_items_to_prefetch, model_vocab_size if model_vocab_size > 0 else k_items_to_prefetch), 
                            dim=-1
                        )
                        for pred_idx_tensor in predicted_indices_to_prefetch:
                            pred_idx = pred_idx_tensor.item()
                            if pred_idx in lru_cache:
                                lru_cache.remove(pred_idx) 
                            lru_cache.append(pred_idx) 
                    
                    total_lru_accesses += 1
                    if actual_demanded_item_idx in lru_cache:
                        lru_cache.remove(actual_demanded_item_idx) 
                        lru_cache.append(actual_demanded_item_idx)
                    else:
                        total_lru_misses += 1
                        if actual_demanded_item_idx in lru_cache: 
                            lru_cache.remove(actual_demanded_item_idx)
                        lru_cache.append(actual_demanded_item_idx)
            
    avg_loss_per_batch = total_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    lru_miss_ratio = total_lru_misses / total_lru_accesses if total_lru_accesses > 0 else 0.0
    return avg_loss_per_batch, lru_miss_ratio

In [12]:
'''
Bagian ini akan mengevaluasi performa cache LRU tanpa prefetcher sebagai baseline
'''
def evaluate_lru_only_cache(dataloader: DataLoader, device: torch.device, 
                            model_vocab_size: int, cache_size_percentage: float = 0.1):
    total_lru_misses = 0
    total_lru_accesses = 0

    if model_vocab_size == 0:
        print("Warning: model_vocab_size is 0 for LRU-only. Cache capacity might be 0.")
        cache_capacity = 1
    else:
        cache_capacity = max(1, int(model_vocab_size * cache_size_percentage))
    
    print(f"Simulating Baseline LRU cache with capacity: {cache_capacity} ({cache_size_percentage*100:.1f}% of vocab {model_vocab_size}).")
    
    lru_cache = deque(maxlen=cache_capacity)

    batch_iterator = tqdm(dataloader, desc="Evaluating Baseline LRU", leave=False, unit="batch")
    with torch.no_grad():
        for _, target_seqs in batch_iterator:
            target_seqs = target_seqs.to(device)

            for b in range(target_seqs.size(0)):
                for s in range(target_seqs.size(1)):
                    actual_demanded_item_idx = target_seqs[b, s].item()
                    
                    total_lru_accesses += 1

                    if actual_demanded_item_idx in lru_cache:
                        # Cache Hit
                        lru_cache.remove(actual_demanded_item_idx)
                        lru_cache.append(actual_demanded_item_idx)
                    else:
                        total_lru_misses += 1
                        if actual_demanded_item_idx in lru_cache:
                            lru_cache.remove(actual_demanded_item_idx)
                        lru_cache.append(actual_demanded_item_idx)
            
    lru_miss_ratio = total_lru_misses / total_lru_accesses if total_lru_accesses > 0 else 0.0
    return lru_miss_ratio

# 5. Load Processed Data and Prepare Datasets

In [13]:
print(f"Loading processed data from {PROCESSED_DATA_PICKLE_PATH}...")
if not os.path.exists(PROCESSED_DATA_PICKLE_PATH):
    print(f"Error: Processed data file not found at {PROCESSED_DATA_PICKLE_PATH}.")
    print("Please run the preprocessing notebook first to generate this file.")
    raise FileNotFoundError(f"Missing {PROCESSED_DATA_PICKLE_PATH}")

with open(PROCESSED_DATA_PICKLE_PATH, 'rb') as f:
    processed_data = pickle.load(f)

filtered_sequence = processed_data['filtered_sequence_popular_obj_ids']
list_of_popular_objects_for_vocab = processed_data['list_of_popular_obj_ids']

print(f"Loaded filtered sequence of length: {len(filtered_sequence)}")
print(f"Number of unique popular objects for vocabulary: {len(list_of_popular_objects_for_vocab)}")

if not filtered_sequence or not list_of_popular_objects_for_vocab:
    print("Error: Loaded data is empty. Cannot proceed with training.")
    raise ValueError("Empty data loaded from pickle file.")

# Split the filtered sequence for training and validation
split_idx = int(TRAIN_SPLIT_RATIO * len(filtered_sequence))
train_filtered_ids = filtered_sequence[:split_idx]
val_filtered_ids = filtered_sequence[split_idx:]

print(f"Training sequence length: {len(train_filtered_ids)}")
print(f"Validation sequence length: {len(val_filtered_ids)}")

Loading processed data from processed_cache_data.pkl...
Loaded filtered sequence of length: 15000
Number of unique popular objects for vocabulary: 7588
Training sequence length: 12000
Validation sequence length: 3000


In [14]:
print("Creating Training Dataset...")
train_dataset = CacheDataset(train_filtered_ids, list_of_popular_objects_for_vocab, SEQ_LENGTH, num_init_workers=NUM_INIT_WORKERS_DATASET)
print("Creating Validation Dataset...")
val_dataset = CacheDataset(val_filtered_ids, list_of_popular_objects_for_vocab, SEQ_LENGTH, num_init_workers=NUM_INIT_WORKERS_DATASET)

if len(train_dataset) == 0:
    raise ValueError("Training dataset is empty after processing. Insufficient data for SEQ_LENGTH.")
if len(val_dataset) == 0:
    print("Warning: Validation dataset is empty. Evaluation will be skipped.")

Creating Training Dataset...


Creating Dataset Sequences (Sequential):   0%|          | 0/11980 [00:00<?, ?sequence/s]

Creating Validation Dataset...


Creating Dataset Sequences (Sequential):   0%|          | 0/2980 [00:00<?, ?sequence/s]

In [15]:
pin_memory_flag = True if device.type == 'cuda' else False
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=NUM_WORKERS_DATALOADER, pin_memory=pin_memory_flag)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS_DATALOADER, pin_memory=pin_memory_flag) if len(val_dataset) > 0 else None

In [16]:
MODEL_VOCAB_SIZE = train_dataset.vocab_size
print(f"Effective Vocabulary size for all models: {MODEL_VOCAB_SIZE}")
print(f"Number of training sequences: {len(train_dataset)}")
print(f"Number of validation sequences: {len(val_dataset)}")

Effective Vocabulary size for all models: 7588
Number of training sequences: 11980
Number of validation sequences: 2980


In [17]:
baseline_lru_miss_ratio = float('nan')
if val_dataloader and len(val_dataset) > 0:
    if MODEL_VOCAB_SIZE > 0:
        baseline_lru_miss_ratio = evaluate_lru_only_cache(
            val_dataloader, device, 
            model_vocab_size=MODEL_VOCAB_SIZE, 
            cache_size_percentage=LRU_CACHE_SIZE_PERCENTAGE
        )
        print(f"\nBaseline LRU-Only Cache Miss Ratio (on validation set): {baseline_lru_miss_ratio:.4f}\n")
    else:
        print("\nMODEL_VOCAB_SIZE is 0, cannot run baseline LRU-only cache evaluation meaningfully.")
else:
    print("\nValidation dataloader not available, skipping baseline LRU-only cache evaluation.")

Simulating Baseline LRU cache with capacity: 7 (0.1% of vocab 7588).


Evaluating Baseline LRU:   0%|          | 0/373 [00:00<?, ?batch/s]


Baseline LRU-Only Cache Miss Ratio (on validation set): 0.7043



# 6. Define Hyperparameter Configurations

In [18]:
hyperparameter_configs = [
    {
        "name": "Config1_MQA_Baseline", # Renamed to reflect MQA
        "D_MODEL": 128,
        "NUM_QUERY_HEADS": 4, # Number of heads for Query in MQA
        "NUM_LAYERS": 3,
        "D_FF": 256,
        "DROPOUT": 0.1,
        "LEARNING_RATE": 0.001,
    },
    {
        "name": "Config2_MQA_LargerModel_LowerLR",
        "D_MODEL": 256, 
        "NUM_QUERY_HEADS": 8,  
        "NUM_LAYERS": 4, 
        "D_FF": 512,   
        "DROPOUT": 0.15, 
        "LEARNING_RATE": 0.0005, 
    },
    {
        "name": "Config3_MQA_SmallerModel_HigherLR_MoreDropout",
        "D_MODEL": 64,  
        "NUM_QUERY_HEADS": 2,  
        "NUM_LAYERS": 2, 
        "D_FF": 128,   
        "DROPOUT": 0.2,  
        "LEARNING_RATE": 0.002, 
    },
    {
        "name": "Config4_MQA_MediumModel_MoreQueryHeads",
        "D_MODEL": 128,
        "NUM_QUERY_HEADS": 8, # More query heads for same D_MODEL
        "NUM_LAYERS": 3,
        "D_FF": 256,
        "DROPOUT": 0.1,
        "LEARNING_RATE": 0.001,
    }
]

# 7. Training and Evaluation Loop for Each Configuration

In [21]:
all_results = [] 

configurations_pbar = tqdm(hyperparameter_configs, desc="Configurations")
for config in configurations_pbar: 
    print(f"\n\n{'='*20} Starting Training for: {config['name']} {'='*20}")
    print(f"Hyperparameters: {config}")
    start_time_config = time.time()

    current_model = DecoderOnlyTransformerScratch(
        vocab_size=MODEL_VOCAB_SIZE,
        d_model=config['D_MODEL'],
        num_query_heads=config['NUM_QUERY_HEADS'], # Use NUM_QUERY_HEADS
        num_layers=config['NUM_LAYERS'],
        d_ff=config['D_FF'],
        max_seq_length=MODEL_MAX_SEQ_LENGTH,
        dropout=config['DROPOUT']
    ).to(device)

    criterion = nn.CrossEntropyLoss() 
    optimizer = optim.Adam(current_model.parameters(), lr=config['LEARNING_RATE'])

    num_params = sum(p.numel() for p in current_model.parameters() if p.requires_grad)
    print(f"Model for {config['name']} initialized with {num_params} trainable parameters.")

    config_results = {
        "config_name": config['name'],
        "hyperparameters": config,
        "num_parameters": num_params,
        "train_losses": [],
        "val_model_losses": [], 
        "val_lru_miss_ratios_with_prefetcher": [], 
        "best_val_lru_miss_ratio_with_prefetcher": float('inf'), 
        "best_epoch": -1
    }

    if len(train_dataloader) == 0:
        print(f"Training dataloader is empty for {config['name']}. Skipping training for this config.")
        all_results.append(config_results) 
        continue

    for epoch in tqdm(range(1, NUM_EPOCHS + 1), desc=f"Config '{config['name']}' Epochs", leave=False):
        epoch_start_time = time.time()
        
        avg_train_loss = train_epoch(current_model, train_dataloader, criterion, optimizer, device, GRAD_CLIP, epoch_num=epoch, config_name=config['name'])
        config_results["train_losses"].append(avg_train_loss)
        print(f"Config: {config['name']}, Epoch {epoch} Training: Avg Model Loss: {avg_train_loss:.4f}")
        
        current_val_model_loss = float('nan')
        current_val_lru_miss_ratio_with_prefetcher = float('nan')

        if val_dataloader and len(val_dataset) > 0:
            avg_val_model_loss, val_lru_miss_ratio_wp = evaluate_model_with_prefetcher( 
                current_model, val_dataloader, criterion, device, 
                model_vocab_size=MODEL_VOCAB_SIZE, 
                k_items_to_prefetch=K_PREFETCH_MODEL, 
                cache_size_percentage=LRU_CACHE_SIZE_PERCENTAGE
            )
            config_results["val_model_losses"].append(avg_val_model_loss)
            config_results["val_lru_miss_ratios_with_prefetcher"].append(val_lru_miss_ratio_wp)
            current_val_model_loss = avg_val_model_loss
            current_val_lru_miss_ratio_with_prefetcher = val_lru_miss_ratio_wp
            print(f"Config: {config['name']}, Epoch {epoch} Validation: Avg Model Loss: {avg_val_model_loss:.4f}, Prefetcher+LRU Miss Ratio: {val_lru_miss_ratio_wp:.4f}")
            
            if not math.isnan(val_lru_miss_ratio_wp) and val_lru_miss_ratio_wp < config_results["best_val_lru_miss_ratio_with_prefetcher"]:
                config_results["best_val_lru_miss_ratio_with_prefetcher"] = val_lru_miss_ratio_wp
                config_results["best_epoch"] = epoch
        else:
            config_results["val_model_losses"].append(float('nan')) 
            config_results["val_lru_miss_ratios_with_prefetcher"].append(float('nan'))
            if len(val_dataset) == 0:
                 print(f"Config: {config['name']}, Epoch {epoch}: Validation dataset is empty. Skipping validation.")
            else:
                 print(f"Config: {config['name']}, Epoch {epoch}: Validation dataloader not available. Skipping validation.")
        
        epoch_duration = time.time() - epoch_start_time
        if not math.isnan(current_val_lru_miss_ratio_with_prefetcher):
             configurations_pbar.set_description_str(f"Cfgs (Best P+LRUMR for {config['name']}: {config_results['best_val_lru_miss_ratio_with_prefetcher']:.4f})")
        print(f"Config: {config['name']}, Epoch {epoch} duration: {epoch_duration:.2f} seconds")

    config_duration = time.time() - start_time_config
    print(f"\nTraining for {config['name']} completed in {config_duration:.2f} seconds.")
    print(f"Best Validation Prefetcher+LRU Miss Ratio for {config['name']}: {config_results['best_val_lru_miss_ratio_with_prefetcher']:.4f} at Epoch {config_results['best_epoch']}")
    all_results.append(config_results)
    configurations_pbar.set_description_str("Configurations") 

print(f"\n\n{'='*20} All Configurations Processed {'='*20}")

Configurations:   0%|          | 0/4 [00:00<?, ?it/s]



Hyperparameters: {'name': 'Config1_MQA_Baseline', 'D_MODEL': 128, 'NUM_QUERY_HEADS': 4, 'NUM_LAYERS': 3, 'D_FF': 256, 'DROPOUT': 0.1, 'LEARNING_RATE': 0.001}
Model for Config1_MQA_Baseline initialized with 2273508 trainable parameters.


Config 'Config1_MQA_Baseline' Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 1 Training: Avg Model Loss: 2.8702


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 1 Validation: Avg Model Loss: 12.3316, Prefetcher+LRU Miss Ratio: 0.6894
Config: Config1_MQA_Baseline, Epoch 1 duration: 77.70 seconds


Epoch 2 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 2 Training: Avg Model Loss: 0.5214


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 2 Validation: Avg Model Loss: 13.8054, Prefetcher+LRU Miss Ratio: 0.6951
Config: Config1_MQA_Baseline, Epoch 2 duration: 79.00 seconds


Epoch 3 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 3 Training: Avg Model Loss: 0.4025


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 3 Validation: Avg Model Loss: 14.2197, Prefetcher+LRU Miss Ratio: 0.6964
Config: Config1_MQA_Baseline, Epoch 3 duration: 78.14 seconds


Epoch 4 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 4 Training: Avg Model Loss: 0.3610


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 4 Validation: Avg Model Loss: 14.4871, Prefetcher+LRU Miss Ratio: 0.6950
Config: Config1_MQA_Baseline, Epoch 4 duration: 75.89 seconds


Epoch 5 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 5 Training: Avg Model Loss: 0.3312


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config1_MQA_Baseline, Epoch 5 Validation: Avg Model Loss: 15.1268, Prefetcher+LRU Miss Ratio: 0.6895
Config: Config1_MQA_Baseline, Epoch 5 duration: 79.36 seconds

Training for Config1_MQA_Baseline completed in 390.24 seconds.
Best Validation Prefetcher+LRU Miss Ratio for Config1_MQA_Baseline: 0.6894 at Epoch 1


Hyperparameters: {'name': 'Config2_MQA_LargerModel_LowerLR', 'D_MODEL': 256, 'NUM_QUERY_HEADS': 8, 'NUM_LAYERS': 4, 'D_FF': 512, 'DROPOUT': 0.15, 'LEARNING_RATE': 0.0005}
Model for Config2_MQA_LargerModel_LowerLR initialized with 5541028 trainable parameters.


Config 'Config2_MQA_LargerModel_LowerLR' Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 1 Training: Avg Model Loss: 3.0454


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 1 Validation: Avg Model Loss: 11.5336, Prefetcher+LRU Miss Ratio: 0.6908
Config: Config2_MQA_LargerModel_LowerLR, Epoch 1 duration: 87.60 seconds


Epoch 2 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 2 Training: Avg Model Loss: 0.5527


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 2 Validation: Avg Model Loss: 13.6804, Prefetcher+LRU Miss Ratio: 0.6844
Config: Config2_MQA_LargerModel_LowerLR, Epoch 2 duration: 87.43 seconds


Epoch 3 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 3 Training: Avg Model Loss: 0.3921


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 3 Validation: Avg Model Loss: 14.5103, Prefetcher+LRU Miss Ratio: 0.6848
Config: Config2_MQA_LargerModel_LowerLR, Epoch 3 duration: 87.55 seconds


Epoch 4 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 4 Training: Avg Model Loss: 0.3452


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 4 Validation: Avg Model Loss: 14.7940, Prefetcher+LRU Miss Ratio: 0.6863
Config: Config2_MQA_LargerModel_LowerLR, Epoch 4 duration: 87.73 seconds


Epoch 5 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 5 Training: Avg Model Loss: 0.3200


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config2_MQA_LargerModel_LowerLR, Epoch 5 Validation: Avg Model Loss: 14.9909, Prefetcher+LRU Miss Ratio: 0.6883
Config: Config2_MQA_LargerModel_LowerLR, Epoch 5 duration: 87.52 seconds

Training for Config2_MQA_LargerModel_LowerLR completed in 438.13 seconds.
Best Validation Prefetcher+LRU Miss Ratio for Config2_MQA_LargerModel_LowerLR: 0.6844 at Epoch 2


Hyperparameters: {'name': 'Config3_MQA_SmallerModel_HigherLR_MoreDropout', 'D_MODEL': 64, 'NUM_QUERY_HEADS': 2, 'NUM_LAYERS': 2, 'D_FF': 128, 'DROPOUT': 0.2, 'LEARNING_RATE': 0.002}
Model for Config3_MQA_SmallerModel_HigherLR_MoreDropout initialized with 1037604 trainable parameters.


Config 'Config3_MQA_SmallerModel_HigherLR_MoreDropout' Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 1 Training: Avg Model Loss: 3.7167


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 1 Validation: Avg Model Loss: 15.1157, Prefetcher+LRU Miss Ratio: 0.6867
Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 1 duration: 65.28 seconds


Epoch 2 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 2 Training: Avg Model Loss: 1.1389


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 2 Validation: Avg Model Loss: 17.0227, Prefetcher+LRU Miss Ratio: 0.6871
Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 2 duration: 65.52 seconds


Epoch 3 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 3 Training: Avg Model Loss: 0.8006


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 3 Validation: Avg Model Loss: 18.0005, Prefetcher+LRU Miss Ratio: 0.6859
Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 3 duration: 64.71 seconds


Epoch 4 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 4 Training: Avg Model Loss: 0.6973


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 4 Validation: Avg Model Loss: 18.7145, Prefetcher+LRU Miss Ratio: 0.6849
Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 4 duration: 47.65 seconds


Epoch 5 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 5 Training: Avg Model Loss: 0.6395


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 5 Validation: Avg Model Loss: 18.8699, Prefetcher+LRU Miss Ratio: 0.6900
Config: Config3_MQA_SmallerModel_HigherLR_MoreDropout, Epoch 5 duration: 64.83 seconds

Training for Config3_MQA_SmallerModel_HigherLR_MoreDropout completed in 308.10 seconds.
Best Validation Prefetcher+LRU Miss Ratio for Config3_MQA_SmallerModel_HigherLR_MoreDropout: 0.6849 at Epoch 4


Hyperparameters: {'name': 'Config4_MQA_MediumModel_MoreQueryHeads', 'D_MODEL': 128, 'NUM_QUERY_HEADS': 8, 'NUM_LAYERS': 3, 'D_FF': 256, 'DROPOUT': 0.1, 'LEARNING_RATE': 0.001}
Model for Config4_MQA_MediumModel_MoreQueryHeads initialized with 2261124 trainable parameters.


Config 'Config4_MQA_MediumModel_MoreQueryHeads' Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 1 Training: Avg Model Loss: 2.9206


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 1 Validation: Avg Model Loss: 12.1083, Prefetcher+LRU Miss Ratio: 0.6914
Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 1 duration: 78.48 seconds


Epoch 2 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 2 Training: Avg Model Loss: 0.5414


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 2 Validation: Avg Model Loss: 14.3730, Prefetcher+LRU Miss Ratio: 0.6880
Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 2 duration: 68.99 seconds


Epoch 3 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 3 Training: Avg Model Loss: 0.4090


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 3 Validation: Avg Model Loss: 15.1056, Prefetcher+LRU Miss Ratio: 0.6869
Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 3 duration: 79.12 seconds


Epoch 4 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 4 Training: Avg Model Loss: 0.3627


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 4 Validation: Avg Model Loss: 15.4308, Prefetcher+LRU Miss Ratio: 0.6888
Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 4 duration: 80.03 seconds


Epoch 5 Training:   0%|          | 0/1497 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 5 Training: Avg Model Loss: 0.3375


Evaluating with Prefetcher+LRU:   0%|          | 0/373 [00:00<?, ?batch/s]

Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 5 Validation: Avg Model Loss: 14.9194, Prefetcher+LRU Miss Ratio: 0.7044
Config: Config4_MQA_MediumModel_MoreQueryHeads, Epoch 5 duration: 78.16 seconds

Training for Config4_MQA_MediumModel_MoreQueryHeads completed in 384.95 seconds.
Best Validation Prefetcher+LRU Miss Ratio for Config4_MQA_MediumModel_MoreQueryHeads: 0.6869 at Epoch 3




# 8. Summarize Results

In [22]:
for result in all_results:
    print(f"\nConfiguration: {result['config_name']}")
    print(f"  Hyperparameters: {result['hyperparameters']}")
    print(f"  Trainable Parameters: {result['num_parameters']}")
    if result['train_losses']: 
        print(f"  Final Training Model Loss (Epoch {len(result['train_losses'])}): {result['train_losses'][-1]:.4f}")
        if result['val_lru_miss_ratios_with_prefetcher'] and not math.isnan(result['val_lru_miss_ratios_with_prefetcher'][-1]):
             print(f"  Final Validation Model Loss (Epoch {len(result['val_model_losses'])}): {result['val_model_losses'][-1]:.4f}")
             print(f"  Final Validation Prefetcher+LRU Miss Ratio: {result['val_lru_miss_ratios_with_prefetcher'][-1]:.4f}")
        print(f"  Best Validation Prefetcher+LRU Miss Ratio: {result['best_val_lru_miss_ratio_with_prefetcher']:.4f} (Epoch {result['best_epoch']})")
        if not math.isnan(baseline_lru_miss_ratio) and not math.isnan(result['best_val_lru_miss_ratio_with_prefetcher']):
            improvement = baseline_lru_miss_ratio - result['best_val_lru_miss_ratio_with_prefetcher']
            improvement_percent = (improvement / baseline_lru_miss_ratio) * 100 if baseline_lru_miss_ratio > 0 else 0 # Avoid division by zero
            print(f"  Improvement over Baseline LRU: {improvement:.4f} ({improvement_percent:.2f}%)")
    else:
        print("  Training was not run for this configuration (e.g., empty dataloader).")


Configuration: Config1_MQA_Baseline
  Hyperparameters: {'name': 'Config1_MQA_Baseline', 'D_MODEL': 128, 'NUM_QUERY_HEADS': 4, 'NUM_LAYERS': 3, 'D_FF': 256, 'DROPOUT': 0.1, 'LEARNING_RATE': 0.001}
  Trainable Parameters: 2273508
  Final Training Model Loss (Epoch 5): 0.3312
  Final Validation Model Loss (Epoch 5): 15.1268
  Final Validation Prefetcher+LRU Miss Ratio: 0.6895
  Best Validation Prefetcher+LRU Miss Ratio: 0.6894 (Epoch 1)
  Improvement over Baseline LRU: 0.0149 (2.11%)

Configuration: Config2_MQA_LargerModel_LowerLR
  Hyperparameters: {'name': 'Config2_MQA_LargerModel_LowerLR', 'D_MODEL': 256, 'NUM_QUERY_HEADS': 8, 'NUM_LAYERS': 4, 'D_FF': 512, 'DROPOUT': 0.15, 'LEARNING_RATE': 0.0005}
  Trainable Parameters: 5541028
  Final Training Model Loss (Epoch 5): 0.3200
  Final Validation Model Loss (Epoch 5): 14.9909
  Final Validation Prefetcher+LRU Miss Ratio: 0.6883
  Best Validation Prefetcher+LRU Miss Ratio: 0.6844 (Epoch 2)
  Improvement over Baseline LRU: 0.0199 (2.82%)

C

In [23]:
best_overall_config_result = None
if all_results:
    valid_results = [r for r in all_results if r['best_val_lru_miss_ratio_with_prefetcher'] != float('inf') and \
                     r['best_val_lru_miss_ratio_with_prefetcher'] is not None and \
                     not math.isnan(r['best_val_lru_miss_ratio_with_prefetcher'])]
    if valid_results:
        best_overall_config_result = min(valid_results, key=lambda x: x['best_val_lru_miss_ratio_with_prefetcher'])

if best_overall_config_result:
    print(f"\n--- Overall Best Configuration (based on Prefetcher+LRU Miss Ratio) ---")
    print(f"Name: {best_overall_config_result['config_name']}")
    print(f"Best Validation Prefetcher+LRU Miss Ratio: {best_overall_config_result['best_val_lru_miss_ratio_with_prefetcher']:.4f} at Epoch {best_overall_config_result['best_epoch']}")
    print(f"Hyperparameters: {best_overall_config_result['hyperparameters']}")
    print(f"Trainable Parameters: {best_overall_config_result['num_parameters']}")
    if not math.isnan(baseline_lru_miss_ratio):
        improvement = baseline_lru_miss_ratio - best_overall_config_result['best_val_lru_miss_ratio_with_prefetcher']
        improvement_percent = (improvement / baseline_lru_miss_ratio) * 100 if baseline_lru_miss_ratio > 0 else 0
        print(f"  Improvement over Baseline LRU: {improvement:.4f} ({improvement_percent:.2f}%)")
else:
    print("\nCould not determine an overall best configuration (e.g., no valid validation results).")


--- Overall Best Configuration (based on Prefetcher+LRU Miss Ratio) ---
Name: Config2_MQA_LargerModel_LowerLR
Best Validation Prefetcher+LRU Miss Ratio: 0.6844 at Epoch 2
Hyperparameters: {'name': 'Config2_MQA_LargerModel_LowerLR', 'D_MODEL': 256, 'NUM_QUERY_HEADS': 8, 'NUM_LAYERS': 4, 'D_FF': 512, 'DROPOUT': 0.15, 'LEARNING_RATE': 0.0005}
Trainable Parameters: 5541028
  Improvement over Baseline LRU: 0.0199 (2.82%)
