In [141]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import torchvision.transforms.functional as TF

In [142]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import albumentations as A

from PIL import Image 
from torch.utils.data import Dataset, DataLoader, random_split
from glob import glob
from albumentations.pytorch import ToTensorV2

<h3>A little dataset/dataloading processing??</h3>

In [160]:

class DRIVEDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir 

    def __len__(self):
        return len(self.image_dir)
        
    def __getitem__(self, index):
        image = cv2.imread(self.image_dir[index], cv2.IMREAD_COLOR)
        image = image / 255.0
        image = np.transpose(image, (2, 0, 1))
        image = image.astype(np.float32)
        image = torch.from_numpy(image)


        mask = cv2.imread(self.mask_dir[index], cv2.IMREAD_COLOR)
        print(type(mask))
        mask = mask / 255.0

        #mask = np.expand_dims(mask, axis=0)
        mask = np.transpose(mask, (2,1,0))

        mask = mask.astype(np.float32)
        mask = torch.from_numpy(mask)
        
        
        return image, mask

In [161]:

TRAIN_IMAGE_DIR = './datasets/training/training/images'
TRAIN_MASK_DIR = './datasets/training/training/true_mask'

VALID_IMAGE_DIR = './datasets/training/validation/images'
VALID_MASK_DIR = './datasets/training/validation/true_mask'

BATCH_SIZE = 2 
IMAGE_HEIGHT = 512 
IMAGE_WIDTH = 512

training_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.Rotate(limit=30, p=0.1),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std = [1.0, 1.0, 1.0],
        max_pixel_value=255
    ),
    ToTensorV2()
])

validation_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255
    ),
    ToTensorV2()
])

In [162]:

train_x = sorted(glob(os.path.join(TRAIN_IMAGE_DIR, '*').replace("\\", "/")))
train_y = sorted(glob(os.path.join(TRAIN_MASK_DIR, '*').replace("\\", "/")))

valid_x = sorted(glob(os.path.join(VALID_IMAGE_DIR, '*').replace("\\", "/")))
valid_y = sorted(glob(os.path.join(VALID_MASK_DIR, '*').replace("\\", "/")))

print(train_x)
print(train_y)

['./datasets/training/training/images\\100_training.png', './datasets/training/training/images\\21_training.png', './datasets/training/training/images\\22_training.png', './datasets/training/training/images\\23_training.png', './datasets/training/training/images\\24_training.png', './datasets/training/training/images\\25_training.png', './datasets/training/training/images\\26_training.png', './datasets/training/training/images\\27_training.png', './datasets/training/training/images\\28_training.png', './datasets/training/training/images\\29_training.png', './datasets/training/training/images\\30_training.png', './datasets/training/training/images\\31_training.png', './datasets/training/training/images\\32_training.png', './datasets/training/training/images\\33_training.png', './datasets/training/training/images\\34_training.png', './datasets/training/training/images\\35_training.png', './datasets/training/training/images\\36_training.png', './datasets/training/training/images\\37_train

In [165]:
train_dataset = DRIVEDataset(train_x, train_y)
valid_dataset = DRIVEDataset(valid_x, valid_y)

item = train_dataset.__getitem__(1)
plt.imshow(item[0].permute(1,2,0))
plt.figure()
plt.imshow(item[1].permute(1,2,0))

<class 'NoneType'>


TypeError: unsupported operand type(s) for /: 'NoneType' and 'float'

<h3>A little trolling with double conv??</h3>

In [102]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x

<h3>A little trolling with network definition?</h3>

In [103]:

class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # process Down-UNET:
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
            
        #process Up-UNET:
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature * 2, feature, kernel_size=2, stride=2 
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))
        self.bottom = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels=out_channels, kernel_size=1)
    
    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottom(x)
        skip_connections = skip_connections[::-1]
        
        
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            
            
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_res = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_res)
        
        x = self.final_conv(x)
        return x

<h3>A little pre-network testing??</h3>

In [104]:
def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    predictions = model(x)
    
    print(x.shape)
    print(predictions.shape)
    assert predictions.shape == x.shape
    
test()

torch.Size([3, 1, 161, 161])
torch.Size([3, 1, 161, 161])


<h2> Let's do some TRAINING????</h2>

In [105]:
EPOCHS = 5
LEARNING_RATE = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPT_PATH = './checkpoints/checkpoint.pth'
print(f"Using: {DEVICE}")

model = UNET().to(DEVICE)

Using: cuda


In [106]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

In [107]:
def train(device, model, trainloader, optimizer, criterion):
    
    model.train()     
    print("hey")
    for batch, (image, mask) in enumerate(trainloader):
        image, mask = image.to(device), mask.to(device)
        
        output = model(image)
        loss = criterion(output, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [108]:
for epoch in range(EPOCHS):
    train(device=DEVICE, model=model, trainloader=trainloader, optimizer=optimizer, criterion=criterion)
    print(f"Epoch: {epoch}") 

hey


RuntimeError: CUDA out of memory. Tried to allocate 202.00 MiB (GPU 0; 4.00 GiB total capacity; 3.43 GiB already allocated; 0 bytes free; 3.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF