# Import Libraries

In [None]:
# Standard Libraries
import os, sys, time, gc, random

# PyTorch Core Libraries
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision import datasets, models, transforms
from torchsummary import summary

# Additional Libraries
import numpy as np
import pandas as pd

# System Monitoring (GPU/CPU)
import psutil
from pynvml import *

# Initialize NVML for GPU monitoring
nvmlInit()
gc.collect()  # Explicitly call garbage collection to clean up memory
torch.cuda.empty_cache()  # Clear unused memory

""" Reproducibility """
# Define seed value
seed = 42
# Set the random seed for reproducibility
random.seed(seed)  # Python's random module
np.random.seed(seed)  # NumPy random module
torch.manual_seed(seed)  # PyTorch CPU random seed
# Check if CUDA is available and set the seed for CUDA as well
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)  # PyTorch GPU random seed (for current device)
    torch.cuda.manual_seed_all(seed) # PyTorch GPU random seed (for all devices, if multi-GPU)

# For deterministic behavior with cuDNN (when using GPU)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  # Disables the cudnn autotuner to ensure reproducibility

#### Redirecting Print Output to a Log File

In [None]:
# Define the main parameters
MODEL_NAME = "(d)784-256-128-64-32-10" # Target Model
DATASET_NAME = "EMNIST" # Target Dataset
NUM_CLASSES = 26 # No of classes
split_ratio = 0.85 # Split training set into training and validation
ALPHAS = [0]

