In [1]:
import os
import torch
from trainer import Trainer

from Preprocessing import (
    filter_collate_fn,
    generate_file_path,
    get_train_transform,
    get_validation_transform
)

from monai.data import (
    DataLoader,
    ThreadDataLoader,
    SmartCacheDataset,
    PersistentDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)

from monai.losses import DiceCELoss

from monai.networks.nets import SwinUNETR
from map_to_binary import class_map_5_parts
from monai.data.utils import pad_list_data_collate


MONAI version: 1.2.dev2313
Numpy version: 1.23.5
Pytorch version: 2.0.0+cu118
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 331437dfce075b4fa5785016ad4e7f8c7c77ad21
MONAI __file__: D:\Arash\Semester2\DeepLearning\FinalProject\venv\lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.3.0
Nibabel version: 5.0.1
scikit-image version: 0.20.0
Pillow version: 9.4.0
Tensorboard version: 2.12.0
gdown version: 4.7.1
TorchVision version: 0.15.1+cu118
tqdm version: 4.65.0
lmdb version: 1.4.0
psutil version: 5.9.4
pandas version: 1.5.3
einops version: 0.6.0
transformers version: 4.27.3
mlflow version: 2.2.2
pynrrd version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [2]:
task_name = 'ribs'
val_transforms = get_validation_transform()
train_transforms = get_train_transform(num_samples=4)

# Load Dataset

In [3]:
root_dataset = 'DatasetCombined'
label_name = f'labels_task_{task_name}'

In [None]:
file_list_val = generate_file_path(root_path=f'{root_dataset}/val', label_name=label_name)
file_list_train = generate_file_path(root_path=f'{root_dataset}/train', label_name=label_name)

In [4]:
train_ds = PersistentDataset(
    data=file_list_train,
    transform=train_transforms,
    cache_dir=f'C:/Training/train_{task_name}'
)

train_loader = ThreadDataLoader(train_ds, batch_size=1, shuffle=True,
                                collate_fn=lambda x: pad_list_data_collate(x, pad_to_shape=(96, 96, 96))
                               )

In [5]:
val_ds = PersistentDataset(
    data=file_list_val,
    transform=val_transforms,
    cache_dir=f'val_{task_name}'
#     cache_dir='C:/Training/val'
)

val_loader = DataLoader(val_ds, batch_size=1, num_workers=0,
                              collate_fn=lambda x: pad_list_data_collate(x, pad_to_shape=(96, 96, 96)))

In [4]:
num_of_classes = len(class_map_5_parts[f'class_map_part_{task_name}']) + 1
num_of_classes

25

# Model

In [5]:
root_dir = 'Model'
model_folder = os.path.join(root_dir, f'SwinTransformer_{task_name}')
if not os.path.isdir(model_folder):
    os.makedirs(model_folder)

In [6]:
model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=num_of_classes,
    feature_size=48,
#     drop_rate=0.25,
    use_checkpoint=True,
)

In [7]:
loss_function = DiceCELoss(to_onehot_y=True, softmax=True, lambda_dice=1, lambda_ce=1)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [8]:
swin_trainer = Trainer(
    model=model,
    max_epoch=1000,
    optimizer=optimizer,
    num_samples=4,
    loss_function=loss_function,
    model_root_path=model_folder,
    number_of_classes=num_of_classes,
    )
swin_trainer.load_weights()

In [11]:
torch.backends.cudnn.benchmark = True
swin_trainer.train(train_loader, val_loader)

Training (X / X Steps) (loss=X.X):   0%|                                                      | 0/1081 [00:00<?, ?it/s]

metatensor(4.7374, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=4.73743):   0%|                                     | 1/1081 [00:10<3:01:31, 10.08s/it]

metatensor(4.5330, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=4.53298):   0%|                                     | 2/1081 [00:10<1:23:18,  4.63s/it]Num foregrounds 0, Num backgrounds 1975669, unable to generate class balanced samples, setting `pos_ratio` to 0.


