In [23]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

In [24]:
class DoubleConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            # 2 conv functions in:1 out:64, in:64, out:64
            nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

In [25]:
# feature: feature map of convolutions
class UNET(nn.Module):
    def __init__(self, in_channel=3, out_channel=1, features=[64,128,256,512]):
        super(UNET, self).__init__()
        self.ups=nn.ModuleList()
        self.downs=nn.ModuleList()
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
        self.bottleneck=DoubleConv(features[-1], features[-1]*2)
        self.finalconv=nn.Conv2d(features[0], out_channel, kernel_size=1)
        
        #ModuleList = list of modules 
        #down part for UNet
        for feature in features:
            self.downs.append(DoubleConv(in_channel, feature))
            in_channel=feature
        
        for feature in features[::-1]:
            # ConvTranspose2d has in_chan=feat*2 and out_chan=feat 
            # as each layer of ups is (skip connections) concatenated with corresponding 
            # layer of downs, which results in out_chan=feat+feat
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))
            
    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x=down(x)
            skip_connections.append(x)
            x=self.pool(x)

        x=self.bottleneck(x)
        skip_connections=skip_connections[::-1]
        
        # steps of 2: concat b/w each UpConv and DoubleConv
        for i in range(0, len(self.ups), 2):
            x=self.ups[i](x)
            skip_connection=skip_connections[i//2]
            
            if x.shape!= skip_connection.shape:
                x=TF.resize(x, size=skip_connection.shape[2:])
                #index= 0,1 contains batch_size and no. of channels
            
            #dim=1: concat along channels
            concat_skip=torch.cat((skip_connection,x), dim=1)
            s=self.ups[i+1](concat_skip)
        return self.finalconv(x)

In [26]:
def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channel=1, out_channel=1)
    preds = model(x)
    assert preds.shape == x.shape
    return x

In [27]:
# test()

In [28]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

In [29]:
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir=image_dir
        self.mask_dir=mask_dir
        self.transform=transform
        self.images=os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path=os.path.join(self.image_dir, self.images[index])
        mask_path=os.path.join(self.mask_dir, self.images[index].replace('.jpg', '_mask.gif'))
        image=np.array(Image.open(img_path).convert('RGB'))
        mask=np.array(Image.open(mask_path).convert('L'), dtype=np.float32)
        #0.0: black 255.0: white section of mask
        mask[mask==255.0]=1.0
        #sigmoid will be used therefore probability measure is appropriate
        
        if self.transform is not None:
            augmentations=self.transform(image=image, mask=mask)
            image=augmentations['image']
            mask=augmentations['mask']
        
        return image, mask

In [30]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader

In [35]:
def save_checkpoint(state, filename = 'checkpoint.pth.tar'):
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    model.load_state_dict(checkpoint['state_dict'])
    
def get_loaders(train_dir, train_maskdir, val_dir, val_maskdir, batch_size, train_transform, val_transform, num_workers=4, pin_memory=True):
    train_ds = CarvanaDataset(image_dir=train_dir,mask_dir=train_maskdir,
                              transform=train_transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              num_workers=num_workers, pin_memory=pin_memory,
                              shuffle=True)
    val_ds = CarvanaDataset(image_dir=val_dir, mask_dir=val_maskdir,
                            transform=val_transform)
    val_loader = DataLoader(val_ds, batch_size=batch_size,
                            num_workers=num_workers, pin_memory=pin_memory,
                            shuffle=False)
    return train_loader, val_loader

def check_accuracy(loader, model, device='cuda'):
    num_correct=0
    num_pixels=0
    dice_score=0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x=x.to(device)
            y=y.to(device).unsqueeze(1)
            preds=torch.sigmoid(model(x))
            preds=(preds>0.5).float()
            num_correct+=(preds==y).sum()
            num_pixels+=torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
            print(f"Dice score: {dice_score/len(loader)}")
            model.train()

def save_predictions_as_imgs(loader, model, folder = 'saved_images/', device='cuda'):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x=x.to(device)
        with torch.no_grad():
            preds=torch.sigmoid(model(x))
            preds=(preds>0.5).float()
        torchvision.utils.save_image(preds, f'{folder}/pred_{idx}.png')
        torchvision.utils.save_image(y.unsqueeze(1), f'{folder}/{idx}.png')
        
    model.train()
    
    

In [36]:
LEARNING_RATE=1E-4
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE=32
NUM_EPOCHS=10
NUM_WORKERS=2
IMAGE_HEIGHT=160
IMAGE_WIDTH=240
PIN_MEMORY=True
LOAD_MODEL=False
TRAIN_IMG_DIR='data/train'
TRAIN_MASK_DIR='data/train_masks'
VAL_IMG_DIR='data/val'
VAL_MASK_DIR='data/val_masks'

In [37]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channel=3, out_channel=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("checkpoint.pth.tar"), model)


    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE
        )


In [None]:
main()

Got 261870/1228800 with acc 21.31
Dice score: 0.010979542508721352
Got 509687/2457600 with acc 20.74
Dice score: 0.02145787701010704
Got 759501/3686400 with acc 20.60
Dice score: 0.03200681507587433
Got 1022730/4915200 with acc 20.81
Dice score: 0.04302257299423218
Got 1270925/6144000 with acc 20.69
Dice score: 0.05351422727108002
Got 1527967/7372800 with acc 20.72
Dice score: 0.06431496888399124
Got 1799741/8601600 with acc 20.92
Dice score: 0.07562610507011414
Got 2062083/9830400 with acc 20.98
Dice score: 0.0866113007068634
Got 2328159/11059200 with acc 21.05
Dice score: 0.09772699326276779
Got 2577176/12288000 with acc 20.97
Dice score: 0.10824694484472275
Got 2835942/13516800 with acc 20.98
Dice score: 0.11910851299762726
Got 3087300/14745600 with acc 20.94
Dice score: 0.12971128523349762
Got 3347018/15974400 with acc 20.95
Dice score: 0.14060720801353455
Got 3607905/17203200 with acc 20.97
Dice score: 0.15154187381267548
Got 3883916/18432000 with acc 21.07
Dice score: 0.162996307

