# Setup

In [1]:
import random
import h5py
import copy
import os

import torch
import torch.nn as nn

from torchvision.models import *
from torchvision.transforms import *
from torch.utils.data import DataLoader, Dataset
from torch.optim import *
from torch.optim.lr_scheduler import *

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from Dataset.dataset import *
from Dataset.data_generation import ArrhythmiaLabels
from Utils.utils import *
from Utils.model_loading import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


Using cuda:0


<torch._C.Generator at 0x28ea8f1f9d0>

# Training Different Models

- IMPORTANT: Replace the output layer correctly before training or face suffering

In [2]:
image_size = 224
dataloader = build_dataloader(
    train_path = "Data/mitbih_mif_train_small.h5",
    test_path  = "Data/mitbih_mif_test.h5",
    batch_size = 32,
    transform  = Resize((image_size, image_size), antialias=None)
)

# visualize_ecg_data(dataloader["train"])

In [3]:
model_name = "vgg16_bn"
model = load_model(model_name)

# Replace classification layer output
num_classes = ArrhythmiaLabels.size
model.classifier[6] = nn.Linear(4096, num_classes)

# Finetuning
num_finetune_epochs = 50
model.to(device)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_model_checkpoint = dict()
best_accuracy = 0
safety = True

if not os.path.exists("Pretrained"):
    os.makedirs("Pretrained")

print("Finetuning")
for epoch in range(num_finetune_epochs):
    train(model, dataloader["train"], criterion, optimizer, scheduler)
    accuracy = evaluate(model, dataloader["test"])

    if accuracy > best_accuracy:
        best_model_checkpoint["state_dict"] = copy.deepcopy(model.state_dict())
        best_accuracy = accuracy
    
    if safety and (epoch + 1) % 10 == 0:
        checkpoint = {
            "model_state_dict":     model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "epoch":                epoch,
        }
        if not os.path.exists("Pretrained/Safety"):
            os.makedirs("Pretrained/Safety")
        torch.save(checkpoint, f"Pretrained/Safety/check_{epoch+1}.pth")
    
    print(f"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%")

# Save best model    
save_path = f"Pretrained/{model_name}_ecg_ep{num_finetune_epochs}_i{image_size}.pth"
torch.save(best_model_checkpoint["state_dict"], save_path)

# Test saved model
model = load_model_from_pretrained(model_name, save_path, num_classes=num_classes)
model.to(device)
acc = evaluate(model, dataloader["test"])
print(f"Accuracy of Loaded Model: {acc:.2f}")

Finetuning


                                                        

Epoch 1 Accuracy 97.36% / Best Accuracy: 97.36%


                                                        

Epoch 2 Accuracy 95.74% / Best Accuracy: 97.36%


                                                        

Epoch 3 Accuracy 97.03% / Best Accuracy: 97.36%


                                                        

Epoch 4 Accuracy 97.83% / Best Accuracy: 97.83%


                                                        

Epoch 5 Accuracy 97.31% / Best Accuracy: 97.83%


                                                        

Epoch 6 Accuracy 98.09% / Best Accuracy: 98.09%


                                                        

Epoch 7 Accuracy 98.11% / Best Accuracy: 98.11%


                                                        

Epoch 8 Accuracy 97.35% / Best Accuracy: 98.11%


                                                        

Epoch 9 Accuracy 97.55% / Best Accuracy: 98.11%


                                                        

Epoch 10 Accuracy 97.35% / Best Accuracy: 98.11%


                                                        

Epoch 11 Accuracy 98.31% / Best Accuracy: 98.31%


                                                        

Epoch 12 Accuracy 98.61% / Best Accuracy: 98.61%


                                                        

Epoch 13 Accuracy 96.06% / Best Accuracy: 98.61%


                                                        

Epoch 14 Accuracy 98.14% / Best Accuracy: 98.61%


                                                        

Epoch 15 Accuracy 98.31% / Best Accuracy: 98.61%


                                                        

