In [1]:
import pandas as pd
import h5py
import torch
import numpy as np
from code.model.resnet import resnet50, resnet1d34
from code.model.unet import unet
from code.utils.dataloader import ECHODataset
from torch.utils.data import DataLoader
import tqdm
import matplotlib.pyplot as plt
from torchmetrics import Dice
from code.utils.evaluation import Accumulator, hausdorff_distance, cal_hd, shield, generate_pseudo_label
from copy import deepcopy

In [2]:
base_path = "/data/zyk/data/dataset/ECHO/preprocessed/"
test_meta = pd.concat([
    pd.read_csv(base_path + f"{c}/test.csv", dtype={"ECHO_ID": str}) for c in ["client3"]
    ])
test_meta.reset_index(inplace=True)
test_dataset = ECHODataset(test_meta, base_path, ["client3"], "records.h5", 4, frac=1)
print(len(test_dataset))

In [3]:
print(test_dataset[0][1][0])

In [9]:
test_meta = pd.concat([
    pd.read_csv(base_path + f"{c}/train.csv", dtype={"ECHO_ID": str}) for c in ["client3"]
    ])
test_meta.reset_index(inplace=True)
test1_dataset = ECHODataset(test_meta, base_path, ["client3"], "records.h5", 4, frac=0.2)
print(len(test1_dataset), len(test_dataset))

In [10]:
loader1 = DataLoader(test1_dataset, batch_size=32, shuffle=False)
loader2 = DataLoader(test_dataset, batch_size=32, shuffle=False)
for x, y in zip(loader1, loader2):
    print(x[0].shape, y[0].shape)
for x in loader1:
    print(x[0].shape)

In [2]:
client = "client1"
base_path = "/data/zyk/data/dataset/ECHO/preprocessed/"
train_meta = pd.read_csv(base_path + f"{client}/train.csv", dtype={"ECHO_ID": str})
warm_up_dataset = ECHODataset(train_meta, base_path, [client], "records.h5", 4)
# train_dataset = ECHODataset(train_meta, base_path, [client], "records.h5", 4)
train_meta = pd.concat([
    pd.read_csv(base_path + f"{c}/train.csv", dtype={"ECHO_ID": str}) for c in ["client2", "client3"]
    ])
train_meta.reset_index(inplace=True)
train_dataset = ECHODataset(train_meta, base_path, ["client2", "client3"], "records.h5", 4, frac=1)
test_meta = pd.concat([
    pd.read_csv(base_path + f"{c}/test.csv", dtype={"ECHO_ID": str}) for c in ["client1", "client2", "client3"]
    ])
test_meta.reset_index(inplace=True)
test_dataset = ECHODataset(test_meta, base_path, ["client1", "client2", "client3"], "records.h5", 4, frac=1)

In [3]:
batch_size = 64
lr = 0.1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 50
model = resnet50()
model.backbone["conv1"] = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# model = unet()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

In [4]:
def dice(pred, target, class_set=(1, 2, 3)):
    if pred.dim() == 4:
        pred = torch.argmax(pred, dim=1)
    dice_scores = torch.zeros(len(class_set))
    for idx, c in enumerate(class_set):
        pred_c = torch.where(pred == c, 1, 0)
        target_c = torch.where(target == c, 1, 0)
        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum()
        dice_c = 2 * intersection / union
        dice_scores[idx] = dice_c
    return torch.mean(dice_scores)

def alpha(epoch):
    if epoch < 10:
        return 0.01 * epoch
    return 1


