In [None]:
# torch and torchvision imports
import torch
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from ResnetModel import *
from transformer import *
writer = SummaryWriter()
# from google.colab import drive
# drive.mount('/gdrive')
torch.set_default_dtype(torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")


### Loading and Preparing Data

In [None]:
import os

In [None]:
os.listdir('../')

In [None]:
path = '../ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1/Datasets/'
X_train_form = np.load(path+'FormTrain.npz')['x']
X_train_rhythm = np.load(path+'RhythmTrain.npz')['x']
X_train_CD = np.load(path+'CDTrain.npz')['x']
X_train_HYP = np.load(path+'HYPTrain.npz')['x']
X_train_MI = np.load(path+'MITrain.npz')['x']
X_train_STTC = np.load(path+'STTCTrain.npz')['x']
Y_train_form = np.load(path+'FormTrain.npz')['y']
Y_train_rhythm = np.load(path+'RhythmTrain.npz')['y']
Y_train_CD = np.load(path+'CDTrain.npz')['y']
Y_train_HYP = np.load(path+'HYPTrain.npz')['y']
Y_train_MI = np.load(path+'MITrain.npz')['y']
Y_train_STTC = np.load(path+'STTCTrain.npz')['y']
X_val_form = np.load(path+'FormVal.npz')['x']
X_val_rhythm = np.load(path+'RhythmVal.npz')['x']
X_val_CD = np.load(path+'CDVal.npz')['x']
X_val_HYP = np.load(path+'HYPVal.npz')['x']
X_val_MI = np.load(path+'MIVal.npz')['x']
X_val_STTC = np.load(path+'STTCVal.npz')['x']
Y_val_form = np.load(path+'FormVal.npz')['y']
Y_val_rhythm = np.load(path+'RhythmVal.npz')['y']
Y_val_CD = np.load(path+'CDVal.npz')['y']
Y_val_HYP = np.load(path+'HYPVal.npz')['y']
Y_val_MI = np.load(path+'MIVal.npz')['y']
Y_val_STTC = np.load(path+'STTCVal.npz')['y']


In [None]:
form_train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train_form), torch.from_numpy(Y_train_form))
rhythm_train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train_rhythm), torch.from_numpy(Y_train_rhythm))
CD_train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train_CD), torch.from_numpy(Y_train_CD))
HYP_train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train_HYP), torch.from_numpy(Y_train_HYP))
MI_train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train_MI), torch.from_numpy(Y_train_MI))
STTC_train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train_STTC), torch.from_numpy(Y_train_STTC))
form_val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_val_form), torch.from_numpy(Y_val_form))
rhythm_val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_val_rhythm), torch.from_numpy(Y_val_rhythm))
CD_val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_val_CD), torch.from_numpy(Y_val_CD))
HYP_val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_val_HYP), torch.from_numpy(Y_val_HYP))
MI_val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_val_MI), torch.from_numpy(Y_val_MI))
STTC_val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_val_STTC), torch.from_numpy(Y_val_STTC))

del X_train_form, X_train_rhythm, X_train_CD, X_train_HYP, X_train_MI, X_train_STTC, Y_train_form, Y_train_rhythm, Y_train_CD, Y_train_HYP, Y_train_MI, Y_train_STTC, X_val_form, X_val_rhythm, X_val_CD, X_val_HYP, X_val_MI, X_val_STTC, Y_val_form, Y_val_rhythm, Y_val_CD, Y_val_HYP, Y_val_MI, Y_val_STTC

In [None]:
with open(path+'category.pickle', 'rb') as f:
    category = pickle.load(f)
num_classes = [len(category[key]) for key in category.keys()]
models = [key for key in category.keys()]
print(num_classes)
print(models)

### Creating the Resnet Model

In [None]:
form_model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes[0]).to(device).float()
rhythm_model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes[1]).to(device).float()
STTC_model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes[2]).to(device).float()
MI_model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes[3]).to(device).float()
HYP_model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes[4]).to(device).float()
CD_model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes[5]).to(device).float()

In [None]:
form_train_loader = torch.utils.data.DataLoader(form_train_dataset, batch_size=10, shuffle=True)
rhythm_train_loader = torch.utils.data.DataLoader(rhythm_train_dataset, batch_size=10, shuffle=True)
STCC_train_loader = torch.utils.data.DataLoader(STTC_train_dataset, batch_size=10, shuffle=True)
MI_train_loader = torch.utils.data.DataLoader(MI_train_dataset, batch_size=10, shuffle=True)
HYP_train_loader = torch.utils.data.DataLoader(HYP_train_dataset, batch_size=10, shuffle=True)
CD_train_loader = torch.utils.data.DataLoader(CD_train_dataset, batch_size=10, shuffle=True)
form_val_loader = torch.utils.data.DataLoader(form_val_dataset, batch_size=10, shuffle=True)
rhythm_val_loader = torch.utils.data.DataLoader(rhythm_val_dataset, batch_size=10, shuffle=True)
STCC_val_loader = torch.utils.data.DataLoader(STTC_val_dataset, batch_size=10, shuffle=True)
MI_val_loader = torch.utils.data.DataLoader(MI_val_dataset, batch_size=10, shuffle=True)
HYP_val_loader = torch.utils.data.DataLoader(HYP_val_dataset, batch_size=10, shuffle=True)
CD_val_loader = torch.utils.data.DataLoader(CD_val_dataset, batch_size=10, shuffle=True)


In [None]:
form_metric = MultilabelAUROC(num_labels=num_classes[0])
rhythm_metric = MultilabelAUROC(num_labels=num_classes[1])
STTC_metric = MultilabelAUROC(num_labels=num_classes[2])
MI_metric = MultilabelAUROC(num_labels=num_classes[3])
HYP_metric = MultilabelAUROC(num_labels=num_classes[4])
CD_metric = MultilabelAUROC(num_labels=num_classes[5])


form_criterion = nn.BCELoss()
rhythm_criterion = nn.BCELoss()
STTC_criterion = nn.BCELoss()
MI_criterion = nn.BCELoss()
HYP_criterion = nn.BCELoss()
CD_criterion = nn.BCELoss()

form_optimizer = torch.optim.Adam(form_model.parameters(), lr=0.0001, weight_decay=1e-4)
rhythm_optimizer = torch.optim.Adam(rhythm_model.parameters(), lr=0.0001, weight_decay=1e-4)
STTC_optimizer = torch.optim.Adam(STTC_model.parameters(), lr=0.0001, weight_decay=1e-4)
MI_optimizer = torch.optim.Adam(MI_model.parameters(), lr=0.0001, weight_decay=1e-4)
HYP_optimizer = torch.optim.Adam(HYP_model.parameters(), lr=0.0001, weight_decay=1e-4)
CD_optimizer = torch.optim.Adam(CD_model.parameters(), lr=0.0001, weight_decay=1e-4)

model_names = ['form', 'rhythm', 'STTC', 'MI', 'HYP', 'CD']
models = [form_model, rhythm_model, STTC_model, MI_model, HYP_model, CD_model]
optimizers = [form_optimizer, rhythm_optimizer, STTC_optimizer, MI_optimizer, HYP_optimizer, CD_optimizer]
criterions = [form_criterion, rhythm_criterion, STTC_criterion, MI_criterion, HYP_criterion, CD_criterion]
train_loaders = [form_train_loader, rhythm_train_loader, STCC_train_loader, MI_train_loader, HYP_train_loader, CD_train_loader]
val_loaders = [form_val_loader, rhythm_val_loader, STCC_val_loader, MI_val_loader, HYP_val_loader, CD_val_loader]
metrics = [form_metric, rhythm_metric, STTC_metric, MI_metric, HYP_metric, CD_metric]

train_losses = [[] for i in range(len(models))]
val_losses = [[] for i in range(len(models))]
train_aurocs = [[] for i in range(len(models))]
val_aurocs = [[] for i in range(len(models))]
lr_records = [[] for i in range(len(models))]

lr_max = 0.00035
lr = lr_max
epochs = 10