Epoch 16 Accuracy 98.31% / Best Accuracy: 98.61%


                                                        

Epoch 17 Accuracy 98.39% / Best Accuracy: 98.61%


                                                        

Epoch 18 Accuracy 98.43% / Best Accuracy: 98.61%


                                                        

Epoch 19 Accuracy 98.33% / Best Accuracy: 98.61%


                                                        

Epoch 20 Accuracy 98.53% / Best Accuracy: 98.61%


                                                        

Epoch 21 Accuracy 98.20% / Best Accuracy: 98.61%


                                                        

Epoch 22 Accuracy 98.39% / Best Accuracy: 98.61%


                                                        

Epoch 23 Accuracy 98.41% / Best Accuracy: 98.61%


                                                        

Epoch 24 Accuracy 98.31% / Best Accuracy: 98.61%


                                                        

Epoch 25 Accuracy 97.94% / Best Accuracy: 98.61%


                                                        

Epoch 26 Accuracy 96.18% / Best Accuracy: 98.61%


                                                        

Epoch 27 Accuracy 98.40% / Best Accuracy: 98.61%


                                                        

Epoch 28 Accuracy 98.10% / Best Accuracy: 98.61%


                                                        

Epoch 29 Accuracy 98.59% / Best Accuracy: 98.61%


                                                        

Epoch 30 Accuracy 98.55% / Best Accuracy: 98.61%


                                                        

Epoch 31 Accuracy 98.54% / Best Accuracy: 98.61%


                                                        

Epoch 32 Accuracy 98.47% / Best Accuracy: 98.61%


                                                        

Epoch 33 Accuracy 98.41% / Best Accuracy: 98.61%


                                                        

Epoch 34 Accuracy 98.41% / Best Accuracy: 98.61%


                                                        

Epoch 35 Accuracy 98.41% / Best Accuracy: 98.61%


                                                        

Epoch 36 Accuracy 98.39% / Best Accuracy: 98.61%


                                                        

Epoch 37 Accuracy 98.33% / Best Accuracy: 98.61%


                                                        

Epoch 38 Accuracy 98.30% / Best Accuracy: 98.61%


                                                        

Epoch 39 Accuracy 97.40% / Best Accuracy: 98.61%


                                                        

Epoch 40 Accuracy 98.40% / Best Accuracy: 98.61%


                                                        

Epoch 41 Accuracy 98.55% / Best Accuracy: 98.61%


                                                        

Epoch 42 Accuracy 98.47% / Best Accuracy: 98.61%


                                                        

Epoch 43 Accuracy 98.42% / Best Accuracy: 98.61%


                                                        

Epoch 44 Accuracy 98.49% / Best Accuracy: 98.61%


                                                        

Epoch 45 Accuracy 98.55% / Best Accuracy: 98.61%


                                                        

Epoch 46 Accuracy 98.55% / Best Accuracy: 98.61%


                                                        

Epoch 47 Accuracy 98.53% / Best Accuracy: 98.61%


                                                        

Epoch 48 Accuracy 98.49% / Best Accuracy: 98.61%


                                                        

Epoch 49 Accuracy 98.56% / Best Accuracy: 98.61%


                                                        

Epoch 50 Accuracy 98.63% / Best Accuracy: 98.63%


                                                       

Accuracy of Loaded Model: 98.63




# Resolution Scaling

