# Setup

In [1]:
import random
import h5py
import copy

import torch
import torch.nn as nn

from torchvision.models import *
from torchvision.transforms import *
from torch.utils.data import DataLoader, Dataset
from torch.optim import *
from torch.optim.lr_scheduler import *

import torch_pruning as tp

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from utils import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


Using cuda:0


<torch._C.Generator at 0x1d42f56b9f0>

In [2]:
class ArrhythmiaLabels:
    labels = {
        0: "N",
        1: "S",
        2: "V",
        3: "F",
        4: "Q",
    }
    size = 5

class EcgDataset(Dataset):
    def __init__(self, file_path: str, transform=None) -> None:
        super().__init__()

        self.h5_file   = h5py.File(file_path, "r")
        self.images    = self.h5_file["images"]
        self.labels    = self.h5_file["labels"]
        self.transform = transform

    def __del__(self):
        self.h5_file.close()

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx].astype(int)
        
        image = np.transpose(image, (1, 2, 0)) # (h, c, w) => (c, w, h)
        image = np.array(image, dtype=np.float32)

        label = np.array(label, dtype=np.int64)

        if self.transform:
            image = self.transform(image)

        return image, label
    
def build_dataloader(train_path: str, test_path: str, batch_size: int, transform) -> dict[str, DataLoader]:
    transforms = {
        "train": Compose([
            ToTensor(),
            transform,
        ]),
        "test": Compose([
            ToTensor(),
            transform,
        ])
    }

    dataset = {
        "train": EcgDataset(train_path, transform=transforms["train"]),
        "test":  EcgDataset(test_path, transform=transforms["test"]),       
    }

    dataloader = {}
    for split in ["train", "test"]:
        dataloader[split] = DataLoader(
            dataset[split],
            batch_size  = batch_size,
            shuffle     = (split == "train"),
            num_workers = 0,
            pin_memory  = True,
        )

    return dataloader

def visualize_ecg_data(dataloader: DataLoader) -> None:
    batch = next(iter(dataloader))
    num_images = len(batch[0])

    idx = random.choice(range(0, num_images))
    image = batch[0][idx]
    channels = torch.split(image, 1, dim=0)

    _, ax = plt.subplots(1, 3)
    for i, channel in enumerate(channels):
        ax[i].imshow(channel.squeeze(0).numpy(), cmap="gray")
        ax[i].axis("off")
    plt.show()


In [3]:
def load_model(name: str) -> nn.Module:
    match name:
        case "alexnet":
            return alexnet(weights=AlexNet_Weights.DEFAULT)
        case "resnet18":
            return resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        case "resnet50":
            return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        case "vgg11_bn":
            return vgg11_bn(weights=VGG11_BN_Weights.IMAGENET1K_V1)
        case "vgg16_bn":
            return vgg16_bn(weights=VGG16_BN_Weights.IMAGENET1K_V1)       
        case "vit_b_16":
            return vit_b_16(weights=ViT_B_16_Weights.DEFAULT) 
        case "mobilenet_v3":
            return mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.IMAGENET1K_V2)
        case "mobilenet_v3_small":
            return mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
    
def load_model_from_pretrained(name: str, path: str, num_classes: int) -> nn.Module:
    model = None
    match name:
        case "alexnet":
            model = alexnet()
            model.classifier[4] = nn.Linear(4096, 512)
            model.classifier[6] = nn.Linear(512, num_classes)

        case "resnet18":
            model = resnet18()
            model.fc = nn.Linear(512, num_classes)

        case "resnet50":
            model = resnet50()
            model.fc = nn.Linear(2048, num_classes)
            
        case "vgg11_bn":
            model = vgg11_bn()
            model.classifier[6] = nn.Linear(4096, num_classes)

        case "vgg16_bn":
            model = vgg16_bn()
            model.classifier[6] = nn.Linear(4096, num_classes)

        case "vit_b_16":
            model = vit_b_16()
            model.heads.head = nn.Linear(768, num_classes)
        
        case "mobilenet_v3":
            model = mobilenet_v3_large()
            model.classifier[3] = nn.Linear(1280, num_classes)

        case "mobilenet_v3_small":
            model = mobilenet_v3_small()
            model.classifier[3] = nn.Linear(1024, num_classes)

    model.load_state_dict(torch.load(path))
    return model

# Pruning

