In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("hojjatk/mnist-dataset")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Downloading from https://www.kaggle.com/api/v1/datasets/download/hojjatk/mnist-dataset?dataset_version_number=1...


100%|██████████| 22.0M/22.0M [00:15<00:00, 1.46MB/s]

Extracting files...





Path to dataset files: C:\Users\Hokta\.cache\kagglehub\datasets\hojjatk\mnist-dataset\versions\1


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import time
import os
import numpy as np
import matplotlib.pyplot as plt

# 1. Định nghĩa mô hình CNN đơn giản
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.pool1(x)
        x = self.relu2(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

# 2. Chuẩn bị dữ liệu
def load_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
    
    # Tập dữ liệu nhỏ cho calibration
    calib_sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=100)
    calib_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=32, num_workers=0, sampler=calib_sampler
    )
    
    return train_loader, test_loader, calib_loader

# 3. Huấn luyện mô hình
def train_model(model, train_loader, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            if i % 100 == 99:
                print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0
    
    print('Huấn luyện hoàn tất!')
    return model

# 4. Đánh giá mô hình
def evaluate_model(model, test_loader, name="FP32"):
    model.eval()
    correct = 0
    total = 0
    
    start_time = time.time()
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    inference_time = time.time() - start_time
    accuracy = 100 * correct / total
    
    print(f'Model {name}:')
    print(f'Accuracy: {accuracy:.2f}%')
    print(f'Inference time: {inference_time:.4f} seconds')
    
    return accuracy, inference_time

# 5. Quantization của mô hình
def prepare_for_quantization(model):
    # Thêm các quan sát giá trị (observers) cho quantization
    model_fused = torch.quantization.fuse_modules(model, [['conv1', 'relu1'], 
                                                         ['conv2', 'relu2'], 
                                                         ['fc1', 'relu3']])
    return model_fused

def quantize_model(model, calib_loader):
    # Chuẩn bị mô hình
    model_fused = prepare_for_quantization(model)
    
    # Cấu hình quantization
    model_fused.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(model_fused, inplace=True)
    
    # Calibration
    model_fused.eval()
    with torch.no_grad():
        for inputs, _ in calib_loader:
            model_fused(inputs)
    
    # Chuyển đổi sang mô hình quantized
    model_quantized = torch.quantization.convert(model_fused, inplace=True)
    
    return model_quantized

# 6. So sánh kích thước mô hình
def compare_model_sizes(fp32_model, int8_model):
    import os
    def get_model_size(model):
        torch.save(model.state_dict(), "temp.p")
        size = os.path.getsize("temp.p") / 1e6  # MB
        os.remove('temp.p')
        return size
    
    fp32_size = get_model_size(fp32_model)
    int8_size = get_model_size(int8_model)
    
    print(f"FP32 Model Size: {fp32_size:.2f} MB")
    print(f"INT8 Model Size: {int8_size:.2f} MB")
    print(f"Compression ratio: {fp32_size/int8_size:.2f}x")
    
    return fp32_size, int8_size

# 7. Hiển thị kết quả
def plot_comparison(fp32_metrics, int8_metrics):
    labels = ['FP32', 'INT8']
    accuracy = [fp32_metrics[0], int8_metrics[0]]
    inference_time = [fp32_metrics[1], int8_metrics[1]]
    
    x = np.arange(len(labels))
    width = 0.35
    
    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax2 = ax1.twinx()
    
    bars1 = ax1.bar(x - width/2, accuracy, width, label='Accuracy (%)', color='blue')
    ax1.set_ylabel('Accuracy (%)', color='blue')
    ax1.set_ylim([0, 100])
    
    bars2 = ax2.bar(x + width/2, inference_time, width, label='Inference Time (s)', color='red')
    ax2.set_ylabel('Inference Time (s)', color='red')
    
    ax1.set_xticks(x)
    ax1.set_xticklabels(labels)
    ax1.legend(loc='upper left')
    ax2.legend(loc='upper right')
    
    plt.title('So sánh mô hình FP32 và INT8 Quantized')
    plt.tight_layout()
    plt.savefig('quantization_comparison.png')
    plt.show()

# 8. Hàm chính
def main():
    # Tải dữ liệu
    train_loader, test_loader, calib_loader = load_data()
    
    # Tạo và huấn luyện mô hình
    model = SimpleCNN()
    if not os.path.exists('mnist_fp32.pth'):
        model = train_model(model, train_loader)
        torch.save(model.state_dict(), 'mnist_fp32.pth')
    else:
        model.load_state_dict(torch.load('mnist_fp32.pth'))
    
    # Đánh giá mô hình FP32
    fp32_metrics = evaluate_model(model, test_loader, "FP32")
    
    # Thực hiện quantization
    quantized_model = quantize_model(model, calib_loader)
    
    # Đánh giá mô hình quantized
    int8_metrics = evaluate_model(quantized_model, test_loader, "INT8")
    
    # So sánh kích thước mô hình
    compare_model_sizes(model, quantized_model)
    
    # Hiển thị kết quả so sánh
    plot_comparison(fp32_metrics, int8_metrics)
    
    # Lưu mô hình quantized
    torch.save(quantized_model.state_dict(), 'mnist_int8.pth')
    
    print("Pipeline quantization hoàn tất!")

if __name__ == "__main__":
    main()

[1, 100] loss: 0.572
[1, 200] loss: 0.155
[1, 300] loss: 0.107
[1, 400] loss: 0.099
[1, 500] loss: 0.084
[1, 600] loss: 0.078
[1, 700] loss: 0.060
[1, 800] loss: 0.065
[1, 900] loss: 0.063
[2, 100] loss: 0.040
[2, 200] loss: 0.047
[2, 300] loss: 0.040
[2, 400] loss: 0.044
[2, 500] loss: 0.042
[2, 600] loss: 0.047
[2, 700] loss: 0.038
[2, 800] loss: 0.044
[2, 900] loss: 0.041
[3, 100] loss: 0.027
[3, 200] loss: 0.033
[3, 300] loss: 0.031
[3, 400] loss: 0.028
[3, 500] loss: 0.028
[3, 600] loss: 0.027
[3, 700] loss: 0.030
[3, 800] loss: 0.029
[3, 900] loss: 0.037
[4, 100] loss: 0.015
[4, 200] loss: 0.023
[4, 300] loss: 0.025
[4, 400] loss: 0.020
[4, 500] loss: 0.027
[4, 600] loss: 0.018
[4, 700] loss: 0.023
[4, 800] loss: 0.028
[4, 900] loss: 0.027
[5, 100] loss: 0.010
[5, 200] loss: 0.013
[5, 300] loss: 0.014
[5, 400] loss: 0.017
[5, 500] loss: 0.014
[5, 600] loss: 0.012
[5, 700] loss: 0.021
[5, 800] loss: 0.023
[5, 900] loss: 0.020
Huấn luyện hoàn tất!
Model FP32:
Accuracy: 98.92%
Infer



NotImplementedError: Could not run 'quantized::conv2d_relu.new' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d_relu.new' is only available for these backends: [Meta, QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMeta, Tracer, AutocastCPU, AutocastMTIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\quantized\cpu\qconv.cpp:2044 [kernel]
QuantizedCUDA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\quantized\cudnn\Conv.cpp:386 [kernel]
BackendSelect: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\DynamicLayer.cpp:479 [backend fallback]
Functionalize: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\FunctionalizeFallbackKernel.cpp:349 [backend fallback]
Named: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:100 [backend fallback]
AutogradOther: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:63 [backend fallback]
AutogradCPU: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:67 [backend fallback]
AutogradCUDA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:75 [backend fallback]
AutogradXLA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:83 [backend fallback]
AutogradMPS: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:91 [backend fallback]
AutogradXPU: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:71 [backend fallback]
AutogradHPU: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:104 [backend fallback]
AutogradLazy: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:87 [backend fallback]
AutogradMTIA: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:79 [backend fallback]
AutogradMeta: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:95 [backend fallback]
Tracer: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\TraceTypeManual.cpp:294 [backend fallback]
AutocastCPU: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:322 [backend fallback]
AutocastMTIA: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:466 [backend fallback]
AutocastXPU: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:504 [backend fallback]
AutocastMPS: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\TensorWrapper.cpp:208 [backend fallback]
PythonTLSSnapshot: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\functorch\DynamicLayer.cpp:475 [backend fallback]
PreDispatch: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:198 [backend fallback]


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import time
import os
import numpy as np
import matplotlib.pyplot as plt


# 1. Define models with clear separation between quantizable and regular versions
class QuantizableCNN(nn.Module):
    def __init__(self):
        super(QuantizableCNN, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x):
        x = self.quant(x)
        x = self.relu1(self.conv1(x))
        x = self.pool1(x)
        x = self.relu2(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        x = self.dequant(x)
        return x
    
    def fuse_model(self):
        # Safer fusion approach
        for module_name, module in self.named_children():
            if module_name == 'quant' or module_name == 'dequant':
                continue
            if type(module) == nn.Sequential:
                torch.quantization.fuse_modules(module, ['0', '1'], inplace=True)
        
        # Fuse conv-bn-relu
        torch.quantization.fuse_modules(
            self, 
            [['conv1', 'relu1'], 
             ['conv2', 'relu2'], 
             ['fc1', 'relu3']], 
            inplace=True
        )


# Regular model for training
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.pool1(x)
        x = self.relu2(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x


# 2. Improved data loading with robust error handling
def load_data(batch_size=64):
    print("Loading MNIST dataset...")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Create data directory if it doesn't exist
    os.makedirs('./data', exist_ok=True)
    
    try:
        # Load dataset with proper error handling
        train_full = datasets.MNIST('./data', train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
        
        # Split dataset into train and validation
        train_size = int(0.8 * len(train_full))
        val_size = len(train_full) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_full, [train_size, val_size], generator=torch.Generator().manual_seed(42)
        )
        
        # Create DataLoaders with appropriate batch sizes
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
        
        # Small dataset for calibration - keep it small and manageable
        calib_dataset = torch.utils.data.Subset(
            train_dataset, 
            indices=torch.randperm(len(train_dataset))[:100].tolist()
        )
        calib_loader = torch.utils.data.DataLoader(
            calib_dataset,
            batch_size=10,
            shuffle=False
        )
        
        print(f"Dataset loaded successfully. Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
        return train_loader, val_loader, test_loader, calib_loader
        
    except Exception as e:
        print(f"Error loading dataset: {e}")
        raise


# 3. Train model with early stopping and robust checkpointing
def train_model(model, train_loader, val_loader, epochs=10, patience=3, device='cpu', checkpoint_path='checkpoint.pth'):
    print(f"Training model on {device}...")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Move model to device
    model = model.to(device)
    
    best_val_loss = float('inf')
    best_model_state = None
    counter = 0
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
    
    try:
        for epoch in range(epochs):
            # Training phase
            model.train()
            running_loss = 0.0
            train_total_loss = 0.0
            
            for i, (inputs, labels) in enumerate(train_loader):
                # Move tensors to device
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                train_total_loss += loss.item()
                
                if i % 100 == 99:
                    print(f'[{epoch + 1}, {i + 1}] train loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
            
            avg_train_loss = train_total_loss / len(train_loader)
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            correct = 0
            total = 0
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            val_loss = val_loss / len(val_loader)
            val_accuracy = 100 * correct / total
            
            # Save history
            history['train_loss'].append(avg_train_loss)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_accuracy)
            
            print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss:.3f}, Val Loss: {val_loss:.3f}, Val Acc: {val_accuracy:.2f}%')
            
            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict().copy()
                counter = 0
                
                # Save best model checkpoint
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': best_model_state,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': best_val_loss,
                    'val_accuracy': val_accuracy
                }, checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")
                
            else:
                counter += 1
                if counter >= patience:
                    print(f'Early stopping triggered after epoch {epoch + 1}')
                    break
        
        # Restore best model
        if best_model_state is not None:
            model.load_state_dict(best_model_state)
        
        # Plot training history
        plot_training_history(history)
        
        print('Training completed successfully!')
        return model
        
    except Exception as e:
        print(f"Error during training: {e}")
        # Try to restore from checkpoint if available
        if os.path.exists(checkpoint_path):
            print(f"Attempting to load model from checkpoint {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Restored model from checkpoint (epoch {checkpoint['epoch']})")
        return model


# 4. Evaluate model with better error handling
def evaluate_model(model, test_loader, name="FP32", device='cpu'):
    try:
        print(f"Evaluating {name} model...")
        model.eval()
        
        # If using quantized model, ensure it's on CPU
        if 'INT8' in name and next(model.parameters()).device.type != 'cpu':
            print(f"Moving {name} model to CPU for evaluation (quantization requires CPU)")
            model = model.cpu()
        
        # Start timing
        model_device = next(model.parameters()).device
        print(f"Model is on {model_device}")
        
        correct = 0
        total = 0
        
        # Batch processing with timing
        start_time = time.time()
        with torch.no_grad():
            for inputs, labels in test_loader:
                # Match device with model
                inputs = inputs.to(model_device)
                labels = labels.to(model_device)
                
                # Forward pass
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        inference_time = time.time() - start_time
        accuracy = 100 * correct / total
        
        print(f"Model {name}:")
        print(f"Accuracy: {accuracy:.2f}%")
        print(f"Inference time: {inference_time:.4f} seconds")
        
        return accuracy, inference_time
        
    except Exception as e:
        print(f"Error evaluating model: {e}")
        return 0.0, 0.0


# 5. Improved static quantization
def static_quantize_model(trained_model, calib_loader):
    print("Preparing model for static quantization...")
    
    # Always ensure the model is on CPU for quantization
    trained_model.cpu()
    
    # Create a new quantizable model instance
    quantizable_model = QuantizableCNN()
    
    try:
        # Copy weights with proper error handling
        print("Copying weights to quantizable model...")
        for q_name, q_param in quantizable_model.named_parameters():
            if q_name.endswith(('.weight', '.bias')):
                # Extract module and param name
                param_parts = q_name.split('.')
                module_name = param_parts[0]  # e.g., 'conv1', 'fc1'
                param_type = param_parts[-1]  # 'weight' or 'bias'
                
                # Construct source parameter name
                source_name = f"{module_name}.{param_type}"
                
                # Copy parameter if found in source model
                if source_name in dict(trained_model.named_parameters()):
                    source_param = dict(trained_model.named_parameters())[source_name]
                    with torch.no_grad():
                        q_param.copy_(source_param)
        
        # Verify parameter copying
        print("Weight copying completed. Setting model to eval mode...")
        quantizable_model.eval()
        
        # Fuse modules
        print("Fusing modules...")
        quantizable_model.fuse_model()
        
        # Get supported quantization backend
        backend = 'fbgemm'  # Default for x86
        if torch.backends.quantized.supported_engines:
            # Check for supported backends
            if 'fbgemm' in torch.backends.quantized.supported_engines:
                backend = 'fbgemm'  # Better for server (x86)
            elif 'qnnpack' in torch.backends.quantized.supported_engines:
                backend = 'qnnpack'  # Better for mobile (ARM)
            
        print(f"Using quantization backend: {backend}")
        torch.backends.quantized.engine = backend
        
        # Set quantization configuration
        quantizable_model.qconfig = torch.quantization.get_default_qconfig(backend)
        print(f"Quantization config: {quantizable_model.qconfig}")
        
        # Prepare model for calibration
        print("Preparing model for calibration...")
        prepared_model = torch.quantization.prepare(quantizable_model)
        
        # Run calibration
        print("Running calibration...")
        with torch.no_grad():
            for i, (inputs, _) in enumerate(calib_loader):
                prepared_model(inputs)
                if i >= 9:  # Just use 10 batches for calibration
                    break
        
        # Convert to quantized model
        print("Converting to fully quantized model...")
        quantized_model = torch.quantization.convert(prepared_model)
        
        print("Static quantization completed successfully!")
        return quantized_model
        
    except Exception as e:
        print(f"Error during static quantization: {e}")
        raise


# 6. Improved dynamic quantization
def dynamic_quantize_model(model):
    print("Preparing for dynamic quantization...")
    # Ensure model is on CPU
    model.cpu().eval()
    
    try:
        # Dynamic quantization focuses on weights of specific layers
        print("Applying dynamic quantization to linear layers...")
        quantized_model = torch.quantization.quantize_dynamic(
            model,  # model to quantize
            {nn.Linear},  # a set of layers to dynamically quantize
            dtype=torch.qint8  # target quantization dtype
        )
        
        print("Dynamic quantization completed successfully!")
        return quantized_model
        
    except Exception as e:
        print(f"Error during dynamic quantization: {e}")
        raise


# 7. Compare model sizes with better visualization
def compare_model_sizes(fp32_model, int8_model, fp32_name="FP32", int8_name="INT8"):
    print("Comparing model sizes...")
    
    def get_model_size(model, filename="temp_model.pth"):
        try:
            # Clean up previous temp file if it exists
            if os.path.exists(filename):
                os.remove(filename)
                
            torch.save(model.state_dict(), filename)
            size = os.path.getsize(filename) / 1e6  # MB
            os.remove(filename)
            return size
        except Exception as e:
            print(f"Error measuring model size: {e}")
            return 0
    
    fp32_size = get_model_size(fp32_model, "temp_fp32.pth")
    int8_size = get_model_size(int8_model, "temp_int8.pth")
    
    compression_ratio = fp32_size / int8_size if int8_size > 0 else 0
    
    print(f"{fp32_name} Model Size: {fp32_size:.2f} MB")
    print(f"{int8_name} Model Size: {int8_size:.2f} MB")
    print(f"Compression ratio: {compression_ratio:.2f}x")
    
    # Create a simple bar chart for model sizes
    plt.figure(figsize=(10, 6))
    plt.bar([fp32_name, int8_name], [fp32_size, int8_size], color=['blue', 'green'])
    plt.title('Model Size Comparison')
    plt.ylabel('Size (MB)')
    plt.annotate(f"{fp32_size:.2f} MB", 
                 xy=(0, fp32_size), 
                 xytext=(0, fp32_size + 0.1),
                 ha='center')
    plt.annotate(f"{int8_size:.2f} MB", 
                 xy=(1, int8_size), 
                 xytext=(1, int8_size + 0.1),
                 ha='center')
    plt.savefig('model_size_comparison.png')
    plt.close()
    
    return fp32_size, int8_size, compression_ratio


# 8. Plot comparison with better layout
def plot_comparison(fp32_metrics, int8_metrics, int8_name="INT8"):
    labels = ['FP32', int8_name]
    accuracy = [fp32_metrics[0], int8_metrics[0]]
    inference_time = [fp32_metrics[1], int8_metrics[1]]
    
    # Create a figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot accuracy
    bars1 = ax1.bar(labels, accuracy, color=['blue', 'green'])
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('Model Accuracy Comparison')
    ax1.set_ylim([min(accuracy) - 5 if min(accuracy) > 5 else 0, 100])
    
    # Add value labels on accuracy bars
    for bar in bars1:
        height = bar.get_height()
        ax1.annotate(f'{height:.2f}%',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')
    
    # Plot inference time
    bars2 = ax2.bar(labels, inference_time, color=['blue', 'green'])
    ax2.set_ylabel('Inference Time (s)')
    ax2.set_title('Inference Time Comparison')
    
    # Add value labels on time bars
    for bar in bars2:
        height = bar.get_height()
        ax2.annotate(f'{height:.4f}s',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')
    
    # Add a speedup indicator on the inference time chart
    if inference_time[0] > 0 and inference_time[1] > 0:
        speedup = inference_time[0] / inference_time[1]
        ax2.text(0.5, 0.5, f'{speedup:.2f}x speedup',
                 horizontalalignment='center',
                 verticalalignment='center',
                 transform=ax2.transAxes,
                 bbox=dict(facecolor='white', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig('performance_comparison.png')
    plt.close()


# 9. Main function with improved error handling and fallback options
def main():
    print(f"PyTorch version: {torch.__version__}")
    
    # Check device availability
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Check quantization support
    print("Checking quantization support...")
    if hasattr(torch, 'quantization'):
        print(f"Quantization is supported. Available engines: {torch.backends.quantized.supported_engines}")
    else:
        print("WARNING: This PyTorch version doesn't properly support quantization!")
    
    # Set paths
    fp32_model_path = 'mnist_fp32.pth'
    static_int8_model_path = 'mnist_static_int8.pth'
    dynamic_int8_model_path = 'mnist_dynamic_int8.pth'
    checkpoint_path = 'mnist_training_checkpoint.pth'
    
    try:
        # 1. Load and prepare data
        train_loader, val_loader, test_loader, calib_loader = load_data(batch_size=64)
        
        # 2. Create and train/load FP32 model
        fp32_model = SimpleCNN()
        
        if os.path.exists(fp32_model_path):
            print(f"Loading pre-trained model from {fp32_model_path}...")
            fp32_model.load_state_dict(torch.load(fp32_model_path, map_location=device))
            fp32_model = fp32_model.to(device)
        else:
            print(f"Training new model, will save to {fp32_model_path}...")
            fp32_model = train_model(
                fp32_model, 
                train_loader, 
                val_loader, 
                epochs=5, 
                patience=2, 
                device=device,
                checkpoint_path=checkpoint_path
            )
            # Save model to file (move to CPU first for better compatibility)
            torch.save(fp32_model.cpu().state_dict(), fp32_model_path)
            # Move back to original device
            fp32_model = fp32_model.to(device)
            print(f"Model saved to {fp32_model_path}")
        
        # 3. Evaluate FP32 model
        fp32_metrics = evaluate_model(fp32_model, test_loader, "FP32", device)
        
        # 4. Quantization attempts - try both static and dynamic
        
        # Track which methods succeeded
        static_success = False
        dynamic_success = False
        
        # Try static quantization
        try:
            print("\n===== ATTEMPTING STATIC QUANTIZATION =====")
            
            if os.path.exists(static_int8_model_path):
                print(f"Loading pre-quantized static model from {static_int8_model_path}...")
                static_quantized_model = QuantizableCNN()
                static_quantized_model.load_state_dict(torch.load(static_int8_model_path, map_location='cpu'))
            else:
                print("Creating new static quantized model...")
                # Copy model for quantization
                model_for_static = SimpleCNN()
                model_for_static.load_state_dict(fp32_model.cpu().state_dict())
                
                # Perform static quantization
                static_quantized_model = static_quantize_model(model_for_static, calib_loader)
                
                # Save quantized model
                torch.save(static_quantized_model.state_dict(), static_int8_model_path)
                print(f"Static quantized model saved to {static_int8_model_path}")
            
            # Evaluate static quantized model
            static_int8_metrics = evaluate_model(static_quantized_model, test_loader, "INT8 (Static)", 'cpu')
            
            # Compare models
            fp32_size, static_int8_size, static_ratio = compare_model_sizes(
                fp32_model.cpu(), static_quantized_model, "FP32", "INT8 (Static)"
            )
            
            # Plot comparison
            plot_comparison(fp32_metrics, static_int8_metrics, "INT8 (Static)")
            
            static_success = True
            print("Static quantization workflow completed successfully!")
            
        except Exception as e:
            print(f"Static quantization workflow failed: {e}")
        
        # Try dynamic quantization
        try:
            print("\n===== ATTEMPTING DYNAMIC QUANTIZATION =====")
            
            if os.path.exists(dynamic_int8_model_path):
                print(f"Loading pre-quantized dynamic model from {dynamic_int8_model_path}...")
                dynamic_quantized_model = torch.load(dynamic_int8_model_path, map_location='cpu')
            else:
                print("Creating new dynamic quantized model...")
                # Copy model for quantization
                model_for_dynamic = SimpleCNN()
                model_for_dynamic.load_state_dict(fp32_model.cpu().state_dict())
                
                # Perform dynamic quantization
                dynamic_quantized_model = dynamic_quantize_model(model_for_dynamic)
                
                # Save quantized model
                torch.save(dynamic_quantized_model, dynamic_int8_model_path)
                print(f"Dynamic quantized model saved to {dynamic_int8_model_path}")
            
            # Evaluate dynamic quantized model
            dynamic_int8_metrics = evaluate_model(dynamic_quantized_model, test_loader, "INT8 (Dynamic)", 'cpu')
            
            # Compare models
            fp32_size, dynamic_int8_size, dynamic_ratio = compare_model_sizes(
                fp32_model.cpu(), dynamic_quantized_model, "FP32", "INT8 (Dynamic)"
            )
            
            # Plot comparison
            plot_comparison(fp32_metrics, dynamic_int8_metrics, "INT8 (Dynamic)")
            
            dynamic_success = True
            print("Dynamic quantization workflow completed successfully!")
            
        except Exception as e:
            print(f"Dynamic quantization workflow failed: {e}")
        
        # 5. Summary
        print("\n===== QUANTIZATION RESULTS SUMMARY =====")
        if static_success or dynamic_success:
            print("Successfully performed model quantization!")
            
            if static_success and dynamic_success:
                print("Both static and dynamic quantization methods were successful.")
                print(f"Static quantization compression: {static_ratio:.2f}x")
                print(f"Dynamic quantization compression: {dynamic_ratio:.2f}x")
                
                # Compare which is better
                if static_int8_metrics[0] > dynamic_int8_metrics[0]:
                    print("Static quantization achieved better accuracy.")
                else:
                    print("Dynamic quantization achieved better accuracy.")
                    
                if static_int8_metrics[1] < dynamic_int8_metrics[1]:
                    print("Static quantization achieved faster inference.")
                else:
                    print("Dynamic quantization achieved faster inference.")
                    
            elif static_success:
                print("Only static quantization was successful.")
            else:
                print("Only dynamic quantization was successful.")
        else:
            print("Both quantization methods failed.")
            print("Possible reasons for failure:")
            print("1. PyTorch version incompatibility (version should be >= 1.8.0)")
            print("2. Missing quantization operators for your specific model/layers")
            print("3. Platform compatibility issues with quantization backends")
            print("Suggestion: Try upgrading PyTorch or simplifying your model architecture.")
    
    except Exception as e:
        print(f"An error occurred in the main workflow: {e}")


if __name__ == "__main__":
    main()

PyTorch version: 2.7.0+cu118
Using device: cuda:0
Checking quantization support...
Quantization is supported. Available engines: ['none', 'onednn', 'x86', 'fbgemm']
Loading MNIST dataset...
Dataset loaded successfully. Train: 48000, Val: 12000, Test: 10000
Loading pre-trained model from mnist_fp32.pth...
Evaluating FP32 model...
Model is on cuda:0
Model FP32:
Accuracy: 98.92%
Inference time: 5.0005 seconds

===== ATTEMPTING STATIC QUANTIZATION =====
Creating new static quantized model...
Preparing model for static quantization...
Copying weights to quantizable model...
Weight copying completed. Setting model to eval mode...
Fusing modules...
Using quantization backend: fbgemm
Quantization config: QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
Preparing model for cali



: 

In [4]:
#!/usr/bin/env python3
"""
Simple Load Test Script - 100 concurrent requests với login
"""

import requests
import threading
import time
from concurrent.futures import ThreadPoolExecutor

# Cấu hình
URL = "http://meizu1908.id.vn/"
LOGIN_URL = "http://meizu1908.id.vn/api/auth/login"  # Thử API endpoint
LOGIN_URL_FALLBACK = "http://meizu1908.id.vn/login"  # Fallback nếu không có API
USERNAME = "levanminh19102003@gmail.com"
PASSWORD = "123456"
NUM_REQUESTS = 100

# Biến lưu kết quả
results = []
lock = threading.Lock()

def make_request(request_id):
    """Thực hiện 1 request với đăng nhập"""
    start_time = time.time()
    
    try:
        # Tạo session để giữ cookie
        session = requests.Session()
        
        # Lấy trang chính trước để lấy cookies/tokens
        session.get(URL, timeout=10)
        
        # Thử đăng nhập với JSON API trước
        login_data = {
            'username': USERNAME,
            'password': PASSWORD
        }
        
        headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json',
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
        }
        
        login_success = False
        
        # Thử API endpoint trước
        try:
            login_response = session.post(
                LOGIN_URL, 
                json=login_data,  # Gửi JSON
                headers=headers,
                timeout=10
            )
            
            if login_response.status_code == 200:
                # Kiểm tra response JSON
                try:
                    json_response = login_response.json()
                    if json_response.get('success') or json_response.get('token') or json_response.get('user'):
                        login_success = True
                except:
                    pass
            
        except Exception as e:
            # Nếu API fail, thử form thông thường
            try:
                login_response = session.post(
                    LOGIN_URL_FALLBACK,
                    data=login_data,  # Gửi form data
                    timeout=10
                )
                
                if login_response.status_code == 302:  # Redirect
                    login_success = True
                elif login_response.status_code == 200:
                    if any(keyword in login_response.text.lower() 
                          for keyword in ['dashboard', 'welcome', 'logout', 'profile']):
                        login_success = True
                        
            except Exception as e2:
                pass
        
        # Request chính sau khi đăng nhập
        response = session.get(URL, timeout=10)
        
        # Kiểm tra thêm bằng cách xem có content khác không
        if not login_success and response.status_code == 200:
            # Nếu có những từ này trong response thì có thể đã login
            if any(keyword in response.text.lower() 
                  for keyword in ['dashboard', 'profile', 'logout', 'welcome']):
                login_success = True
        
        end_time = time.time()
        response_time = end_time - start_time
        
        # Lưu kết quả
        with lock:
            results.append({
                'id': request_id,
                'login_success': login_success,
                'status': response.status_code,
                'time': response_time,
                'success': response.status_code == 200 and login_success
            })
        
        login_status = "✅ LOGIN OK" if login_success else "❌ LOGIN FAIL"
        print(f"Request {request_id}: {login_status} | Response: {response.status_code} - {response_time:.2f}s")
        
    except Exception as e:
        end_time = time.time()
        response_time = end_time - start_time
        
        with lock:
            results.append({
                'id': request_id,
                'login_success': False,
                'status': 0,
                'time': response_time,
                'success': False,
                'error': str(e)
            })
            
        print(f"Request {request_id}: ❌ ERROR - {str(e)}")

def run_load_test():
    """Chạy load test"""
    print(f"🚀 Bắt đầu test {NUM_REQUESTS} requests đồng thời")
    print(f"🎯 URL: {URL}")
    print(f"👤 Username: {USERNAME}")
    print("-" * 50)
    
    start_time = time.time()
    
    # Chạy 100 requests đồng thời
    with ThreadPoolExecutor(max_workers=50) as executor:
        futures = [executor.submit(make_request, i) for i in range(NUM_REQUESTS)]
        
        # Đợi tất cả hoàn thành
        for future in futures:
            future.result()
    
    end_time = time.time()
    total_time = end_time - start_time
    
    # In kết quả
    print("\n" + "="*50)
    print("📊 KẾT QUẢ")
    print("="*50)
    
    successful = [r for r in results if r['success']]
    failed = [r for r in results if not r['success']]
    login_success = [r for r in results if r.get('login_success', False)]
    login_failed = [r for r in results if not r.get('login_success', True)]
    
    print(f"⏱️  Tổng thời gian: {total_time:.2f}s")
    print(f"📊 Tổng requests: {len(results)}")
    print(f"🔐 Đăng nhập thành công: {len(login_success)}")
    print(f"🚫 Đăng nhập thất bại: {len(login_failed)}")
    print(f"✅ Requests thành công: {len(successful)}")
    print(f"❌ Requests thất bại: {len(failed)}")
    
    if successful:
        avg_time = sum(r['time'] for r in successful) / len(successful)
        min_time = min(r['time'] for r in successful)
        max_time = max(r['time'] for r in successful)
        
        print(f"⚡ Thời gian trung bình: {avg_time:.2f}s")
        print(f"🏃 Nhanh nhất: {min_time:.2f}s")
        print(f"🐌 Chậm nhất: {max_time:.2f}s")
        print(f"🔥 Requests/giây: {len(successful)/total_time:.2f}")
    
    # Hiển thị một số lỗi login nếu có
    if login_failed:
        print(f"\n⚠️  CẢNH BÁO: {len(login_failed)} lần đăng nhập thất bại!")
        print("Kiểm tra lại:")
        print(f"- URL login: {LOGIN_URL}")
        print(f"- Username: {USERNAME}")
        print("- Mật khẩu có đúng không?")
        print("- Trang có dùng React/MUI form không? (có thể cần API endpoint khác)")

if __name__ == "__main__":
    run_load_test()

🚀 Bắt đầu test 100 requests đồng thời
🎯 URL: http://meizu1908.id.vn/
👤 Username: levanminh19102003@gmail.com
--------------------------------------------------
Request 0: ❌ LOGIN FAIL | Response: 200 - 0.29s
Request 5: ❌ LOGIN FAIL | Response: 200 - 0.30s
Request 2: ❌ LOGIN FAIL | Response: 200 - 0.31s
Request 19: ❌ LOGIN FAIL | Response: 200 - 0.30s
Request 13: ❌ LOGIN FAIL | Response: 200 - 0.31s
Request 1: ❌ LOGIN FAIL | Response: 200 - 0.34s
Request 12: ❌ LOGIN FAIL | Response: 200 - 0.32s
Request 7: ❌ LOGIN FAIL | Response: 200 - 0.33s
Request 18: ❌ LOGIN FAIL | Response: 200 - 0.32s
Request 14: ❌ LOGIN FAIL | Response: 200 - 0.32s
Request 9: ❌ LOGIN FAIL | Response: 200 - 0.33s
Request 23: ❌ LOGIN FAIL | Response: 200 - 0.34s
Request 20: ❌ LOGIN FAIL | Response: 200 - 0.35s
Request 6: ❌ LOGIN FAIL | Response: 200 - 0.38s
Request 24: ❌ LOGIN FAIL | Response: 200 - 0.35s
Request 33: ❌ LOGIN FAIL | Response: 200 - 0.34s
Request 32: ❌ LOGIN FAIL | Response: 200 - 0.35s
Request 25: ❌ 