In [3]:
from src.constants import *
from src.model.my_model import MyModel
from tqdm import  tqdm
from monai.networks import one_hot
from src.data.data_loaders import get_loaders
from src.utils.metrics import dice_scores
import numpy as np


In [4]:
train_loader, test_loader = get_loaders()


Loading dataset:   0%|          | 0/70 [00:00<?, ?it/s]IOStream.flush timed out
Loading dataset:  51%|█████▏    | 36/70 [03:55<01:54,  3.38s/it]IOStream.flush timed out
Loading dataset: 100%|██████████| 70/70 [06:31<00:00,  5.59s/it]
Loading dataset:   0%|          | 0/10 [00:00<?, ?it/s]IOStream.flush timed out
Loading dataset:   0%|          | 0/10 [00:33<?, ?it/s]


RuntimeError: applying transform <monai.transforms.utility.dictionary.ToTensord object at 0x7fd880133a90>

In [3]:

def train_single_epoch(model, optimizer, train_loader, loss_fn):
    losses = []

    for d in train_loader:

        img = d['img'].to(DEVICE)
        mask = d['mask'].to(DEVICE)

        optimizer.zero_grad()

        outputs = model(img)

        outputs = torch.softmax(outputs, dim=1)
        mask = one_hot(mask, num_classes=3)

        loss = loss_fn(outputs, mask)
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().item())
    print((sum(losses) / len(losses)))
    return model

In [4]:
def test_single_epoch(model, test_loader):
    model.eval()
    scores = []

    largest_component = monai.transforms.KeepLargestConnectedComponent()

    visualised = False

    with torch.no_grad():
        for d in test_loader:
            img = d['img'].to(DEVICE)
            mask = d['mask'].to(DEVICE)

            out = monai.inferers.sliding_window_inference(img,
                                                          roi_size=CROP_SIZE,
                                                          sw_batch_size=BATCH_SIZE,
                                                          predictor=model,
                                                          overlap=0.5,
                                                          sw_device=DEVICE,
                                                          device="cpu",
                                                          progress=False,
                                                          )
            out = torch.argmax(out, 1, keepdim=True)
            out = largest_component(out).to(DEVICE)
            s = dice_scores(out, mask)
            print(s)
            scores.append(s)

        scores = np.array(scores)
        scores = np.nan_to_num(scores, copy=True, nan=1.0)
        scores = np.sum(scores, axis=0) / scores.shape[0]

    test_score = np.sum(scores) / (len(scores) * 1.0)

In [5]:
model = MyModel(in_channels=1, patch_size=PATCH_SIZE, out_channels=3, skip_transformer=False, channels=(32, 32, 32, 32, 32),transformer_channels=16, embed_dim=256).to(DEVICE)
optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-4)
loss_fn = monai.losses.DiceCELoss()

In [7]:
for e in tqdm(range(20)):
    c_model = train_single_epoch(model, optimizer, train_loader, loss_fn)

  0%|          | 0/20 [00:00<?, ?it/s]To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
[2023-04-24 13:51:36,749] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-04-24 13:51:36,938] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-04-24 13:51:46,530] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 0
[2023-04-24 13:51:46,631] torch._inductor.graph: [INFO] Using FallbackKernel: aten.max_pool3d_with_indices
[2023-04-24 13:51:47,923] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 0
[2023-04-24 13:51:47,926] torch._dynamo.output_g

1.2051889101664226


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
 10%|█         | 2/20 [02:18<18:20, 61.13s/it] 

1.1841529409090679


 15%|█▌        | 3/20 [02:41<12:28, 44.03s/it]

1.1700243420071073


 20%|██        | 4/20 [03:05<09:39, 36.21s/it]

1.1496473087204828


 25%|██▌       | 5/20 [03:29<07:56, 31.77s/it]

1.1403683026631672


 30%|███       | 6/20 [03:53<06:45, 28.94s/it]

1.1441215011808608


 35%|███▌      | 7/20 [04:17<05:56, 27.40s/it]

1.1437821057107713


 40%|████      | 8/20 [04:40<05:13, 26.12s/it]

1.1109963721699185


 45%|████▌     | 9/20 [05:04<04:38, 25.27s/it]

1.116151147418552


 50%|█████     | 10/20 [05:27<04:07, 24.71s/it]

1.1148061487409804


 55%|█████▌    | 11/20 [05:49<03:34, 23.83s/it]

1.1185422208574083


 60%|██████    | 12/20 [06:11<03:04, 23.09s/it]

1.1168451640341017


 65%|██████▌   | 13/20 [06:32<02:38, 22.62s/it]

1.087371051311493


 70%|███████   | 14/20 [06:53<02:13, 22.24s/it]

1.0976895027690463


 75%|███████▌  | 15/20 [07:15<01:49, 21.99s/it]

1.1047762698597379


 80%|████████  | 16/20 [07:36<01:27, 21.82s/it]

1.0886463854047987


 85%|████████▌ | 17/20 [07:58<01:05, 21.73s/it]

1.0858200788497925


 90%|█████████ | 18/20 [08:19<00:43, 21.66s/it]

1.087765759891934


 95%|█████████▌| 19/20 [08:41<00:21, 21.63s/it]

1.0923849675390456


100%|██████████| 20/20 [09:02<00:00, 27.14s/it]

1.0829876793755426





In [13]:
%%time
test_single_epoch(model=model, test_loader=test_loader)

[0.94862133 0.11096717 0.01727696]
[0.9219479  0.04001863 0.00209648]
[0.90451807 0.01981635 0.02504121]
[0.94322503 0.11020487 0.05520159]


KeyboardInterrupt: 