In [4]:
def load_vgg_from_pruned(
    path:          str, 
    pruning_ratio: float, 
    dummy_input:   torch.Tensor
) -> nn.Module:
    model = vgg16_bn()
    model.classifier[6] = nn.Linear(4096, 5)
    ignored_layers = [model.classifier[6]]

    model.to("cpu")
    model.eval()

    imp = tp.importance.MagnitudeImportance(p=2)
    pruner = tp.pruner.MagnitudePruner(
        model           = model,
        example_inputs  = dummy_input,
        importance      = imp,
        pruning_ratio   = pruning_ratio,
        ignored_layers  = ignored_layers,
    )
    pruner.step()

    model.load_state_dict(torch.load(path))
    return model

In [5]:
image_size = 152
dataloader = build_dataloader(
    train_path = "Data/mitbih_mif_train_small.h5",
    test_path  = "Data/mitbih_mif_test.h5",
    batch_size = 32,
    transform  = Resize((image_size, image_size))
)

model_name      = "vgg16_bn"
base_model_path = "Pretrained/vgg16_bn_ecg_ep50_i152.pth"
num_classes     = ArrhythmiaLabels.size
base_model      = load_model_from_pretrained(model_name, base_model_path, num_classes)

In [None]:
# Copy base model
pruned_model = copy.deepcopy(base_model).to("cpu")
pruned_model.eval()

# Get dummy input for tracing
dummy_input = torch.rand((1, 3, image_size, image_size))

# Ignore the last classification layer
ignored_layers = [pruned_model.classifier[6]]

# Pruning objects
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model             = pruned_model,
    example_inputs    = dummy_input,
    importance        = imp,
    global_pruning    = True,
    pruning_ratio     = 0.5,
    max_pruning_ratio = 0.9,
    ignored_layers    = ignored_layers,
)

pruner.step()

pruned_model_stats = benchmark_model(pruned_model, dataloader["test"], name="Pruned")
display_model_stats(pruned_model_stats)

In [None]:
# Finetuning
num_finetune_epochs = 20
pruned_model.to(device)
optimizer = SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_model_checkpoint = dict()
best_accuracy         = 0

print("Finetuning")
for epoch in range(num_finetune_epochs):
    train(pruned_model, dataloader["train"], criterion, optimizer, scheduler)
    accuracy = evaluate(pruned_model, dataloader["test"])

    if accuracy > best_accuracy:
        best_model_checkpoint["state_dict"] = copy.deepcopy(pruned_model.state_dict())
        best_accuracy = accuracy
    
    print(f"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%")

# Save best model
save_path = f"Pretrained/VGG-Pruned/{model_name}_ecg_ep{num_finetune_epochs}_i152_g0.5.pth"
torch.save(best_model_checkpoint["state_dict"], save_path)

In [7]:
ratios = [0.5, 0.6, 0.7, 0.8, 0.9]
for ratio in ratios:

    if ratio == 0.9:
        max_prune = 0.95
    else:
        max_prune = 0.9

    # Copy base model
    pruned_model = copy.deepcopy(base_model).to("cpu")
    pruned_model.eval()

    # Get dummy input for tracing
    dummy_input = torch.rand((1, 3, image_size, image_size))

    # Ignore the last classification layer
    ignored_layers = [pruned_model.classifier[6]]

    # Pruning objects
    imp = tp.importance.MagnitudeImportance(p=2)

    pruner = tp.pruner.MagnitudePruner(
        model             = pruned_model,
        example_inputs    = dummy_input,
        importance        = imp,
        global_pruning    = True,
        pruning_ratio     = ratio,
        max_pruning_ratio = max_prune,
        ignored_layers    = ignored_layers,
    )

    pruner.step()

    # Finetuning
    num_finetune_epochs = 20
    pruned_model.to(device)
    optimizer = SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)
    criterion = nn.CrossEntropyLoss()

    best_model    = copy.deepcopy(pruned_model)
    best_accuracy = 0

    print("Finetuning")
    for epoch in range(num_finetune_epochs):
        train(pruned_model, dataloader["train"], criterion, optimizer, scheduler)
        accuracy = evaluate(pruned_model, dataloader["test"])

        if accuracy > best_accuracy:
            best_model = copy.deepcopy(pruned_model)
            best_accuracy = accuracy
        
        print(f"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%")

    # Save best model
    save_path = f"Pretrained/VGG-Pruned/{model_name}_ecg_ep{num_finetune_epochs}_i152_g{ratio}.pth"
    best_model.zero_grad()
    torch.save(best_model, save_path)

Finetuning


                                                        

Epoch 1 Accuracy 97.43% / Best Accuracy: 97.43%


                                                        

Epoch 2 Accuracy 97.07% / Best Accuracy: 97.43%


                                                        

