In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from models.unet_from_scratch.UNET_MODEL_FROM_SCRATCH import UNET_FROM_SCRATCH
from models.unet_from_scratch.utils import get_loaders, save_checkpoint, print_measurements, save_predictions_as_imgs, \
    load_checkpoint
from models.unet_from_scratch.visual import DatasetViewer

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 240
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False

TRAIN_IMG_DIR = 'E:\CVDL\MyApp\models\data\images'
TRAIN_MASK_DIR = 'E:\CVDL\MyApp\models\data\labels'
VAL_IMG_DIR = 'E:\CVDL\MyApp\models\data\\validationimages'
VAL_MASK_DIR = 'E:\CVDL\MyApp\models\data\\validationlabels'

train_transfrom = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0
                    ),
        ToTensorV2()
    ]
)

val_transfrom = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0
                    ),
        ToTensorV2()
    ]
)


In [2]:
train_loader, val_loader = get_loaders(TRAIN_IMG_DIR,
                                           TRAIN_MASK_DIR,
                                           VAL_IMG_DIR,
                                           VAL_MASK_DIR,
                                           BATCH_SIZE,
                                           train_transfrom,
                                           val_transfrom,
                                           NUM_WORKERS,
                                           PIN_MEMORY
                                           )

In [3]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    """
    :param loader:
    :param model:
    :param optimizer: - how the network will be updated based on the loss function
    :param loss_fn: - quantity that will be minimized during training
    :param scaler:
    :return:
    """
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.unsqueeze(1).to(device=DEVICE)
        # forward
        with torch.cuda.amp.autocast():
            pred = model(data)
            loss = loss_fn(pred, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

## UNET FROM SCRATCH


In [4]:
#########################
model = UNET_FROM_SCRATCH(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()  # applies sigmoid on output
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch} / {NUM_EPOCHS}")
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    save_checkpoint(checkpoint, 'check_unet_scratch.pth.tar')
    print_measurements(val_loader, model, device=DEVICE)

    # check acc
    # print examples
    save_predictions_as_imgs(val_loader, model, folder="unet_from_scratch_saved_photos/", device=DEVICE)



Epoch 0 / 10


100%|██████████| 7/7 [01:33<00:00, 13.30s/it, loss=0.535]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 3908632/5760000 with acc 67.86
Got dice_score:0.7227222323417664 
Got ioc1:2.468986307714673 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 1 / 10


100%|██████████| 7/7 [01:30<00:00, 12.99s/it, loss=0.475]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 3529306/5760000 with acc 61.27
Got dice_score:0.46788859367370605 
Got ioc1:1.5058266716433062 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 2 / 10


100%|██████████| 7/7 [01:34<00:00, 13.48s/it, loss=0.459]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 2785898/5760000 with acc 48.37
Got dice_score:0.06620122492313385 
Got ioc1:0.1834574312210464 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 3 / 10


100%|██████████| 7/7 [01:40<00:00, 14.33s/it, loss=0.442]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 2876878/5760000 with acc 49.95
Got dice_score:0.12496374547481537 
Got ioc1:0.3359548923891483 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 4 / 10


100%|██████████| 7/7 [01:37<00:00, 13.86s/it, loss=0.378]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 3958495/5760000 with acc 68.72
Got dice_score:0.6235316395759583 
Got ioc1:2.100407048575176 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 5 / 10


100%|██████████| 7/7 [01:36<00:00, 13.76s/it, loss=0.368]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 4512683/5760000 with acc 78.35
Got dice_score:0.7699885368347168 
Got ioc1:2.6580700344356276 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 6 / 10


100%|██████████| 7/7 [01:36<00:00, 13.79s/it, loss=0.375]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 4683291/5760000 with acc 81.31
Got dice_score:0.8241996765136719 
Got ioc1:2.864524281035736 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 7 / 10


100%|██████████| 7/7 [01:35<00:00, 13.67s/it, loss=0.338]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 4887653/5760000 with acc 84.86
Got dice_score:0.8536372184753418 
Got ioc1:2.9738648985479714 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 8 / 10


100%|██████████| 7/7 [01:36<00:00, 13.73s/it, loss=0.335]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 4938808/5760000 with acc 85.74
Got dice_score:0.8622197508811951 
Got ioc1:3.0049388583624106 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...
Epoch 9 / 10


100%|██████████| 7/7 [01:36<00:00, 13.79s/it, loss=0.356]


=>saving checkpoint to check_unet_scratch.pth.tar
Got 4854015/5760000 with acc 84.27
Got dice_score:0.8372305631637573 
Got ioc1:2.8900847829186023 
Saving images to.... unet_from_scratch_saved_photos/
Finish saving images...


In [5]:
## Validation UNET_FROM_SCRATCH

In [6]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from models.unet_from_scratch.utils import get_loaders, save_checkpoint, print_measurements, save_predictions_as_imgs, \
    load_checkpoint

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = True

TRAIN_IMG_DIR = 'E:\CVDL\MyApp\models\data\\images'
TRAIN_MASK_DIR = 'E:\CVDL\MyApp\models\data\labels'
VAL_IMG_DIR = 'E:\CVDL\MyApp\models\data\\validationimages'
VAL_MASK_DIR = 'E:\CVDL\MyApp\models\data\\validationlabels'

train_transfrom = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0
                    ),
        ToTensorV2()
    ]
)