metatensor(4.3436, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=4.34364):   0%|                                       | 3/1081 [00:11<50:23,  2.81s/it]Num foregrounds 0, Num backgrounds 2582709, unable to generate class balanced samples, setting `pos_ratio` to 0.


metatensor(4.2770, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=4.27696):   0%|▏                                      | 4/1081 [00:12<34:57,  1.95s/it]

metatensor(4.1567, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=4.15671):   1%|▎                                      | 7/1081 [00:13<16:09,  1.11it/s]

metatensor(4.0244, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=4.02439):   1%|▎                                      | 8/1081 [00:13<15:12,  1.18it/s]

metatensor(4.2272, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=4.22721):   1%|▎                                      | 9/1081 [00:14<14:12,  1.26it/s]

metatensor(3.9478, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=3.94779):   1%|▎                                     | 10/1081 [00:14<13:28,  1.33it/s]Num foregrounds 0, Num backgrounds 3609327, unable to generate class balanced samples, setting `pos_ratio` to 0.


metatensor(3.8630, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=3.86303):   1%|▍                                     | 11/1081 [00:15<12:51,  1.39it/s]

metatensor(3.9375, device='cuda:0', grad_fn=<AliasBackward0>)


Training (0 / 1000 Steps) (loss=3.86303):   1%|▍                                     | 11/1081 [00:16<25:58,  1.46s/it]


KeyboardInterrupt: 

In [12]:
# swin_trainer.continue_train()

0.7926767468452454 38


In [9]:
file_list_test = generate_file_path(root_path=f'{root_dataset}/test', label_name=label_name)
test_ds = PersistentDataset(
    data=file_list_test,
    transform=val_transforms,
    cache_dir=f'test_{task_name}'
#     cache_dir='C:/Training/val'
)

test_loader = DataLoader(test_ds, num_workers=0, batch_size=1, 
                              collate_fn=lambda x: pad_list_data_collate(x, pad_to_shape=(96, 96, 96)))

In [10]:
swin_trainer.test_new(test_loader)

Test:   0%|                                                                                     | 0/59 [00:00<?, ?it/s]None of the inputs have requires_grad=True. Gradients will be None
invalid value encountered in double_scalars
torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
Test (loss=0.42146): 100%|█████████████████████████████████████████████████████████████| 59/59 [15:20<00:00, 15.60s/it]


