# Setup

In [1]:
import random
import h5py
import copy

import torch
import torch.nn as nn

from torchvision.models import *
from torchvision.transforms import *
from torch.utils.data import DataLoader, Dataset
from torch.optim import *
from torch.optim.lr_scheduler import *

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from utils import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


Using cuda:0


<torch._C.Generator at 0x2bb3620fb30>

In [2]:
class ArrhythmiaLabels:
    labels = {
        0: "N",
        1: "S",
        2: "V",
        3: "F",
        4: "Q",
    }
    size = 5

class EcgDataset(Dataset):
    def __init__(self, file_path: str, transform=None) -> None:
        super().__init__()

        self.h5_file   = h5py.File(file_path, "r")
        self.images    = self.h5_file["images"]
        self.labels    = self.h5_file["labels"]
        self.transform = transform

    def __del__(self):
        self.h5_file.close()

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx].astype(int)
        
        image = np.transpose(image, (1, 2, 0)) # (h, c, w) => (c, w, h)
        image = np.array(image, dtype=np.float32)

        label = np.array(label, dtype=np.int64)

        if self.transform:
            image = self.transform(image)

        return image, label
    
def build_dataloader(train_path: str, test_path: str, batch_size: int, transform) -> dict[str, DataLoader]:
    transforms = {
        "train": Compose([
            ToTensor(),
            transform,
        ]),
        "test": Compose([
            ToTensor(),
            transform,
        ])
    }

    dataset = {
        "train": EcgDataset(train_path, transform=transforms["train"]),
        "test":  EcgDataset(test_path, transform=transforms["test"]),       
    }

    dataloader = {}
    for split in ["train", "test"]:
        dataloader[split] = DataLoader(
            dataset[split],
            batch_size  = batch_size,
            shuffle     = (split == "train"),
            num_workers = 0,
            pin_memory  = True,
        )

    return dataloader

def visualize_ecg_data(dataloader: DataLoader) -> None:
    batch = next(iter(dataloader))
    num_images = len(batch[0])

    idx = random.choice(range(0, num_images))
    image = batch[0][idx]
    channels = torch.split(image, 1, dim=0)

    _, ax = plt.subplots(1, 3)
    for i, channel in enumerate(channels):
        ax[i].imshow(channel.squeeze(0).numpy(), cmap="gray")
        ax[i].axis("off")
    plt.show()


In [3]:
def load_model(name: str) -> nn.Module:
    match name:
        case "alexnet":
            return alexnet(weights=AlexNet_Weights.DEFAULT)
        case "resnet18":
            return resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        case "resnet50":
            return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        case "vgg11_bn":
            return vgg11_bn(weights=VGG11_BN_Weights.IMAGENET1K_V1)
        case "vgg16_bn":
            return vgg16_bn(weights=VGG16_BN_Weights.IMAGENET1K_V1)       
        case "vit_b_16":
            return vit_b_16(weights=ViT_B_16_Weights.DEFAULT) 
        case "mobilenet_v3":
            return mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.IMAGENET1K_V2)
        case "mobilenet_v3_small":
            return mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
    
def load_model_from_pretrained(name: str, path: str, num_classes: int) -> nn.Module:
    model = None
    match name:
        case "alexnet":
            model = alexnet()
            model.classifier[4] = nn.Linear(4096, 512)
            model.classifier[6] = nn.Linear(512, num_classes)

        case "resnet18":
            model = resnet18()
            model.fc = nn.Linear(512, num_classes)

        case "resnet50":
            model = resnet50()
            model.fc = nn.Linear(2048, num_classes)
            
        case "vgg11_bn":
            model = vgg11_bn()
            model.classifier[6] = nn.Linear(4096, num_classes)

        case "vgg16_bn":
            model = vgg16_bn()
            model.classifier[6] = nn.Linear(4096, num_classes)

        case "vit_b_16":
            model = vit_b_16()
            model.heads.head = nn.Linear(768, num_classes)
        
        case "mobilenet_v3":
            model = mobilenet_v3_large()
            model.classifier[3] = nn.Linear(1280, num_classes)

        case "mobilenet_v3_small":
            model = mobilenet_v3_small()
            model.classifier[3] = nn.Linear(1024, num_classes)

    model.load_state_dict(torch.load(path))
    return model

# Training Different Models

- IMPORTANT: Replace the output layer correctly before training or face suffering

In [9]:
dataloader = build_dataloader(
    train_path = "Data/mitbih_mif_train_small.h5",
    test_path  = "Data/mitbih_mif_test.h5",
    batch_size = 32,
    transform  = Resize((152, 152))
)

# visualize_ecg_data(dataloader["train"])

In [None]:
model_name = "vgg16_bn"
model = load_model(model_name)

# Replace classification layer output
num_classes = ArrhythmiaLabels.size
model.classifier[6] = nn.Linear(4096, num_classes)

# Finetuning
num_finetune_epochs = 50
model.to(device)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_model_checkpoint = dict()
best_accuracy = 0