In [5]:
warm_up_loader = DataLoader(warm_up_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [26]:
dice_macro = Dice(ignore_index=0, num_classes=4, average="macro").to(device)
dice_micro = Dice(ignore_index=0, num_classes=4, average="micro").to(device)
model.to(device)

In [7]:
warm_up_metric = Accumulator(5)
description = "Train Loss {:.8f} | Dice Macro:{:.2f} | Dice Micro:{:.2f} | HD:{:.2f}"
bar = tqdm.tqdm(initial=0, leave=True, total=len(warm_up_loader) * 10,
                 desc=description.format(0, 0, 0, 0), position=0)
for epoch in range(10):
    model.train()
    for i, (data, target, label_type) in enumerate(warm_up_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)["out"]
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            pred_label = torch.argmax(output, dim=1)
            shield_pred_label = shield(pred_label, label_type)
        warm_up_metric.add(
            float(loss) * len(data),
            dice_macro(shield_pred_label, target) * len(data),
            dice_micro(shield_pred_label, target) * len(data),
            # 0,
            cal_hd(shield_pred_label.cpu().numpy(), target.cpu().numpy(), mask=label_type) * len(data),
            len(data))
        bar.desc = description.format( warm_up_metric[0] /  warm_up_metric[-1],  warm_up_metric[1] /  warm_up_metric[-1],  warm_up_metric[2] /  warm_up_metric[-1],  warm_up_metric[3] /  warm_up_metric[-1])
        bar.update(1)
bar.close()
        # if i % 10 == 0:
        #     print(f"Epoch {epoch}, Step {i}, Loss {loss.item()}")

In [31]:
for x, y, z in warm_up_loader:
    x, y, z = x[:3], y[:3], z[:3]
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        output = model(x)["out"]
        pred_label = torch.argmax(output, dim=1)
        pseudo_label = generate_pseudo_label(pred_label, y, z)
        shield_pred_label = shield(pred_label, z)
        # print(id(shield_pred_label), id(pred_label))
        fig, ax = plt.subplots(3, 5)
        for i in range(3):
            ax[i, 0].imshow(x[i][0].cpu().numpy(), cmap="gray")
            ax[i, 1].imshow(y[i].cpu().numpy())
            ax[i, 2].imshow(shield_pred_label[i].cpu().numpy())
            ax[i, 3].imshow(pseudo_label[i].cpu().numpy())
            ax[i, 4].imshow(pred_label[i].cpu().numpy())
        plt.show()
    break

In [32]:
metric = Accumulator(5)
description = "Train Loss {:.8f} | Dice Macro:{:.2f} | Dice Micro:{:.2f} | HD:{:.2f}"
bar = tqdm.tqdm(initial=0, leave=True, total=(len(train_loader) + len(warm_up_loader)) * epochs,
                 desc=description.format(0, 0, 0, 0), position=0)
for epoch in range(epochs):
    model.train()
    for i, (data, target, label_type) in enumerate(warm_up_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)["out"]
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            pred_label = torch.argmax(output, dim=1)
            shield_pred_label = shield(pred_label, label_type)
        metric.add(
            float(loss) * len(data),
            dice_macro(shield_pred_label, target) * len(data),
            dice_micro(shield_pred_label, target) * len(data),
            # 0,
            cal_hd(shield_pred_label.cpu().numpy(), target.cpu().numpy(), mask=label_type) * len(data),
            len(data))
        bar.desc = description.format(metric[0] / metric[-1], metric[1] / metric[-1], metric[2] / metric[-1], metric[3] / metric[-1])
        bar.update(1)
    for i, (data, target, label_type) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            output = model(data)["out"]
            pred_label = torch.argmax(output, dim=1)
            # print(data.shape, output.shape, label_type.shape)
            pseudo_target = generate_pseudo_label(pred_label, target, label_type)
        output = model(data)["out"]
        loss = alpha(epoch) * criterion(output, pseudo_target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            pred_label = torch.argmax(output, dim=1)
            shield_pred_label = shield(pred_label, label_type)
        metric.add(
            float(loss) * len(data),
            dice_macro(shield_pred_label, target) * len(data),
            dice_micro(shield_pred_label, target) * len(data),
            # 0,
            cal_hd(shield_pred_label.cpu().numpy(), target.cpu().numpy(), mask=label_type) * len(data),
            len(data))
        bar.desc = description.format(metric[0] / metric[-1], metric[1] / metric[-1], metric[2] / metric[-1], metric[3] / metric[-1])
        bar.update(1)
bar.close()


In [12]:
test_metric = Accumulator(5)
description = "Test Loss {:.8f} | Dice Macro:{:.2f} | Dice Micro:{:.2f} | HD:{:.2f}"
bar = tqdm.tqdm(initial=0, leave=True, total=len(test_loader),
                 desc=description.format(0, 0, 0, 0), position=0)
model.train()
for i, (data, target, label_type) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad():
        output = model(data)["out"]
        loss = criterion(output, target)
        pred_label = torch.argmax(output, dim=1)
        shield_pred_label = shield(pred_label, label_type)
        test_metric.add(
            float(loss) * len(data),
            dice_macro(shield_pred_label, target) * len(data),
            dice_micro(shield_pred_label, target) * len(data),
            # 0,
            cal_hd(shield_pred_label.cpu().numpy(), target.cpu().numpy(), mask=label_type) * len(data),
            len(data))
    bar.desc = description.format(test_metric[0] / test_metric[-1], test_metric[1] / test_metric[-1],
                                  test_metric[2] / test_metric[-1], test_metric[3] / test_metric[-1]
                                  )
    bar.update(1)
bar.close()

In [7]:
print(metric[0] / metric[2], metric[1] / metric[2])
print(test_metric[0] / test_metric[2], test_metric[1] / test_metric[2])

In [74]:
# torch.save({
#     "model": model.state_dict()
# }, "/data/zyk/code/resnet.pth"
# )

In [22]:
model.eval()
model.to(torch.device("cpu"))
x, l, z = test_dataset[0]
x = x.reshape(1, *x.shape)
y = model(x)["out"]
py = shield(y, z)

In [23]:
hd = cal_hd(py.detach().numpy().reshape(1, 112, 112), l.detach().numpy().reshape(1, 112, 112), z)
print(hd)

In [24]:
# print(dice(y, l))
d = Dice(ignore_index=0, num_classes=4, average="macro")
print(d(py, l))
# print(dice(py, l, class_set=[1]))
# print(py)

In [25]:
fig, ax = plt.subplots(1, 3)
ax[0].imshow(x[0][0].numpy(), cmap="gray")
ax[0].imshow(l.numpy(), alpha=0.5)
ax[1].imshow(x[0][0].numpy(), cmap="gray")
ax[1].imshow(py[0].detach().numpy(), alpha=0.5)
ax[2].imshow(x[0][0].numpy(), cmap="gray")
ax[2].imshow(y.argmax(dim=1).detach().numpy()[0], alpha=0.5)
plt.show()
# gt = l.numpy()
# p = py.numpy()
# plt.imshow(gt)
# plt.show()
# plt.imshow(p[0])
# plt.show()
# plt.imshow(y.argmax(dim=1).detach().numpy()[0])

In [17]:
print(l[75])

In [40]:
x = torch.tensor([0, 2, 1, 1])
y = torch.tensor([0, 0, 1, 2])
d = Dice(ignore_index=None, num_classes=3, average="macro")
print(d(x, y))