In [1]:
from src.constants import *
from src.model.my_model import MyModel
from tqdm import  tqdm
from monai.networks import one_hot
from src.data.data_loaders import get_loaders
from src.utils.metrics import dice_scores
import numpy as np
import matplotlib.pyplot as plt

<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.
<class 'monai.transforms.utility.array.AsChannelFirst'>: Class `AsChannelFirst` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.


In [2]:
train_loader, test_loader = get_loaders()


Loading dataset:  63%|██████▎   | 44/70 [09:00<03:15,  7.52s/it]IOStream.flush timed out
Loading dataset: 100%|██████████| 70/70 [11:09<00:00,  9.56s/it]
Loading dataset:   0%|          | 0/10 [00:00<?, ?it/s]IOStream.flush timed out
IOStream.flush timed out
Loading dataset: 100%|██████████| 10/10 [00:56<00:00,  5.62s/it]


In [3]:

def train_single_epoch(model, optimizer, train_loader, loss_fn):
    losses = []

    for d in train_loader:

        img = d['img'].to(DEVICE)
        mask = d['mask'].to(DEVICE)

        optimizer.zero_grad()

        outputs = model(img)

        outputs = torch.softmax(outputs, dim=1)
        mask = one_hot(mask, num_classes=3)

        loss = loss_fn(outputs, mask)
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().item())
    print((sum(losses) / len(losses)))
    return model

In [4]:
def vis(image, label, prediction, scores):

    image = image.cpu().numpy()
    label = label.cpu().numpy()
    prediction = prediction.cpu().numpy()

    # Force a 2D example where at least one class is present
    while True:
        # Choose a random slice index
        slice_idx = np.random.randint(0, image.shape[2])

        # Extract the 2D slices
        image_slice = image[0, 0, slice_idx, :, :]
        label_slice = label[0, 0, slice_idx, :, :]
        pred_slice = prediction[0, 0, slice_idx, :, :]

        if len(np.unique(label_slice)) > 1:
            break

    # Plot the slices
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f'Scores of this slice: {scores}')
    ax1.imshow(image_slice)
    ax1.set_title("Image")
    ax2.imshow(label_slice)
    ax2.set_title("Label")
    ax3.imshow(pred_slice)
    ax3.set_title("Prediction")

    # Remove axis ticks
    for ax in [ax1, ax2, ax3]:
        ax.set_xticks([])
        ax.set_yticks([])

    plt.show()


In [5]:
def test_single_epoch(model, test_loader):
    model.eval()
    scores = []

    largest_component = monai.transforms.KeepLargestConnectedComponent()

    visualised = False

    with torch.no_grad():
        for d in test_loader:
            img = d['img'].to(DEVICE)
            mask = d['mask'].to(DEVICE)

            out = monai.inferers.sliding_window_inference(img,
                                                          roi_size=CROP_SIZE,
                                                          sw_batch_size=BATCH_SIZE,
                                                          predictor=model,
                                                          overlap=0.5,
                                                          sw_device=DEVICE,
                                                          device="cpu",
                                                          progress=False,
                                                          )
            out = torch.argmax(out, 1, keepdim=True)
            out = largest_component(out).to(DEVICE)
            s = dice_scores(out, mask)
            print(s)
            scores.append(s)
            vis(img, mask, out, s)

        scores = np.array(scores)
        scores = np.nan_to_num(scores, copy=True, nan=1.0)
        scores = np.sum(scores, axis=0) / scores.shape[0]

    test_score = np.sum(scores) / (len(scores) * 1.0)

In [6]:
torch.cuda.empty_cache()

In [7]:
model = MyModel(in_channels=1,
                 out_channels=3,
                 lower_channels=32,
                 big_channel=16,
                 patch_size=8,
                 embed_dim=256,
                 skip_transformer=False).cuda()

optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-4)
loss_fn = monai.losses.DiceCELoss(lambda_ce=0.4)

In [None]:
for e in tqdm(range(200)):
    train_single_epoch(model, optimizer, train_loader, loss_fn)

  0%|          | 1/200 [00:46<2:34:52, 46.69s/it]

1.135796308517456


  1%|          | 2/200 [01:31<2:30:38, 45.65s/it]

0.9668397801262992


  2%|▏         | 3/200 [02:16<2:28:57, 45.37s/it]

0.8440490620476859


  2%|▏         | 4/200 [03:01<2:27:42, 45.22s/it]

0.74656183719635


  2%|▎         | 5/200 [03:46<2:26:43, 45.14s/it]

0.7379967195647104


  3%|▎         | 6/200 [04:31<2:25:51, 45.11s/it]