In [4]:
def train_with_loss_record(
    model:      nn.Module,
    dataloader: DataLoader,
    criterion:  nn.Module,
    optimizer:  Optimizer,
    scheduler:  LambdaLR
) -> list:
    model.train()
    running_loss = []

    for inputs, labels in tqdm(dataloader, desc="Train", leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss.append(loss.item())

        loss.backward()

        optimizer.step()
        scheduler.step()

    return running_loss

In [None]:
model_name  = "mobilenet_v3_small"
num_classes = ArrhythmiaLabels.size
num_finetune_epochs = 15
resolutions = [80, 104, 128, 152, 176, 200, 224]
running_accuracy = {r: [] for r in resolutions}
running_loss = {l: [] for l in resolutions}

for resolution in resolutions:
    # Setup training
    model = load_model(model_name)
    model.classifier[3] = nn.Linear(1024, num_classes)
    model.to(device)

    dataloader = build_dataloader(
        train_path = "Data/mitbih_mif_train_small.h5",
        test_path  = "Data/mitbih_mif_test.h5",
        batch_size = 128,
        transform  = Resize((resolution, resolution))
    )
    
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)
    criterion = nn.CrossEntropyLoss()

    best_model_checkpoint = dict()
    best_accuracy = 0

    # Finetune
    print(f"Finetuning at resolution {resolution}")
    for epoch in range(num_finetune_epochs):
        losses = train_with_loss_record(model, dataloader["train"], criterion, optimizer, scheduler)
        accuracy = evaluate(model, dataloader["test"])
        
        running_accuracy[resolution].append(accuracy)
        running_loss[resolution].append(losses)

        if accuracy > best_accuracy:
            best_model_checkpoint["state_dict"] = copy.deepcopy(model.state_dict())
            best_accuracy = accuracy
        
        print(f"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%")

    # Save best model
    if not os.path.exists("Pretrained/MobileNetV3-Small"):
        os.makedirs("Pretrained/MobileNetV3-Small")
    
    save_path = f"Pretrained/MobileNetV3-Small/{model_name}_ecg_ep{num_finetune_epochs}_i{resolution}.pth"
    torch.save(best_model_checkpoint["state_dict"], save_path)

    # Write the running accuracy and loss to another file for safety
    with open("running_acc.txt", "a") as file:
        file.write(" ".join(map(str, running_accuracy[resolution])))
        file.write("\n")
    
    with open("running_loss.txt", "a") as file:
        total_losses = []
        for epoch_losses in running_loss[resolution]:
            total_losses += epoch_losses

        file.write(" ".join(map(str, total_losses)))
        file.write("\n")

# OFA

In [None]:
ofa_specialized_get = torch.hub.load("mit-han-lab/once-for-all", "ofa_specialized_get")

In [12]:
model, image_size = ofa_specialized_get("pixel1_lat@20ms_top1@71.4_finetune@25", pretrained=True)

dataloader = build_dataloader(
    train_path = "Data/mitbih_mif_train_small.h5",
    test_path  = "Data/mitbih_mif_test.h5",
    batch_size = 128,
    transform  = Resize((152, 152))
)

model.classifier.linear = nn.Linear(1280, 5)

In [None]:
model_name  = "ofa_pixel1_20"
num_classes = ArrhythmiaLabels.size

# Finetuning
num_finetune_epochs = 50
model.to(device)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_model_checkpoint = dict()
best_accuracy = 0

print("Finetuning")
for epoch in range(num_finetune_epochs):
    train(model, dataloader["train"], criterion, optimizer, scheduler)
    accuracy = evaluate(model, dataloader["test"])

    if accuracy > best_accuracy:
        best_model_checkpoint["state_dict"] = copy.deepcopy(model.state_dict())
        best_accuracy = accuracy
    
    print(f"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%")

# Save best model
if not os.path.exists("Pretrained"):
    os.makedirs("Pretrained")
    
save_path = f"Pretrained/{model_name}_ecg_ep{num_finetune_epochs}_i{image_size}.pth"
torch.save(best_model_checkpoint["state_dict"], save_path)

In [None]:
# Test saved model
model, _ = ofa_specialized_get("flops@595M_top1@80.0_finetune@75", pretrained=True)
model.classifier.linear = nn.Linear(1536, 5)
model.load_state_dict(torch.load("Pretrained\ofa_595M_ecg_ep50_i152.pth"))
model.to(device)
acc = evaluate(model, dataloader["test"])
print(f"Accuracy of Loaded Model: {acc:.2f}")