In [1]:
import subprocess
import numpy as np
import os
import sys
import pandas as pd
import h5py
import matplotlib.pyplot as plt
from typing import List, Tuple
from torch.utils.data import Dataset
import torch
from torch.utils import data
from tqdm.notebook import tqdm
import torch.nn as nn

import sklearn
from sklearn.metrics import confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report
from sklearn.metrics import precision_recall_curve, f1_score
from sklearn.metrics import average_precision_score

from model import CLIP
import clip
import sklearn
from sklearn.metrics import confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report
from sklearn.metrics import precision_recall_curve, f1_score
from sklearn.metrics import average_precision_score


class ECGDataset(Dataset):
    def __init__(self, data_path, label_path):
        """
        Args:
            data_path (string): 路径到心电图数据的 .npy 文件。
            label_path (string): 路径到标签数据的 .npy 文件。
        """
        self.data = np.load(data_path)
        self.labels = np.load(label_path).squeeze()

        assert self.data.shape[0] == self.labels.shape[0], \
            "Data and labels must have the same number of samples!"

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        ecg = torch.tensor(self.data[idx], dtype=torch.float32) * 200
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        sample = {'ecg': ecg, 'label':label}
        return sample

def load_clip(model_path, pretrained=False, context_length=77): 
    """
    FUNCTION: load_clip
    ---------------------------------
    """
    device = torch.device("cuda:0")
    if pretrained is False: 
        # use new model params
        params = {
        'context_length': context_length,
        'vocab_size': 49408,
        'transformer_width': 512,
        'transformer_heads': 8,
        'transformer_layers': 12
        }

        model = CLIP(**params)
    else: 
        model, preprocess = clip.load("ViT-B/32", device=device, jit=False) 
    try: 
        model.load_state_dict(torch.load(model_path, map_location=device))
    except: 
        print("Argument error. Set pretrained = True.", sys.exc_info()[0])
        raise
    return model

def make(
    model_path: str, 
    cxr_filepath: str, 
    pretrained: bool = False, 
    context_length: bool = 77, 
):
    """
    FUNCTION: make
    -------------------------------------------
    This function makes the model, the data loader, and the ground truth labels. 
    
    args: 
        * model_path - String for directory to the weights of the trained clip model. 
        * context_length - int, max number of tokens of text inputted into the model. 
        * cxr_filepath - String for path to the chest x-ray images. 
        * cxr_labels - Python list of labels for a specific zero-shot task. (i.e. ['Atelectasis',...])
        * pretrained - bool, whether or not model uses pretrained clip weights
        * cutlabels - bool, if True, will keep columns of ground truth labels that correspond
        with the labels inputted through `cxr_labels`. Otherwise, drop the first column and keep remaining.
    
    Returns model, data loader. 
    """
    # load model
    model = load_clip(
        model_path=model_path, 
        pretrained=False, 
        context_length=context_length
    )

    
    # create dataset
    torch_dset = ECGDataset(data_path="/home/ubuntu/code/ECG2TEXT/X_all.npy", label_path="/home/ubuntu/code/ECG2TEXT/y_test.npy")
    
    loader = torch.utils.data.DataLoader(torch_dset, shuffle=False)
    
    return model, loader



In [2]:
model_dir = '/home/ubuntu/code/ECG2TEXT/checkpoints/pt-imp/checkpoint_28000.pt'
cxr_filepath = "/home/ubuntu/code/ECG2TEXT/X_all.npy"
model, loader = make(model_path = model_dir, cxr_filepath = cxr_filepath)

# Preparing ImageNet labels and prompts

The following cell contains the 1,000 labels for the ImageNet dataset, followed by the text templates we'll use as "prompt engineering".

In [3]:
imagenet_classes = ["Normal ECG","Myocardial Infarction","ST/T change","Conduction Disturbance", "Hypertrophy"]

In [4]:
imagenet_templates = [
    'ECG for diagnosis of {}.'
]

print(f"{len(imagenet_classes)} classes, {len(imagenet_templates)} templates")

5 classes, 1 templates


# Creating zero-shot classifier weights

In [None]:
def zeroshot_classifier(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            print(0)
            texts = [template.format(classname) for template in templates] #format with class\
            print(1)
            texts = clip.tokenize(texts).cuda() #tokenize
            print(2)
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights


zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates)

  0%|          | 0/5 [00:00<?, ?it/s]

0
1
2


# Zero-shot prediction

In [None]:
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [None]:
with torch.no_grad():
    top1, top5, n = 0., 0., 0.
    for data in enumerate(tqdm(loader)):
        images = data['ecg']
        target = data['label']
        
        # predict
        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100. * image_features @ zeroshot_weights

        # measure accuracy
        acc1, acc5 = accuracy(logits, target, topk=(1, 5))
        top1 += acc1
        top5 += acc5
        n += images.size(0)

top1 = (top1 / n) * 100
top5 = (top5 / n) * 100 

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")