print("Finetuning")
for epoch in range(num_finetune_epochs):
    train(model, dataloader["train"], criterion, optimizer, scheduler)
    accuracy = evaluate(model, dataloader["test"])

    if accuracy > best_accuracy:
        best_model_checkpoint["state_dict"] = copy.deepcopy(model.state_dict())
        best_accuracy = accuracy
    
    print(f"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%")

# Save best model
save_path = f"Pretrained/{model_name}_ecg_ep{num_finetune_epochs}_i152.pth"
torch.save(best_model_checkpoint["state_dict"], save_path)

# Test saved model
model = load_model_from_pretrained(model_name, save_path, num_classes=num_classes)
model.to(device)
acc = evaluate(model, dataloader["test"])
print(f"Accuracy of Loaded Model: {acc:.2f}")

# Resolution Scaling

In [4]:
def train_with_loss_record(
    model:      nn.Module,
    dataloader: DataLoader,
    criterion:  nn.Module,
    optimizer:  Optimizer,
    scheduler:  LambdaLR
) -> list:
    model.train()
    running_loss = []

    for inputs, labels in tqdm(dataloader, desc="Train", leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss.append(loss.item())

        loss.backward()

        optimizer.step()
        scheduler.step()

    return running_loss

In [None]:
model_name  = "mobilenet_v3_small"
num_classes = ArrhythmiaLabels.size
num_finetune_epochs = 15
resolutions = [80, 104, 128, 152, 176, 200, 224]
running_accuracy = {r: [] for r in resolutions}
running_loss = {l: [] for l in resolutions}

for resolution in resolutions:
    # Setup training
    model = load_model(model_name)
    model.classifier[3] = nn.Linear(1024, num_classes)
    model.to(device)

    dataloader = build_dataloader(
        train_path = "Data/mitbih_mif_train_small.h5",
        test_path  = "Data/mitbih_mif_test.h5",
        batch_size = 128,
        transform  = Resize((resolution, resolution))
    )
    
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)
    criterion = nn.CrossEntropyLoss()

    best_model_checkpoint = dict()
    best_accuracy = 0

    # Finetune
    print(f"Finetuning at resolution {resolution}")
    for epoch in range(num_finetune_epochs):
        losses = train_with_loss_record(model, dataloader["train"], criterion, optimizer, scheduler)
        accuracy = evaluate(model, dataloader["test"])
        
        running_accuracy[resolution].append(accuracy)
        running_loss[resolution].append(losses)

        if accuracy > best_accuracy:
            best_model_checkpoint["state_dict"] = copy.deepcopy(model.state_dict())
            best_accuracy = accuracy
        
        print(f"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%")

    # Save best model
    save_path = f"Pretrained/MobileNetV3-Small/{model_name}_ecg_ep{num_finetune_epochs}_i{resolution}.pth"
    torch.save(best_model_checkpoint["state_dict"], save_path)

    # Write the running accuracy and loss to another file for safety
    with open("running_acc_2.txt", "a") as file:
        file.write(" ".join(map(str, running_accuracy[resolution])))
        file.write("\n")
    
    with open("running_loss_2.txt", "a") as file:
        total_losses = []
        for epoch_losses in running_loss[resolution]:
            total_losses += epoch_losses

        file.write(" ".join(map(str, total_losses)))
        file.write("\n")

# OFA

In [4]:
ofa_specialized_get = torch.hub.load("mit-han-lab/once-for-all", "ofa_specialized_get")

Using cache found in C:\Users\Tavonput Luangphasy/.cache\torch\hub\mit-han-lab_once-for-all_master


In [12]:
model, image_size = ofa_specialized_get("pixel1_lat@20ms_top1@71.4_finetune@25", pretrained=True)

dataloader = build_dataloader(
    train_path = "Data/mitbih_mif_train_small.h5",
    test_path  = "Data/mitbih_mif_test.h5",
    batch_size = 128,
    transform  = Resize((152, 152))
)

model.classifier.linear = nn.Linear(1280, 5)

In [None]:
model_name  = "ofa_pixel1_20"
num_classes = ArrhythmiaLabels.size

# Finetuning
num_finetune_epochs = 50
model.to(device)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_model_checkpoint = dict()
best_accuracy = 0

print("Finetuning")
for epoch in range(num_finetune_epochs):
    train(model, dataloader["train"], criterion, optimizer, scheduler)
    accuracy = evaluate(model, dataloader["test"])

    if accuracy > best_accuracy:
        best_model_checkpoint["state_dict"] = copy.deepcopy(model.state_dict())
        best_accuracy = accuracy
    
    print(f"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%")

# Save best model
save_path = f"Pretrained/{model_name}_ecg_ep{num_finetune_epochs}_i{image_size}.pth"
torch.save(best_model_checkpoint["state_dict"], save_path)

In [None]:
# Test saved model
model, _ = ofa_specialized_get("flops@595M_top1@80.0_finetune@75", pretrained=True)
model.classifier.linear = nn.Linear(1536, 5)
model.load_state_dict(torch.load("Pretrained\ofa_595M_ecg_ep50_i152.pth"))
model.to(device)
acc = evaluate(model, dataloader["test"])
print(f"Accuracy of Loaded Model: {acc:.2f}")