In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as tf

### Construct U-net model

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels:int, out_channels:int):
        super().__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding="same",bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding="same",bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)

        )

    def forward(self,x):
        return self.conv(x)
    

class UNET(nn.Module):
    def __init__(self, in_channels:int=3,out_channels:int=1, features:list=[64,128,256,512]):
        super().__init__()
        self.downs=nn.ModuleList()
        self.ups=nn.ModuleList()
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)

        # down Unet
        for feature in features:
            self.downs.append(DoubleConv(in_channels=in_channels,out_channels=feature))
            in_channels=feature

        #up 
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(in_channels=feature*2, out_channels=feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(in_channels=feature*2, out_channels=feature))
        
        #botton layer latent 
        self.bottomnet=DoubleConv(in_channels=features[-1],out_channels=features[-1]*2)

        self.final_conv= nn.Conv2d(in_channels=features[0],out_channels=1,kernel_size=1)


    def forward(self,x):
        skip_connections=[]
        for down in self.downs:

            x=down(x)
            skip_connections.append(x)
            x=self.pool(x)

        x=self.bottomnet(x)
        skip_connections=skip_connections[::-1]

        for index in range(0,len(self.ups),2):

            #up conv
            x=self.ups[index](x)
        
            skip_connection = skip_connections[index//2]
            if(x.shape != skip_connection.shape): # max pooling with underflow the original size -> cause the new upsampling size smaller than origin size 
                x= tf.resize(x,size=skip_connection.shape[2:])
            x=torch.cat((skip_connection, x),dim=1)

            #pass to double conv
            x=self.ups[index+1](x)

 
        return self.final_conv(x)


In [3]:
# test
x=torch.rand((3,1,160,160))
model = UNET(in_channels=1,out_channels=1)
pred  = model(x)
print(pred.shape)
print(x.shape)
assert pred.shape ==x.shape

torch.Size([3, 1, 160, 160])
torch.Size([3, 1, 160, 160])


## Load [Dataset](https://www.kaggle.com/c/carvana-image-masking-challenge/data?select=train_masks.zip)

In [4]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask # x ,y 



## training part

In [5]:
import torch
import albumentations as A 
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision


LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(DEVICE)

BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 0
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "dataset/train/"
TRAIN_MASK_DIR = "dataset/train_masks/"
VAL_IMG_DIR = "dataset/valid/"
VAL_MASK_DIR = "dataset/valid_masks/"


def train(loader,model, optimizer, loss_func,scaler):
    loop = tqdm(loader)
    print("train")
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_func(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

## utils
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(train_dir, train_maskdir, val_dir, val_maskdir, batch_size, train_transform, val_transform, num_workers=1, pin_memory=True,):
    train_ds = CarvanaDataset(image_dir=train_dir, mask_dir=train_maskdir, transform=train_transform )

    train_loader = DataLoader(train_ds, batch_size=batch_size,num_workers=num_workers, pin_memory=pin_memory, shuffle=True)

    val_ds = CarvanaDataset(image_dir=val_dir, mask_dir=val_maskdir, transform=val_transform,)

    val_loader = DataLoader( val_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False)

    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval() # disable Batch Normalization, Dropout 
    print(loader)
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train() # enable Batch Normalization, Dropout

def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()

        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )

        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

cuda


In [6]:
train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
)

val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
)

model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
)

if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    train(train_loader, model, optimizer, loss_fn, scaler)

        # save model
    checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

        # check accuracy
    check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
    save_predictions_as_imgs(
        val_loader, model, folder="saved_images/", device=DEVICE
    )


<torch.utils.data.dataloader.DataLoader object at 0x000002499603C850>
Got 20390117/25804800 with acc 79.02
Dice score: 0.0


  0%|          | 0/276 [00:00<?, ?it/s]

train


100%|██████████| 276/276 [05:22<00:00,  1.17s/it, loss=0.136]


=> Saving checkpoint
<torch.utils.data.dataloader.DataLoader object at 0x000002499603C850>
Got 25535006/25804800 with acc 98.95
Dice score: 0.9756826758384705


FileNotFoundError: [Errno 2] No such file or directory: 'saved_images//pred_0.png'