In [None]:
for ALPHA in ALPHAS:
    FOLDER_NAME = f"{DATASET_NAME}/{MODEL_NAME}"
    if not os.path.exists(FOLDER_NAME):
        os.makedirs(FOLDER_NAME)
    
    log_file_path = os.path.join(FOLDER_NAME, f"a{ALPHA}-logfile.txt")
    log_file = open(log_file_path, 'w')
    sys.stdout = log_file
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'{device} is available...')
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    if num_gpus > 0:
        print(f"GPU Name: {torch.cuda.get_device_name(0)}\n")
    
    if device.type == "cuda":
        torch.cuda.init()
        torch.cuda.current_device()
        torch.zeros(1).to(device)
    
    batch_size = 128
    num_workers = 16
    learning_rate = 1e-3
    num_epochs = 300
    criterion = nn.CrossEntropyLoss()
    
    patience = 15
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    
    train_transform = transforms.Compose([
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.17222732305526733], std=[0.3309466242790222])
    ])
    
    val_test_transform  = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.17222732305526733], std=[0.3309466242790222])
    ])
    
    class TransformedDataset(Dataset):
        def __init__(self, dataset, transform):
            self.dataset = dataset
            self.transform = transform
        def __getitem__(self, index):
            img, label = self.dataset[index]
            return self.transform(img), label
    
        def __len__(self):
            return len(self.dataset)
            
    train_set = datasets.EMNIST(root=f"./{DATASET_NAME}/data", split='letters', train=True, download=True, transform=None)
    test_set = datasets.EMNIST(root=f"./{DATASET_NAME}/data", split='letters', train=False, download=True, transform=val_test_transform)
    train_size = int(split_ratio * len(train_set))
    val_size = len(train_set) - train_size
    raw_train, raw_val = random_split(train_set, [train_size, val_size])
    
    train_dataset = TransformedDataset(raw_train, train_transform)
    val_dataset = TransformedDataset(raw_val, val_test_transform)
    
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True, 
                              persistent_workers=True
                             )
    
    val_loader = DataLoader(val_dataset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers,
                              pin_memory=True, 
                              persistent_workers=True
                             )
    
    test_loader = DataLoader(test_set,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers,
                              pin_memory=True, 
                              persistent_workers=True
                             )
    
    print("\n" + "-"*60)
    print(f"Number of imgaes in train_loader: {len(train_loader.dataset)}")
    print(f"Number of imgaes in val_loader: {len(val_loader.dataset)}")
    print(f"Number of imgaes in test_loader: {len(test_loader.dataset)}")
    print("-"*60)
    
    class FCNN(nn.Module):
        def __init__(self, input_dim=784, output_dim=10):
            super(FCNN, self).__init__()
            self.flatten = nn.Flatten()
            self.fc1 = nn.Linear(input_dim, 256)
            self.fc2 = nn.Linear(256, 128)
            self.fc3 = nn.Linear(128, 64)
            self.fc4 = nn.Linear(64, 32)
            self.fc5 = nn.Linear(32, output_dim)
            self.relu = nn.ReLU()
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.flatten(x)
            x = self.relu(self.fc1(x))
            x = self.relu(self.fc2(x))
            x = self.relu(self.fc3(x))
            x = self.relu(self.fc4(x))
            x = self.fc5(x)
            return x
    
    model = FCNN()
    model.to(device)
    
    print("\n" + "="*60)
    print("           Model & Data Pipeline Summary")
    print("="*60)
    print(f"Model Architecture:\n{model}")
    
    print("\n\nModel Summary:")
    summary(model, (1, 28, 28))
    
    def train_loop(model, train_loader, optimizer, criterion, device):
        model.train()
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
    
        num_batches = len(train_loader)
        batch_print_interval = max(1, num_batches // 5)
    
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            labels = (labels - 1).to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item() * inputs.size(0)
            predictions = outputs.argmax(dim=1)
            correct_predictions += (predictions == labels).sum().item()
            total_samples += inputs.size(0)
    
            if batch_idx % batch_print_interval == 0:
                batch_loss = loss.item()
                batch_accuracy = (predictions == labels).sum().item() / inputs.size(0) * 100.0
                print(f"Batch {batch_idx+1}/{len(train_loader)} - "
                      f"Batch Loss: {batch_loss:.6f}, Batch Accuracy: {batch_accuracy:.4f}%")
    
        train_loss = running_loss / total_samples
        train_accuracy = correct_predictions / total_samples * 100.0
    
        return train_loss, train_accuracy
    
    def val_test_loop(model, data_loader, criterion, device, compute_top5=False):
        model.eval()
        running_loss = 0.0
        correct_predictions_top1 = 0
        correct_predictions_top5 = 0
        total_samples = 0
    
        with torch.no_grad():
            for inputs, labels in data_loader:
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                labels = (labels - 1).to(device)
    
                outputs = model(inputs)
                loss = criterion(outputs, labels)
    
                running_loss += loss.item() * inputs.size(0)
                predictions = outputs.argmax(dim=1)
                correct_predictions_top1 += (predictions == labels).sum().item()
                total_samples += inputs.size(0)
                
                if compute_top5:
                    _, top5_predictions = outputs.topk(5, dim=1)
                    correct_predictions_top5 += (top5_predictions == labels.view(-1, 1)).any(dim=1).sum().item()
    
        avg_loss = running_loss / total_samples
        top1_accuracy = correct_predictions_top1 / total_samples * 100.0
        val_test_results = {
            'top1_accuracy': top1_accuracy,
            'loss': avg_loss,
        }
        
        if compute_top5:
            top5_accuracy = correct_predictions_top5 / total_samples * 100.0
            val_test_results['top5_accuracy'] = top5_accuracy
    
        return val_test_results
    
    def save_epoch_results(epoch, current_lr, running_train_loss, running_train_accuracy, val_loss, val_accuracy, KE_val_loss, KE_val_accuracy,
                            epoch_time, KE_time, KE_with_V_time, FOLDER_NAME):
        epoch_data = {
            'Epoch': [epoch+1],
            'LR' : [current_lr],
            'Running Train Loss': [running_train_loss], 
            'Running Train Accuracy': [running_train_accuracy],
            'Validation Loss': [val_loss],
            'Validation Accuracy': [val_accuracy],
            'KE Validation Loss': [KE_val_loss],
            'KE Validation Accuracy': [KE_val_accuracy],
            'Epoch Time': [epoch_time],
            'KE Time': [KE_time],
            'KE + Forward Pass Time': [KE_with_V_time],
        }
    
        df = pd.DataFrame(epoch_data)
        file_path = os.path.join(FOLDER_NAME, f'a{ALPHA}-epoch_results.csv')
        header = not os.path.exists(file_path)
        df.to_csv(file_path, mode='a', header=header, index=False)
        print(f"Epoch {epoch+1} results saved...")
    
    def save_checkpoint(model, optimizer, epoch, FOLDER_NAME):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        checkpoint_path = os.path.join(FOLDER_NAME, f"a{ALPHA}-checkpoint_epoch_{epoch+1}.pth")
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1} to {checkpoint_path}...")
    
    def summarize_gpu_info():
        num_gpus = torch.cuda.device_count()
        for i in range(num_gpus):
            handle = nvmlDeviceGetHandleByIndex(i)
            mem_info = nvmlDeviceGetMemoryInfo(handle)
            gpu_util = nvmlDeviceGetUtilizationRates(handle)
            gpu_temp = nvmlDeviceGetTemperature(handle, NVML_TEMPERATURE_GPU)
            print(f"GPU {i}:")
            print(f"  Memory Usage: {mem_info.used / 1024 ** 2} MB (Used) / {mem_info.total / 1024 ** 2} MB (Total)")
            print(f"  GPU Utilization: {gpu_util.gpu} %")
            print(f"  GPU Temperature: {gpu_temp} °C")
        print(f"CPU Usage: {psutil.cpu_percent()}%")
        print(f"Memory Usage: {psutil.virtual_memory().percent}%")
        print("-" * 60)
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    def Callback(model):
        if isinstance(model, nn.DataParallel):
            model = model.module
        with torch.no_grad():
            fc1_weights = model.fc1.weight
            fc1_biases = model.fc1.bias
            for i in range(1, ALPHA * 2, 2):
                fc1_weights[i].copy_(fc1_weights[i - 1])
                fc1_biases[i].copy_(fc1_biases[i - 1]) 
    
    from torch.optim import AdamW
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    
    optimizer = AdamW(
        model.parameters(),
        lr=1e-3,
        weight_decay=1e-4,
        betas=(0.9, 0.999)
    )
    
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5,
        verbose=True
    )
    
    print("\n" + "#" * 60)
    print(f"#         TRAINING STARTED: {MODEL_NAME} on {DATASET_NAME} Dataset        #")
    print("#" * 60 + "\n")
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
    
        print("\n" + "="*60)
        print(f"........... Epoch {epoch+1} Start - GPU & CPU monitoring ...........")
        gc.collect()
        torch.cuda.empty_cache()
        summarize_gpu_info()
    
        print("\n" + "="*60)
        print(f"...................... Epoch {epoch+1} Start .......................")
        print("="*60)
    
        running_train_loss, running_train_accuracy = train_loop(model, train_loader, optimizer, criterion, device)
        val_metrics = val_test_loop(model, val_loader, criterion, device)
        val_loss = val_metrics['loss']
        val_accuracy = val_metrics['top1_accuracy']
        epoch_time = time.time() - epoch_start 
        
        KE_start = time.time()
        Callback(model)
        KE_time = time.time() - KE_start
        
        KE_val_metrics = val_test_loop(model, val_loader, criterion, device)
        KE_val_loss = KE_val_metrics['loss']
        KE_val_accuracy = KE_val_metrics['top1_accuracy']
        KE_with_V_time = time.time() - KE_start
        
        print("\n" + "="*60)
        print(f"...................... Epoch {epoch+1} Results .....................")
        print('--------------- Before Callback -----------------')
        print(f"Epoch {epoch+1}/{num_epochs} Summary:")
        print("-" * 60)
        print(f"Running Train Loss: {running_train_loss:.6f}, Running Train Accuracy: {running_train_accuracy:.4f}%")
        print(f"Validation Loss: {val_loss:.6f}, Validation Accuracy: {val_accuracy:.4f}%")
    
        print('\n---------------- After Callback -----------------')
        print(f"Epoch {epoch+1}/{num_epochs} Summary:")
        print("-" * 60)
        print(f"Validation Loss: {KE_val_loss:.6f}, Validation Accuracy: {KE_val_accuracy:.4f}%")
        
        print(f"\n~~~*~~~ Time Consumptions ~~~*~~~")
        print(f"Epoch {epoch+1} completed in {epoch_time/60:.4f} minutes.")
        print(f"Time Consumption of Callback: {KE_time:.6f} seconds.")
        print(f"Time Consumption of Callback + Forward Pass Validation dataset: {KE_with_V_time/60:.4f} minutes.")
        print(f"Time Elapsed Since Epoch {epoch + 1} Started: {(time.time() - epoch_start)/60:.4f} minutes.")
        print(f"Total Time Elapsed Since Training Started: {(time.time() - start_time)/60:.4f} minutes.\n")  
            
        current_lr = optimizer.param_groups[0]['lr']
    
        save_epoch_results(epoch, current_lr, running_train_loss, running_train_accuracy, val_loss, val_accuracy, KE_val_loss, KE_val_accuracy,
                            epoch_time, KE_time, KE_with_V_time, FOLDER_NAME)
    
        print(f"\n.........Epoch {epoch+1} End - GPU & CPU monitoring.........")
        summarize_gpu_info()
    
        if (epoch + 1) % 50 == 0:
            save_checkpoint(model, optimizer, epoch, FOLDER_NAME)
    
        print("="*60)
        print("\n")
    
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            model_weights_path = os.path.join(FOLDER_NAME, f"a{ALPHA}-best_model.pth")
            torch.save(model.state_dict(), model_weights_path)
        else:
            epochs_without_improvement += 1
            
        if epochs_without_improvement >= patience:
            print("\n================ Early Stopping Triggered! =================\n")
            break
    
        scheduler.step(val_loss)
        
        torch.cuda.empty_cache()
        gc.collect()
    
    end_time = time.time()
    total_time = end_time - start_time
    
    print(f"\nTraining complete! Total time : {total_time:.4f} seconds.")
    print(f"                              : {total_time / 60:.4f} minutes.")
    print(f"                              : {total_time / 3600:.4f} hours.")
    print(f"                              : {total_time / 86400:.4f} days.")
    print("=" * 60)
    print("\n*~*~*~*~*~*~*~*~*~*~*~*~*~* THE END *~*~*~*~*~*~*~*~*~*~*~*~*~*\n")
    print("=" * 60)
    
    nvmlShutdown()
    
    Best_model = FCNN()
    weights_path = os.path.join(FOLDER_NAME, f"a{ALPHA}-best_model.pth")
    state_dict = torch.load(weights_path, map_location=device, weights_only=True)
    
    if any(k.startswith("module.") for k in state_dict):
        new_state_dict = {}
        for k, v in state_dict.items():
            new_key = k.replace("module.", "")
            new_state_dict[new_key] = v
        state_dict = new_state_dict
    Best_model.load_state_dict(state_dict)
    
    if torch.cuda.device_count() > 1:
        Best_model = nn.DataParallel(Best_model)
    Best_model = Best_model.to(device)
    Callback(Best_model)
    Best_model.eval()
    test_metrics  = val_test_loop(Best_model, test_loader, criterion, device, compute_top5=True)
    
    print('\n')
    print("\n" + "="*60)
    print("######################## CoupledNet ########################")
    print("------------- Inference on Validation Dataset --------------")
    print(f"Top-1 Accuracy: {test_metrics['top1_accuracy']:.4f}%, Top-5 Accuracy: {test_metrics['top5_accuracy']:.4f}%, Test Loss: {test_metrics['loss']:.6f}.")
    print("="*60)
    
    import copy
    
    RE_model = copy.deepcopy(Best_model)
    Random_model = FCNN().to(device)
    RE_model.eval()
    Random_model.eval()
    
    Best_model = Best_model.module if isinstance(Best_model, nn.DataParallel) else Best_model
    RE_model = RE_model.module if isinstance(RE_model, nn.DataParallel) else RE_model
    Random_model = Random_model.module if isinstance(Random_model, nn.DataParallel) else Random_model
    
    def RE_calc(layer_name: str, affected_layer_name: str, ALPHA: int, Best_model: nn.Module, RE_model: nn.Module, Random_model: nn.Module):
        Best_layer = getattr(Best_model, layer_name)
        RE_layer = getattr(RE_model, layer_name)
        Random_layer = getattr(Random_model, layer_name)
        
        W = Best_layer.weight.data.clone()
        B = Best_layer.bias.data.clone()
        
        device = W.device
        
        neg_bias = B < 0
        negative_bias_mask = neg_bias[:, None].expand_as(W)
        
        negative_weight_mask = W < 0
        CASE4_mask = negative_bias_mask & negative_weight_mask        
        Final_mask = CASE4_mask.clone()
        
        if layer_name == affected_layer_name:
            Final_mask[:, :2*ALPHA] = True
            has_only_neg = torch.all(CASE4_mask[:, 2*ALPHA:], dim=1)
            Final_mask[has_only_neg, 2*ALPHA:] = True
            Final_mask[~has_only_neg, 2*ALPHA:] = False
            RE_layer.weight.data[Final_mask] = Random_layer.weight.data[Final_mask]
        
        dummy_mask = torch.all(CASE4_mask, dim=1)
        Final_mask[dummy_mask, :] = True
        Final_mask[~dummy_mask, :] = False
        RE_layer.weight.data[Final_mask] = Random_layer.weight.data[Final_mask]
    
    print('\n')
    print("################## Reverse Engineering #####################")
    with torch.no_grad():
        for name, layer in Best_model.named_children():
            if isinstance(layer, nn.Linear) and name != "fc1":
                print(f"Processing layer: {name}")
                RE_calc(layer_name=name, affected_layer_name="fc2", ALPHA=ALPHA, Best_model=Best_model, RE_model=RE_model, Random_model=Random_model)
                print(f"Done..")
    print(f"Number of Epochs: {epoch - patience}")
    
    RE_model_parallel = nn.DataParallel(RE_model).to(device)
    Random_model_parallel = nn.DataParallel(Random_model).to(device)
    
    RE_test_metrics  = val_test_loop(RE_model, test_loader, criterion, device, compute_top5=True)
    Random_test_metrics  = val_test_loop(Random_model, test_loader, criterion, device, compute_top5=True)
    
    print("\n------------- Inference on Validation Dataset --------------")
    print(f"Top-1 Accuracy: {RE_test_metrics['top1_accuracy']:.4f}%, Top-5 Accuracy: {RE_test_metrics['top5_accuracy']:.4f}%, Test Loss: {RE_test_metrics['loss']:.6f}.")
    print("="*60)
    
    print("\n------------- Random Model - testing purpose --------------")
    print(f"Top-1 Accuracy: {Random_test_metrics['top1_accuracy']:.4f}%, Top-5 Accuracy: {Random_test_metrics['top5_accuracy']:.4f}%, Test Loss: {Random_test_metrics['loss']:.6f}.")
    print("="*60)
    
    log_file.close()
    sys.stdout = sys.__stdout__