In [1]:
from src.constants import *
from src.model.my_model import MyModel
from tqdm import  tqdm
from monai.networks import one_hot
from src.data.data_loaders import get_loaders
from src.utils.metrics import dice_scores
import numpy as np


<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.
<class 'monai.transforms.utility.array.AsChannelFirst'>: Class `AsChannelFirst` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.


In [2]:
train_loader, test_loader = get_loaders()


Loading dataset:  51%|█████▏    | 36/70 [07:23<03:40,  6.47s/it]IOStream.flush timed out
Loading dataset:  56%|█████▌    | 39/70 [08:19<07:36, 14.73s/it]IOStream.flush timed out
Loading dataset: 100%|██████████| 70/70 [11:03<00:00,  9.48s/it]
Loading dataset:   0%|          | 0/10 [00:00<?, ?it/s]IOStream.flush timed out
Loading dataset: 100%|██████████| 10/10 [01:02<00:00,  6.24s/it]


In [3]:

def train_single_epoch(model, optimizer, train_loader, loss_fn):
    losses = []

    for d in train_loader:

        img = d['img'].to(DEVICE)
        mask = d['mask'].to(DEVICE)

        optimizer.zero_grad()

        outputs = model(img)

        outputs = torch.softmax(outputs, dim=1)
        mask = one_hot(mask, num_classes=3)

        loss = loss_fn(outputs, mask)
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().item())
    print((sum(losses) / len(losses)))
    return model

In [4]:
def test_single_epoch(model, test_loader):
    model.eval()
    scores = []

    largest_component = monai.transforms.KeepLargestConnectedComponent()

    visualised = False

    with torch.no_grad():
        for d in test_loader:
            img = d['img'].to(DEVICE)
            mask = d['mask'].to(DEVICE)

            out = monai.inferers.sliding_window_inference(img,
                                                          roi_size=CROP_SIZE,
                                                          sw_batch_size=BATCH_SIZE,
                                                          predictor=model,
                                                          overlap=0.5,
                                                          sw_device=DEVICE,
                                                          device="cpu",
                                                          progress=False,
                                                          )
            out = torch.argmax(out, 1, keepdim=True)
            out = largest_component(out).to(DEVICE)
            s = dice_scores(out, mask)
            print(s)
            scores.append(s)

        scores = np.array(scores)
        scores = np.nan_to_num(scores, copy=True, nan=1.0)
        scores = np.sum(scores, axis=0) / scores.shape[0]

    test_score = np.sum(scores) / (len(scores) * 1.0)

In [5]:
torch.cuda.empty_cache()

In [8]:
# model = MyModel(in_channels=1,
#                  out_channels=3,
#                  lower_channels=32,
#                  big_channel=16,
#                  patch_size=8,
#                  embed_dim=256,
#                  skip_transformer=False).cuda()

optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-4)
loss_fn = monai.losses.DiceCELoss()

In [9]:
for e in tqdm(range(50)):
    train_single_epoch(model, optimizer, train_loader, loss_fn)

  2%|▏         | 1/50 [00:44<36:24, 44.59s/it]

1.0664206913539342


  4%|▍         | 2/50 [01:29<35:37, 44.54s/it]

1.0156799282346454


  6%|▌         | 3/50 [02:13<34:55, 44.58s/it]

1.0475993803569248


  8%|▊         | 4/50 [02:58<34:11, 44.60s/it]

1.062489209856306


 10%|█         | 5/50 [03:42<33:26, 44.59s/it]

1.0200574210711888


 12%|█▏        | 6/50 [04:27<32:43, 44.61s/it]

1.0444366574287414


 14%|█▍        | 7/50 [05:12<31:57, 44.59s/it]

1.0406746557780675


 16%|█▌        | 8/50 [05:56<31:13, 44.60s/it]

0.993353681904929


 18%|█▊        | 9/50 [06:41<30:29, 44.61s/it]

1.0116288355418614


 20%|██        | 10/50 [07:25<29:44, 44.61s/it]

0.973262459891183


 22%|██▏       | 11/50 [08:10<29:00, 44.62s/it]

1.0201862045696803


 24%|██▍       | 12/50 [08:55<28:14, 44.58s/it]

0.9892500945499965


 26%|██▌       | 13/50 [09:39<27:28, 44.56s/it]

