In [1]:
from src.constants import *
from src.model.my_model import MyModel
from src.utils.metrics import num_parameters
from src.model.baselines import SWINUNETR, UNet
import os
import torch

In [2]:
device = "cuda"

model = MyModel(in_channels=1,
                out_channels=3,
                embed_dim=256,
                skip_transformer=False,
                channels=(4, 16, 32, 32, 32),
                transformer_channels=(2, 8, 16, 16, 16),
                patch_size=8
                ).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATES[0])
loss_fn = torch.nn.MSELoss()

In [3]:
example = torch.rand(size=(4, 1, 128, 256, 256)).to(device)
mask = torch.rand(size=(4, 3, 128, 256, 256)).to(device)

In [4]:
%%time
out = model(example)
loss = loss_fn(out, mask)
loss.backward()
optimizer.step()
print(out.shape)

torch.Size([4, 3, 128, 256, 256])
CPU times: user 1.32 s, sys: 564 ms, total: 1.88 s
Wall time: 1.96 s


In [5]:
os.system('nvidia-smi')

Sun Apr 23 21:26:13 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A40          On   | 00000000:3B:00.0 Off |                    0 |
|  0%   53C    P0   260W / 300W |  35683MiB / 46068MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A40          On   | 00000000:5E:00.0 Off |                    0 |
|  0%   35C    P8    32W / 300W |      0MiB / 46068MiB |      0%      Default |
|       

0

In [5]:
print(num_parameters(model))
print(num_parameters(SWINUNETR))
print(num_parameters(UNet))


20277079
15703029
1203362


In [6]:
from src.data.data_loaders import get_loaders

# train_loader, test_loader = get_loaders()

<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.
<class 'monai.transforms.utility.array.AsChannelFirst'>: Class `AsChannelFirst` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.
Loading dataset: 100%|██████████| 70/70 [11:30<00:00,  9.87s/it]  
Loading dataset: 100%|██████████| 10/10 [01:04<00:00,  6.42s/it]


In [2]:
for d in test_loader:
    display(d['img'].shape)

torch.Size([1, 1, 743, 317, 317])

torch.Size([1, 1, 582, 300, 300])

torch.Size([1, 1, 510, 295, 295])

torch.Size([1, 1, 806, 317, 317])

torch.Size([1, 1, 556, 387, 387])

torch.Size([1, 1, 799, 334, 334])

torch.Size([1, 1, 687, 309, 309])

torch.Size([1, 1, 808, 317, 317])

torch.Size([1, 1, 560, 300, 300])

torch.Size([1, 1, 740, 292, 292])

In [28]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
loss_fn = monai.losses.DiceCELoss()

torch.cuda.empty_cache()

model = MyModel(in_channels=1,
                out_channels=3,
                embed_dim=256,
                skip_transformer=False,
                channels=(4, 16, 32, 32, 32),
                transformer_channels=(2, 8, 8, 8, 8),
                patch_size=8
                ).to(DEVICE)



In [29]:
from monai.networks import one_hot
from tqdm import tqdm


def train_single_epoch_local(model, optimizer, train_loader):
    losses = []

    # TODO: make this dataloading fast, this seems to be the bottleneck
    for d in train_loader:

        img = d['img'].to(DEVICE)
        mask = d['mask'].to(DEVICE)

        optimizer.zero_grad()

        outputs = model(img)

        outputs = torch.softmax(outputs, dim=1)
        mask = one_hot(mask, num_classes=3)

        loss = loss_fn(outputs, mask)
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().item())
    print((sum(losses) / len(losses)))

In [None]:
for i in tqdm(range(20)):
    train_single_epoch_local(model=model, optimizer=optimizer, train_loader=train_loader)


  5%|▌         | 1/20 [00:45<14:18, 45.17s/it]

1.9394295249666487


 10%|█         | 2/20 [01:30<13:34, 45.26s/it]

1.9401933942522323


 15%|█▌        | 3/20 [02:15<12:48, 45.22s/it]

1.9392336232321603


 20%|██        | 4/20 [03:01<12:05, 45.31s/it]

1.9395035914012364


 25%|██▌       | 5/20 [03:46<11:19, 45.33s/it]

1.9394272259303502


 30%|███       | 6/20 [04:31<10:33, 45.24s/it]

1.9400581666401455


 35%|███▌      | 7/20 [05:16<09:46, 45.15s/it]

1.9398419380187988


 40%|████      | 8/20 [06:01<09:00, 45.08s/it]

1.9393507480621337


 45%|████▌     | 9/20 [06:46<08:14, 45.00s/it]

1.9401801790509905


 50%|█████     | 10/20 [07:31<07:29, 44.97s/it]

1.939806706564767


 55%|█████▌    | 11/20 [08:16<06:45, 45.02s/it]

1.9384404829570225


 60%|██████    | 12/20 [09:01<05:59, 45.00s/it]

1.9390227454049247


 65%|██████▌   | 13/20 [09:45<05:14, 44.92s/it]

1.9394481454576764


 70%|███████   | 14/20 [10:31<04:30, 45.01s/it]

1.9389794724328178


 75%|███████▌  | 15/20 [11:16<03:44, 44.98s/it]

1.9393702813557216


 80%|████████  | 16/20 [12:01<02:59, 44.99s/it]

1.9401637758527484
