In [12]:
# imports
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

In [14]:
# Define model class

class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.conv(x)
    
class Unet(nn.Module):
    def __init__ (
        self,
        in_c=3,
        out_c=1, 
        features=[
            64,
            128,
            256,
            512
        ]
    ):
        super(Unet, self).__init__()
        self.up = nn.ModuleList()
        self.down = nn.ModuleList()
        self.pool = nn.MaxPool2d(2,2)
        
        for feature in features:
            self.down.append(DoubleConv(in_c, feature))
            in_c = feature
        
        for feature in reversed(features):
            self.up.append(
                nn.ConvTranspose2d(
                    feature * 2, feature, kernel_size=2, stride=2
                )
            )
            self.up.append(DoubleConv(feature * 2, feature))
        
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_c, kernel_size=1)
        
    def forward(self, x):
        skip_connections=[]
        for down in self.down:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
            
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        
        for idx in range(0, len(self.up), 2):
            x = self.up[idx](x)
            skip_connection = skip_connections[idx // 2]
            
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.up[idx + 1](concat_skip)
        
        return self.final_conv(x)

None


In [16]:
# Define Dataset class
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class RoadDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.image[index])
        image = np.array(Image.open(img_path))
        mask = np.array(Image.open(mask_path))
        mask[mask==255]=1.0
        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            maks = augmentations["mask"]
        return image, mask

In [None]:
# Helper functions
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state, filename):
    print(f"===> Saving checkpoint {filename}")
    torch.save(state, filename)

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_mask_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = RoadDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir, 
        transform=train_transform
    )
    
    train_loader = Dataloader(
        train_ds, 
        batch_size=batch_size, 
        num_workers=num_workers, 
        pin_memory=pin_memory, 
        shuffle=True
    )
    
    val_ds = RoadDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )
    
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )
    
    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds==y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum + 1e-8)
    
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels * 100:.4f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()
    
def save_preds_as_imgs(
    loader,
    model,
    folder"saved_images/",
    device="cuda"
):
    model.eval()
    for idx, (x,y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds>0.5).float()
        torchvision.utils.save_image(preds, f'{folder}/pred_{idx}.png')
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/{idx}.png")
    
                            
    
    