1.029745806966509


 28%|██▊       | 14/50 [10:24<26:44, 44.57s/it]

1.0181196246828352


 30%|███       | 15/50 [11:08<26:00, 44.60s/it]

1.0682941555976868


 32%|███▏      | 16/50 [11:53<25:14, 44.55s/it]

1.0478756172316415


 34%|███▍      | 17/50 [12:37<24:30, 44.55s/it]

1.02307243176869


 36%|███▌      | 18/50 [13:22<23:45, 44.56s/it]

0.9821970428739275


 38%|███▊      | 19/50 [14:06<23:01, 44.55s/it]

1.0019778200558254


 40%|████      | 20/50 [14:51<22:16, 44.56s/it]

0.9447185039520264


 42%|████▏     | 21/50 [15:36<21:32, 44.58s/it]

0.977519861289433


 44%|████▍     | 22/50 [16:20<20:48, 44.59s/it]

1.008910986355373


 46%|████▌     | 23/50 [17:05<20:04, 44.61s/it]

0.970756823675973


 48%|████▊     | 24/50 [17:49<19:19, 44.59s/it]

0.9789125561714173


 50%|█████     | 25/50 [18:34<18:34, 44.57s/it]

0.9870594246046883


 52%|█████▏    | 26/50 [19:19<17:49, 44.55s/it]

1.009373220375606


 54%|█████▍    | 27/50 [20:03<17:05, 44.57s/it]

0.9957168919699533


 56%|█████▌    | 28/50 [20:48<16:20, 44.56s/it]

0.9669926302773612


 58%|█████▊    | 29/50 [21:32<15:36, 44.58s/it]

0.9872546144894191


 60%|██████    | 30/50 [22:17<14:51, 44.57s/it]

0.9524638022695269


 62%|██████▏   | 31/50 [23:01<14:06, 44.55s/it]

0.9761248128754753


 64%|██████▍   | 32/50 [23:46<13:22, 44.58s/it]

1.0266621913228715


 66%|██████▌   | 33/50 [24:31<12:37, 44.57s/it]

1.0056756275040764


 68%|██████▊   | 34/50 [25:15<11:53, 44.58s/it]

1.0063546197754996


 70%|███████   | 35/50 [26:00<11:08, 44.60s/it]

0.9833311796188354


 72%|███████▏  | 36/50 [26:44<10:24, 44.59s/it]

1.018745049408504


 74%|███████▍  | 37/50 [27:29<09:39, 44.57s/it]

1.0059087736266


 76%|███████▌  | 38/50 [28:14<08:55, 44.60s/it]

0.9870877010481698


 78%|███████▊  | 39/50 [28:58<08:10, 44.57s/it]

0.9709389873913357


 80%|████████  | 40/50 [29:43<07:25, 44.58s/it]

1.0454989433288575


 82%|████████▏ | 41/50 [30:27<06:40, 44.55s/it]

0.9559375831059047


 84%|████████▍ | 42/50 [31:12<05:56, 44.55s/it]

0.9836367181369237


 86%|████████▌ | 43/50 [31:56<05:11, 44.56s/it]

0.9725531509944371


 88%|████████▊ | 44/50 [32:41<04:27, 44.56s/it]

0.9694312368120466


 90%|█████████ | 45/50 [33:25<03:42, 44.52s/it]

1.024906373023987


 92%|█████████▏| 46/50 [34:10<02:58, 44.52s/it]

0.9358107362474714


 94%|█████████▍| 47/50 [34:54<02:13, 44.52s/it]

1.0425258534295219


 96%|█████████▌| 48/50 [35:39<01:29, 44.54s/it]

1.015285221167973


 98%|█████████▊| 49/50 [36:23<00:44, 44.52s/it]

1.0028234022004263


100%|██████████| 50/50 [37:08<00:00, 44.57s/it]

0.9746342624936785





In [10]:
%%time
test_single_epoch(model=model, test_loader=test_loader)

[0.9984386  0.73343122 0.6826582 ]
[0.9956255 0.        0.       ]
[0.99555528 0.         0.60870707]
[0.9980402  0.         0.58109879]


KeyboardInterrupt: 

In [None]:
d =  next(iter(train_loader))

In [None]:
d['img'].shape