Epoch 3 Accuracy 97.60% / Best Accuracy: 97.60%


                                                        

Epoch 4 Accuracy 97.37% / Best Accuracy: 97.60%


                                                        

Epoch 5 Accuracy 97.77% / Best Accuracy: 97.77%


                                                        

Epoch 6 Accuracy 98.01% / Best Accuracy: 98.01%


                                                        

Epoch 7 Accuracy 98.12% / Best Accuracy: 98.12%


                                                        

Epoch 8 Accuracy 97.35% / Best Accuracy: 98.12%


                                                        

Epoch 9 Accuracy 95.12% / Best Accuracy: 98.12%


                                                        

Epoch 10 Accuracy 98.07% / Best Accuracy: 98.12%


                                                        

Epoch 11 Accuracy 96.69% / Best Accuracy: 98.12%


                                                        

Epoch 12 Accuracy 97.54% / Best Accuracy: 98.12%


                                                        

Epoch 13 Accuracy 94.54% / Best Accuracy: 98.12%


                                                        

Epoch 14 Accuracy 97.59% / Best Accuracy: 98.12%


                                                        

Epoch 15 Accuracy 97.97% / Best Accuracy: 98.12%


                                                        

Epoch 16 Accuracy 98.17% / Best Accuracy: 98.17%


                                                        

Epoch 17 Accuracy 98.05% / Best Accuracy: 98.17%


                                                        

Epoch 18 Accuracy 97.90% / Best Accuracy: 98.17%


                                                        

Epoch 19 Accuracy 98.25% / Best Accuracy: 98.25%


                                                        

Epoch 20 Accuracy 91.06% / Best Accuracy: 98.25%
Finetuning


                                                        

Epoch 1 Accuracy 95.25% / Best Accuracy: 95.25%


                                                        

Epoch 2 Accuracy 96.62% / Best Accuracy: 96.62%


                                                        

Epoch 3 Accuracy 96.44% / Best Accuracy: 96.62%


                                                        

Epoch 4 Accuracy 97.27% / Best Accuracy: 97.27%


                                                        

Epoch 5 Accuracy 96.88% / Best Accuracy: 97.27%


                                                        

Epoch 6 Accuracy 96.65% / Best Accuracy: 97.27%


                                                        

Epoch 7 Accuracy 97.48% / Best Accuracy: 97.48%


                                                        

Epoch 8 Accuracy 97.93% / Best Accuracy: 97.93%


                                                        

Epoch 9 Accuracy 98.11% / Best Accuracy: 98.11%


                                                        

Epoch 10 Accuracy 97.90% / Best Accuracy: 98.11%


                                                        

Epoch 11 Accuracy 98.03% / Best Accuracy: 98.11%


                                                        

Epoch 12 Accuracy 97.38% / Best Accuracy: 98.11%


                                                        

Epoch 13 Accuracy 95.47% / Best Accuracy: 98.11%


                                                        

Epoch 14 Accuracy 98.22% / Best Accuracy: 98.22%


                                                        

Epoch 15 Accuracy 97.83% / Best Accuracy: 98.22%


                                                        

Epoch 16 Accuracy 98.22% / Best Accuracy: 98.22%


                                                        

Epoch 17 Accuracy 95.31% / Best Accuracy: 98.22%


                                                        

Epoch 18 Accuracy 98.00% / Best Accuracy: 98.22%


                                                        

Epoch 19 Accuracy 96.31% / Best Accuracy: 98.22%


                                                        

Epoch 20 Accuracy 97.32% / Best Accuracy: 98.22%
Finetuning


                                                        

Epoch 1 Accuracy 97.36% / Best Accuracy: 97.36%


                                                        

Epoch 2 Accuracy 97.40% / Best Accuracy: 97.40%


                                                        

Epoch 3 Accuracy 97.86% / Best Accuracy: 97.86%


                                                        

Epoch 4 Accuracy 98.16% / Best Accuracy: 98.16%


                                                        

Epoch 5 Accuracy 97.89% / Best Accuracy: 98.16%


                                                        

Epoch 6 Accuracy 96.86% / Best Accuracy: 98.16%


                                                        

Epoch 7 Accuracy 98.03% / Best Accuracy: 98.16%


                                                        

Epoch 8 Accuracy 97.91% / Best Accuracy: 98.16%


                                                        

Epoch 9 Accuracy 97.12% / Best Accuracy: 98.16%


                                                        

Epoch 10 Accuracy 97.37% / Best Accuracy: 98.16%


                                                        

Epoch 11 Accuracy 96.05% / Best Accuracy: 98.16%


                                                        