0.7320429878575462


  4%|▎         | 7/200 [05:16<2:24:57, 45.07s/it]

0.7490974034581865


  4%|▍         | 8/200 [06:01<2:24:14, 45.07s/it]

0.7043127238750457


  4%|▍         | 9/200 [06:46<2:23:14, 45.00s/it]

0.6533981731959752


  5%|▌         | 10/200 [07:31<2:22:28, 44.99s/it]

0.7024280250072479


  6%|▌         | 11/200 [08:16<2:21:42, 44.99s/it]

0.7348287326948983


  6%|▌         | 12/200 [09:01<2:20:55, 44.98s/it]

0.6754819623061589


  6%|▋         | 13/200 [09:46<2:20:09, 44.97s/it]

0.6255967600005014


  7%|▋         | 14/200 [10:31<2:19:18, 44.94s/it]

0.6979541097368512


  8%|▊         | 15/200 [11:16<2:18:35, 44.95s/it]

0.6808399489947727


  8%|▊         | 16/200 [12:01<2:17:56, 44.98s/it]

0.6809216976165772


  8%|▊         | 17/200 [12:46<2:17:09, 44.97s/it]

0.6846241559301104


  9%|▉         | 18/200 [13:31<2:16:38, 45.05s/it]

0.6590175424303327


 10%|▉         | 19/200 [14:16<2:15:48, 45.02s/it]

0.6722709485462733


 10%|█         | 20/200 [15:01<2:15:02, 45.01s/it]

0.6397929753576006


 10%|█         | 21/200 [15:46<2:14:25, 45.06s/it]

0.6712429157325199


 11%|█         | 22/200 [16:31<2:13:36, 45.04s/it]

0.7329518420355661


 12%|█▏        | 23/200 [17:16<2:13:00, 45.09s/it]

0.6638926369803292


 12%|█▏        | 24/200 [18:01<2:12:17, 45.10s/it]

0.6558560405458723


 12%|█▎        | 25/200 [18:46<2:11:24, 45.05s/it]

0.6517289689608983


 13%|█▎        | 26/200 [19:31<2:10:41, 45.06s/it]

0.6774904872689929


 14%|█▎        | 27/200 [20:16<2:09:52, 45.04s/it]

0.7345234717641558


 14%|█▍        | 28/200 [21:01<2:09:05, 45.03s/it]

0.6361948932920184


 14%|█▍        | 29/200 [21:47<2:08:20, 45.03s/it]

0.6380403595311301


 15%|█▌        | 30/200 [22:32<2:07:33, 45.02s/it]

0.6520633007798876


 16%|█▌        | 31/200 [23:17<2:06:52, 45.04s/it]

0.6631755948066711


 16%|█▌        | 32/200 [24:01<2:05:57, 44.99s/it]

0.6850684131894793


 16%|█▋        | 33/200 [24:47<2:05:18, 45.02s/it]

0.6670532737459455


 17%|█▋        | 34/200 [25:32<2:04:38, 45.05s/it]

0.6760391056537628


 18%|█▊        | 35/200 [26:17<2:03:49, 45.03s/it]

0.7101562891687666


 18%|█▊        | 36/200 [27:02<2:03:02, 45.02s/it]

0.6628985209124428


 18%|█▊        | 37/200 [27:47<2:02:15, 45.00s/it]

0.6097584792545864


 19%|█▉        | 38/200 [28:32<2:01:28, 44.99s/it]

0.6194394298962185


 20%|█▉        | 39/200 [29:16<2:00:39, 44.96s/it]

0.6763225972652436


 20%|██        | 40/200 [30:01<1:59:52, 44.96s/it]

0.6802809468337467


 20%|██        | 41/200 [30:46<1:59:11, 44.98s/it]

0.661304691859654


 21%|██        | 42/200 [31:32<1:58:31, 45.01s/it]

0.6367316348212105


 22%|██▏       | 43/200 [32:17<1:57:46, 45.01s/it]

0.6885108964783805


 22%|██▏       | 44/200 [33:02<1:57:07, 45.05s/it]

0.6330111571720668


 22%|██▎       | 45/200 [33:47<1:56:22, 45.05s/it]

0.6389743413243975


 23%|██▎       | 46/200 [34:32<1:55:39, 45.06s/it]

0.6227420832429613


 24%|██▎       | 47/200 [35:17<1:54:48, 45.02s/it]

0.6568229241030556


In [None]:
%%time
test_single_epoch(model=model, test_loader=test_loader)

In [None]:
d =  next(iter(train_loader))

In [None]:
d['img'].shape