In [5]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from models.unet_from_scratch.UNET_MODEL_FROM_SCRATCH import UNET_FROM_SCRATCH
from models.unet_from_scratch.utils import get_loaders, save_checkpoint, print_measurements, save_predictions_as_imgs, \
    load_checkpoint
from models.unet_from_scratch.visual import DatasetViewer

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 240
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = True

TRAIN_IMG_DIR = 'E:\CVDL\MyApp\models\data\images'
TRAIN_MASK_DIR = 'E:\CVDL\MyApp\models\data\labels'
VAL_IMG_DIR = 'E:\CVDL\MyApp\models\data\\validationimages'
VAL_MASK_DIR = 'E:\CVDL\MyApp\models\data\\validationlabels'

train_transfrom = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0
                    ),
        ToTensorV2()
    ]
)

val_transfrom = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0
                    ),
        ToTensorV2()
    ]
)


In [6]:
from models.unet_from_scratch.utils import get_loaders
train_loader, val_loader = get_loaders(TRAIN_IMG_DIR,
                                           TRAIN_MASK_DIR,
                                           VAL_IMG_DIR,
                                           VAL_MASK_DIR,
                                           BATCH_SIZE,
                                           train_transfrom,
                                           val_transfrom,
                                           NUM_WORKERS,
                                           PIN_MEMORY
                                           )

## UNET

In [7]:
model_unet = UNET_FROM_SCRATCH(in_channels=3, out_channels=1).to(DEVICE)
if LOAD_MODEL:
    load_checkpoint(torch.load('check_unet_scratch.pth.tar'), model_unet)
print_measurements(val_loader, model_unet)
save_predictions_as_imgs(val_loader, model_unet, folder="validate_unet/", device=DEVICE)



printing measure
Got 17131791/19987200 with acc 85.71
Got dice_score:0.8750036358833313 
Got ioc1:0.43750181794166565 
Saving images to.... validate_unet/
Finish saving images...


### RESNET_UNET

In [8]:
from models.unet_from_scratch.model import TrucnateResNET_UNET

model_resnet_unet = TrucnateResNET_UNET().to(DEVICE)
if LOAD_MODEL:
    load_checkpoint(torch.load('check_resnet_unet.pth.tar'), model_resnet_unet)
print_measurements(val_loader, model_resnet_unet)
save_predictions_as_imgs(val_loader, model_resnet_unet, folder="validate_resnet/", device=DEVICE)



printing measure
Got 18000625/19987200 with acc 90.06
Got dice_score:0.9065384268760681 
Got ioc1:0.45326921343803406 
Saving images to.... validate_resnet/
Finish saving images...
