In [5]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from models.unet_from_scratch.UNET_MODEL_FROM_SCRATCH import UNET_FROM_SCRATCH
from models.unet_from_scratch.utils import get_loaders, save_checkpoint, print_measurements, save_predictions_as_imgs, \
    load_checkpoint
from models.unet_from_scratch.visual import DatasetViewer

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 7
NUM_WORKERS = 2
IMAGE_HEIGHT = 240
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False

TRAIN_IMG_DIR = 'E:\CVDL\MyApp\models\data\images'
TRAIN_MASK_DIR = 'E:\CVDL\MyApp\models\data\labels'
VAL_IMG_DIR = 'E:\CVDL\MyApp\models\data\\validationimages'
VAL_MASK_DIR = 'E:\CVDL\MyApp\models\data\\validationlabels'

train_transfrom = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0
                    ),
        ToTensorV2()
    ]
)

val_transfrom = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0
                    ),
        ToTensorV2()
    ]
)


In [6]:
train_loader, val_loader = get_loaders(TRAIN_IMG_DIR,
                                           TRAIN_MASK_DIR,
                                           VAL_IMG_DIR,
                                           VAL_MASK_DIR,
                                           BATCH_SIZE,
                                           train_transfrom,
                                           val_transfrom,
                                           NUM_WORKERS,
                                           PIN_MEMORY
                                           )

In [7]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    """
    :param loader:
    :param model:
    :param optimizer: - how the network will be updated based on the loss function
    :param loss_fn: - quantity that will be minimized during training
    :param scaler:
    :return:
    """
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.unsqueeze(1).to(device=DEVICE)
        # forward
        with torch.cuda.amp.autocast():
            pred = model(data)
            loss = loss_fn(pred, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

## UNET FROM SCRATCH


In [None]:
#########################
model = UNET_FROM_SCRATCH(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()  # applies sigmoid on output
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1} / {NUM_EPOCHS}")
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    save_checkpoint(checkpoint, 'check_unet_scratch.pth.tar')
    print_measurements(train_loader, model, device=DEVICE)

    # check acc
    # print examples
    if epoch == NUM_EPOCHS -1:
        save_predictions_as_imgs(train_loader, model, folder="unet_from_scratch_saved_photos/", device=DEVICE)

print("End unet training")




Epoch 1 / 7


100%|██████████| 94/94 [24:04<00:00, 15.37s/it, loss=0.415]


=>saving checkpoint to check_unet_scratch.pth.tar
printing measure
Got 70147379/86400000 with acc 81.19
Got dice_score:0.8357783555984497 
Got ioc1:0.41788917779922485 
Epoch 2 / 7


100%|██████████| 94/94 [23:56<00:00, 15.28s/it, loss=0.383]


=>saving checkpoint to check_unet_scratch.pth.tar
printing measure
Got 71081699/86400000 with acc 82.27
Got dice_score:0.8455772995948792 
Got ioc1:0.4227886497974396 
Epoch 3 / 7


100%|██████████| 94/94 [23:54<00:00, 15.26s/it, loss=0.368]


=>saving checkpoint to check_unet_scratch.pth.tar
printing measure
Got 72207125/86400000 with acc 83.57
Got dice_score:0.8583813905715942 
Got ioc1:0.4291906952857971 
Epoch 4 / 7


100%|██████████| 94/94 [23:49<00:00, 15.20s/it, loss=0.351]


=>saving checkpoint to check_unet_scratch.pth.tar
printing measure
Got 72670187/86400000 with acc 84.11
Got dice_score:0.8571197986602783 
Got ioc1:0.42855989933013916 
Epoch 5 / 7


 37%|███▋      | 35/94 [09:02<15:25, 15.68s/it, loss=0.34] 

## Validation UNET_FROM_SCRATCH

In [None]:
# import torch
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# from tqdm import tqdm
# import torch.nn as nn
# import torch.optim as optim
# from models.unet_from_scratch.utils import get_loaders, save_checkpoint, print_measurements, save_predictions_as_imgs, \
#     load_checkpoint
#
# LEARNING_RATE = 1e-4
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# BATCH_SIZE = 16
# NUM_EPOCHS = 10
# NUM_WORKERS = 2
# IMAGE_HEIGHT = 160
# IMAGE_WIDTH = 240
# PIN_MEMORY = True
# LOAD_MODEL = True
#
# TRAIN_IMG_DIR = 'E:\CVDL\MyApp\models\data\\images'
# TRAIN_MASK_DIR = 'E:\CVDL\MyApp\models\data\labels'
# VAL_IMG_DIR = 'E:\CVDL\MyApp\models\data\\validationimages'
# VAL_MASK_DIR = 'E:\CVDL\MyApp\models\data\\validationlabels'
#
# train_transfrom = A.Compose(
#     [
#         A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
#         A.Rotate(limit=35, p=1.0),
#         A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.1),
#         A.Normalize(mean=[0.0, 0.0, 0.0],
#                     std=[1.0, 1.0, 1.0],
#                     max_pixel_value=255.0
#                     ),
#         ToTensorV2()
#     ]
# )
#
# val_transfrom = A.Compose(
#     [
#         A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
#         A.Normalize(mean=[0.0, 0.0, 0.0],
#                     std=[1.0, 1.0, 1.0],
#                     max_pixel_value=255.0
#                     ),
#         ToTensorV2()
#     ]
# )
# train_loader, val_loader = get_loaders(TRAIN_IMG_DIR,
#                                        TRAIN_MASK_DIR,
#                                        VAL_IMG_DIR,
#                                        VAL_MASK_DIR,
#                                        BATCH_SIZE,
#                                        train_transfrom,
#                                        val_transfrom,
#                                        NUM_WORKERS,
#                                        PIN_MEMORY
#                                        )
#


In [None]:
# model_unet = UNET_FROM_SCRATCH(in_channels=3, out_channels=1).to(DEVICE)
# if LOAD_MODEL:
#     load_checkpoint(torch.load('check_unet_scratch.pth.tar'), model_unet)
# print_measurements(val_loader, model_unet)
# save_predictions_as_imgs(val_loader, model_unet, folder="validate_unet/", device=DEVICE)
#


## RESNET_UNET Training

In [None]:
from models.unet_from_scratch.model import TrucnateResNET_UNET

model = TrucnateResNET_UNET().to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()  # applies sigmoid on output
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1} / {NUM_EPOCHS}")
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    save_checkpoint(checkpoint, 'check_resnet_unet.pth.tar')
    print_measurements(val_loader, model, device=DEVICE)

    # check acc
    # print examples
    if epoch == NUM_EPOCHS -1:
        save_predictions_as_imgs(val_loader, model, folder="unet_resnet_saved_photos/", device=DEVICE)
print("End resnet_unet training")


In [None]:
# model_resnet_unet = TrucnateResNET_UNET().to(DEVICE)
# if LOAD_MODEL:
#     load_checkpoint(torch.load('check_resnet_unet.pth.tar'), model_resnet_unet)
# print_measurements(val_loader, model_resnet_unet)
# save_predictions_as_imgs(val_loader, model_resnet_unet, folder="validate_resnet/", device=DEVICE)
#