([0.895768834475794,
  0.9201364163609356,
  0.8860244164216201,
  0.8946778937836815,
  0.8392086691770765,
  0.8177011791745522,
  0.8418752117262379,
  0.7761074226873641,
  0.8100964495235019,
  0.8244390337758317,
  0.8728070978232708,
  0.9111084391081988,
  0.9112752588282829,
  0.9435089411592043,
  0.911306074623276,
  0.8385566770126578,
  0.792008334060771,
  0.7786670859170043,
  0.7955097125955667,
  0.7963218980480847,
  0.8099165194143754,
  0.8194729435963692,
  0.83535694010435,
  0.8760811499505977],
 [13.295742517532162,
  13.287056361847812,
  18.137136281267434,
  15.62910553601502,
  19.47151678960597,
  17.445150418608655,
  18.232191544636777,
  22.82304378002099,
  20.43339424720071,
  22.323808389604803,
  20.460160764781023,
  7.388015448552459,
  8.657086733270763,
  4.890540851659967,
  10.27575774355458,
  20.831213899995447,
  21.65831466475856,
  23.578638529506,
  21.38128449599841,
  18.88657198196915,
  19.79781790732165,
  22.44144949430811,
  24.035

In [12]:
swin_trainer.test_new(test_loader)

Test (loss=0.42856): 100%|█████████████████████████████████████████████████████████████| 59/59 [12:13<00:00, 12.42s/it]


([0.8916841495697062,
  0.8402921569976776,
  0.8299484132043737,
  0.7780722234195039,
  0.7444242679494906,
  0.7491268416524595,
  0.7974582916865116,
  0.732550762257316,
  0.7538079907952666,
  0.7982335572572347,
  0.8314457476529272,
  0.7783790644840728,
  0.9029091289481793,
  0.9147929507631168,
  0.8803205826082521,
  0.8112027163710938,
  0.7039527853086719,
  0.7051463845123552,
  0.7399483302465315,
  0.7806336009046854,
  0.7801203800611218,
  0.7874311940665506,
  0.8034207053669143,
  0.801944359747454],
 [12.729116730720122,
  26.664557646721004,
  24.613424678234825,
  31.25801129376187,
  31.45826116410334,
  27.688402418810316,
  24.982010895873756,
  27.562905129049962,
  27.258854698401908,
  25.208858425429053,
  26.307393152356482,
  32.56362327907559,
  8.951130809488141,
  15.605513686874975,
  17.722240146712384,
  20.87126247090366,
  30.714462603824128,
  29.489163503193122,
  27.59052297382283,
  19.488336849972605,
  22.15690148576117,
  26.3715864498454

In [None]:
torch.cuda.empty_cache()

In [17]:
for d in test_loader:
    t = d['image']
#     t = d['image'][:, :, :, :, :301]
    print(t.shape, (t.element_size() * t.numel() * 25) / (1024 ** 3))

torch.Size([1, 1, 231, 175, 301]) 1.1332263238728046
torch.Size([1, 1, 300, 236, 301]) 1.984722912311554
torch.Size([1, 1, 240, 205, 301]) 1.3792142271995544
torch.Size([1, 1, 286, 225, 301]) 1.8039112910628319
torch.Size([1, 1, 293, 224, 301]) 1.8398493528366089
torch.Size([1, 1, 221, 183, 217]) 0.8173408918082714
torch.Size([1, 1, 202, 159, 244]) 0.7298581302165985
torch.Size([1, 1, 252, 198, 301]) 1.3987250626087189
torch.Size([1, 1, 286, 231, 287]) 1.765875332057476
torch.Size([1, 1, 251, 218, 247]) 1.258714683353901
torch.Size([1, 1, 99, 122, 122]) 0.13723187148571014
torch.Size([1, 1, 271, 166, 301]) 1.2610839679837227
torch.Size([1, 1, 285, 164, 285]) 1.2406054884195328
torch.Size([1, 1, 272, 255, 301]) 1.9443556666374207
torch.Size([1, 1, 255, 196, 301]) 1.4010798186063766
torch.Size([1, 1, 240, 202, 301]) 1.3590306043624878
torch.Size([1, 1, 213, 203, 238]) 0.9584130719304085
torch.Size([1, 1, 212, 197, 301]) 1.1707622557878494
torch.Size([1, 1, 233, 219, 301]) 1.4304301701486

KeyboardInterrupt: 

In [None]:
# import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"
torch.cuda.mem_get_info()[0] / (1024 ** 3)


In [14]:
torch.cuda.empty_cache()
swin_trainer.test_multi_device(test_loader)

Test:   0%|                                                                                     | 0/59 [00:00<?, ?it/s]None of the inputs have requires_grad=True. Gradients will be None
invalid value encountered in double_scalars
torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
Test (loss=0.26597):  86%|████████████████████████████████████████████████████▋        | 51/59 [08:48<01:11,  8.89s/it]

torch.Size([1, 1, 331, 321, 603]) cpu


Test (loss=0.26597):  86%|████████████████████████████████████████████████████▋        | 51/59 [13:40<02:08, 16.08s/it]


KeyboardInterrupt: 

In [None]:
swin_trainer.multiple_test(
    file_list_test=file_list_test,
    task_name=task_name,
    val_transforms=val_transforms)

In [None]:
for t in train_loader:
    if t['label'].max() == 0:
        print(t['label'].min(), t['label'].max() == 0)
    break