val_transfrom = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0
                    ),
        ToTensorV2()
    ]
)
train_loader, val_loader = get_loaders(TRAIN_IMG_DIR,
                                       TRAIN_MASK_DIR,
                                       VAL_IMG_DIR,
                                       VAL_MASK_DIR,
                                       BATCH_SIZE,
                                       train_transfrom,
                                       val_transfrom,
                                       NUM_WORKERS,
                                       PIN_MEMORY
                                       )



In [7]:
model_unet = UNET_FROM_SCRATCH(in_channels=3, out_channels=1).to(DEVICE)
if LOAD_MODEL:
    load_checkpoint(torch.load('check_unet_scratch.pth.tar'), model_unet)
print_measurements(val_loader, model_unet)
save_predictions_as_imgs(val_loader, model_unet, folder="unet_validate/", device=DEVICE)



Got 3188987/3840000 with acc 83.05
Got dice_score:0.8207998871803284 
Got ioc1:2.81063116176925 
Saving images to.... unet_validate/
Finish saving images...


In [8]:
from models.unet_from_scratch.model import MyUNET, TrucnateResNET_UNET

model = TrucnateResNET_UNET().to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()  # applies sigmoid on output
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch} / {NUM_EPOCHS}")
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    save_checkpoint(checkpoint, 'check_resnet_unet.pth.tar')
    print_measurements(val_loader, model, device=DEVICE)

    # check acc
    # print examples
    save_predictions_as_imgs(val_loader, model, folder="unet_resnet_saved_photos/", device=DEVICE)



Epoch 0 / 10


100%|██████████| 7/7 [01:11<00:00, 10.17s/it, loss=0.54] 


=>saving checkpoint to check_resnet_unet.pth.tar
Got 2413778/3840000 with acc 62.86
Got dice_score:0.5606656074523926 
Got ioc1:1.8801244402213424 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 1 / 10


100%|██████████| 7/7 [01:10<00:00, 10.04s/it, loss=0.392]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 2175172/3840000 with acc 56.65
Got dice_score:0.3392743468284607 
Got ioc1:1.1420855379039048 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 2 / 10


100%|██████████| 7/7 [01:09<00:00,  9.99s/it, loss=0.382]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 2843467/3840000 with acc 74.05
Got dice_score:0.6803023219108582 
Got ioc1:2.3069720931711197 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 3 / 10


100%|██████████| 7/7 [01:09<00:00,  9.94s/it, loss=0.382]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 3382735/3840000 with acc 88.09
Got dice_score:0.8784723877906799 
Got ioc1:3.0422105143383327 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 4 / 10


100%|██████████| 7/7 [01:10<00:00, 10.01s/it, loss=0.355]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 3481290/3840000 with acc 90.66
Got dice_score:0.9043213129043579 
Got ioc1:3.135559102166642 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 5 / 10


100%|██████████| 7/7 [01:09<00:00,  9.92s/it, loss=0.339]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 3515468/3840000 with acc 91.55
Got dice_score:0.9153851270675659 
Got ioc1:3.1812779997898604 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 6 / 10


100%|██████████| 7/7 [01:09<00:00,  9.95s/it, loss=0.313]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 3503086/3840000 with acc 91.23
Got dice_score:0.904784619808197 
Got ioc1:3.1404488273229068 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 7 / 10


100%|██████████| 7/7 [01:09<00:00,  9.98s/it, loss=0.347]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 3484602/3840000 with acc 90.74
Got dice_score:0.9061997532844543 
Got ioc1:3.1505453967061587 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 8 / 10


100%|██████████| 7/7 [01:10<00:00, 10.01s/it, loss=0.282]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 3547097/3840000 with acc 92.37
Got dice_score:0.9194938540458679 
Got ioc1:3.1943440784418065 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...
Epoch 9 / 10


100%|██████████| 7/7 [01:10<00:00, 10.06s/it, loss=0.299]


=>saving checkpoint to check_resnet_unet.pth.tar
Got 3560640/3840000 with acc 92.73
Got dice_score:0.9258744120597839 
Got ioc1:3.228239647088128 
Saving images to.... unet_resnet_saved_photos/
Finish saving images...


In [9]:
model_resnet_unet = TrucnateResNET_UNET().to(DEVICE)
if LOAD_MODEL:
    load_checkpoint(torch.load('check_resnet_unet.pth.tar'), model_resnet_unet)
print_measurements(val_loader, model_resnet_unet)
save_predictions_as_imgs(val_loader, model_resnet_unet, folder="validate/", device=DEVICE)



Got 3560640/3840000 with acc 92.73
Got dice_score:0.9258744120597839 
Got ioc1:3.228239647088128 
Saving images to.... validate/
Finish saving images...