ts = np.zeros(len(models), dtype=np.int32)
lrs = np.zeros(len(models), dtype=np.float32)
steps_per_epoch = [len(loader) for loader in train_loaders]
T_maxs = [steps_per_epoch[i]*epochs for i in range(len(models))]
T_0s = [T_maxs[i]//5 for i in range(len(models))]
for optimizer in optimizers:
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_max
            


In [None]:
for epoch in range(epochs):
    for model in range(len(models)):
        model_i = models[model]
        optimizer = optimizers[model]
        criterion = criterions[model]
        train_loader = train_loaders[model]
        val_loader = val_loaders[model]
        metric = metrics[model]
        T_0 = T_0s[model]
        T_max = T_maxs[model]
            
        for i, (signal, labels) in enumerate(train_loader):
            idx = np.random.randint(0, 1000-200)
            signal = signal.transpose(1,2).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model_i(signal)
            loss = criterion(outputs, labels)
            loss.backward()
            ts[model] += 1
            train_losses[model].append(loss.item())
            if ts[model] <= T_0:
                lrs[model] = 10**(-4) + (ts[model]/T_0)*lr_max  
            else: 
                lrs[model] = lr_max*np.cos((np.pi/2)*((ts[model]-T_0)/(T_max-T_0))) + 10**(-6) 

            lr_records[model].append(lrs[model])
            for g in optimizer.param_groups:
                g['lr'] = lrs[model]

            optimizer.step()

            if (i+1) % len(train_loader)//10 == 0:
                print(f'Model: {model_names[model]}, Epoch [{epoch+1}/{epochs}], Step [{i+1}/{steps_per_epoch}], Loss: {loss.item():.4f}, AUROC: {metric(outputs, (labels>0).int()):.4f}')
            
        test_auc = 0
        test_loss = 0
        with torch.no_grad():
            for signal, labels in val_loader:
                idx = np.random.randint(0, 1000-200)
                signal = (signal[:, idx:idx+200]).to(device)
                labels = labels.to(device)
                outputs = model_i(signal)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                test_auc += metric(outputs, (labels>0).int())
            val_losses[model].append(test_loss/len(val_loader))
            val_aurocs[model].append(test_auc/len(val_loader))
        print(f'Model: {model_names[model]}, Epoch [{epoch+1}/{epochs}], Val Loss: {test_loss/len(val_loader):.4f}, Val AUROC: {test_auc/len(val_loader):.4f}')

In [None]:
signal

In [None]:
t = 0
steps_per_epoch = len(train_loader)
T_max = steps_per_epoch*epochs
T_0 = T_max/5 
learning_rates = []
train_losses = []

for epoch in range(epochs):
    for i, (signal, labels) in enumerate(train_loader):
        idx = np.random.randint(0, 1000-200)
        signal_sample = (signal[:, :, idx:idx+200]).to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(signal_sample)
        loss = criterion(outputs, labels.float())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5)
        if t <= T_0:
            lr = 10**(-4) + (t/T_0)*lr_max  
        else: 
            lr = lr_max*np.cos((np.pi/2)*((t-T_0)/(T_max-T_0))) + 10**(-6) 

        for g in optimizer.param_groups:
            g['lr'] = lr 
        learning_rates.append(lr)
        train_losses.append(loss.item())
        optimizer.step()
        t+=1
        
        train_AUC = ml_auroc(outputs, labels.int())
        writer.add_scalar("Train_Loss", loss, t)
        writer.add_scalar("Learning rate", lr, t)
        writer.add_scalar("Batch Train AUC", train_AUC, t)

        if i%(len(train_loader)//10) == 0:
            print(f"Step: {i+1}/{len(train_loader)}  |  Train loss: {loss.item():.4f}  |  Train AUC: {train_AUC:.4f}")
           

    # model.eval()
    test_auc = 0
    with torch.no_grad():
        for i, (signal, labels) in enumerate(test_loader):
            idx = np.random.randint(0, 1000-200)
            signal = (signal[:, :, idx:idx+200]).to(device)
            labels = labels.to(device)
            outputs = model(signal)
            test_auc += ml_auroc(outputs, labels.int())
        test_auc /= len(test_loader)
    writer.add_scalar("Test AUC", test_auc, epoch)

In [None]:
signal.shape

In [None]:
with open('superclassresnetmodel.pickle', 'wb') as f:
    pickle.dump(model, f)

In [None]:
import pickle, matplotlib.pyplot as plt, torch
with open('./modelres/SuperClasslosses.pickle', 'rb') as f:
    train_losses = pickle.load(f)

In [None]:
len(train_losses)

In [None]:
fig,axes = plt.subplots(1,3, figsize = (10,3))
axes[0].plot(train_losses[0])
axes[1].plot(torch.tensor(train_losses[1]))
axes[2].plot(train_losses[2])
plt.show()


In [None]:
torch.tensor(train_losses[0][-2000:]).mean()

In [None]:
import torch
torch.tensor(train_losses[1]).device

In [None]:
train_losses[2]

In [None]:
model