Epoch 12 Accuracy 97.58% / Best Accuracy: 98.16%


                                                        

Epoch 13 Accuracy 98.20% / Best Accuracy: 98.20%


                                                        

Epoch 14 Accuracy 97.99% / Best Accuracy: 98.20%


                                                        

Epoch 15 Accuracy 96.21% / Best Accuracy: 98.20%


                                                        

Epoch 16 Accuracy 98.12% / Best Accuracy: 98.20%


                                                        

Epoch 17 Accuracy 97.88% / Best Accuracy: 98.20%


                                                        

Epoch 18 Accuracy 97.89% / Best Accuracy: 98.20%


                                                        

Epoch 19 Accuracy 97.86% / Best Accuracy: 98.20%


                                                        

Epoch 20 Accuracy 97.72% / Best Accuracy: 98.20%
Finetuning


                                                        

Epoch 1 Accuracy 96.23% / Best Accuracy: 96.23%


                                                        

Epoch 2 Accuracy 92.49% / Best Accuracy: 96.23%


                                                        

Epoch 3 Accuracy 97.50% / Best Accuracy: 97.50%


                                                        

Epoch 4 Accuracy 97.37% / Best Accuracy: 97.50%


                                                        

Epoch 5 Accuracy 96.97% / Best Accuracy: 97.50%


                                                        

Epoch 6 Accuracy 97.76% / Best Accuracy: 97.76%


                                                        

Epoch 7 Accuracy 97.76% / Best Accuracy: 97.76%


                                                        

Epoch 8 Accuracy 97.42% / Best Accuracy: 97.76%


                                                        

Epoch 9 Accuracy 98.44% / Best Accuracy: 98.44%


                                                        

Epoch 10 Accuracy 97.80% / Best Accuracy: 98.44%


                                                        

Epoch 11 Accuracy 96.88% / Best Accuracy: 98.44%


                                                        

Epoch 12 Accuracy 98.04% / Best Accuracy: 98.44%


                                                        

Epoch 13 Accuracy 98.12% / Best Accuracy: 98.44%


                                                        

Epoch 14 Accuracy 97.75% / Best Accuracy: 98.44%


                                                        

Epoch 15 Accuracy 97.33% / Best Accuracy: 98.44%


                                                        

Epoch 16 Accuracy 97.41% / Best Accuracy: 98.44%


                                                        

Epoch 17 Accuracy 95.35% / Best Accuracy: 98.44%


                                                        

Epoch 18 Accuracy 97.57% / Best Accuracy: 98.44%


                                                        

Epoch 19 Accuracy 97.89% / Best Accuracy: 98.44%


                                                        

Epoch 20 Accuracy 98.23% / Best Accuracy: 98.44%
Finetuning


                                                        

Epoch 1 Accuracy 92.63% / Best Accuracy: 92.63%


                                                        

Epoch 2 Accuracy 93.67% / Best Accuracy: 93.67%


                                                        

Epoch 3 Accuracy 96.67% / Best Accuracy: 96.67%


                                                        

Epoch 4 Accuracy 95.94% / Best Accuracy: 96.67%


                                                        

Epoch 5 Accuracy 96.99% / Best Accuracy: 96.99%


                                                        

Epoch 6 Accuracy 97.37% / Best Accuracy: 97.37%


                                                        

Epoch 7 Accuracy 96.65% / Best Accuracy: 97.37%


                                                        

Epoch 8 Accuracy 96.98% / Best Accuracy: 97.37%


                                                        

Epoch 9 Accuracy 97.01% / Best Accuracy: 97.37%


                                                        

Epoch 10 Accuracy 97.48% / Best Accuracy: 97.48%


                                                        

Epoch 11 Accuracy 97.78% / Best Accuracy: 97.78%


                                                        

Epoch 12 Accuracy 97.49% / Best Accuracy: 97.78%


                                                        

Epoch 13 Accuracy 97.91% / Best Accuracy: 97.91%


                                                        

Epoch 14 Accuracy 96.19% / Best Accuracy: 97.91%


                                                        

Epoch 15 Accuracy 97.41% / Best Accuracy: 97.91%


                                                        

Epoch 16 Accuracy 97.93% / Best Accuracy: 97.93%


                                                        

Epoch 17 Accuracy 97.14% / Best Accuracy: 97.93%


                                                        

Epoch 18 Accuracy 96.72% / Best Accuracy: 97.93%


                                                        

Epoch 19 Accuracy 97.42% / Best Accuracy: 97.93%


                                                        

Epoch 20 Accuracy 97.93% / Best Accuracy: 97.93%


