In [9]:
import torch
import torch.nn as nn
from dataset import CarvanaDataset
import torchvision.transforms.functional as TF
import torchvision

In [10]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1 ,1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_channels, out_channels, 3, 1 ,1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.conv(x)
    
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
            
        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                     feature*2, feature, kernel_size=2, stride=2
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))
            
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
        
    def forward(self, x):
        skip_connections = []
        for down in self.downs:
                x = down(x)
                skip_connections.append(x)
                x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
                
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)
        
        return self.final_conv(x)
    

In [11]:
def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    print(preds.shape)
    print(x.shape)
    # assert(preds.shape == x.shape)
test()


torch.Size([3, 1, 161, 161])
torch.Size([3, 1, 161, 161])


## Training part

In [12]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader

In [13]:
def load_checkpoint(checkpoint, model3):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)
    
def get_loaders(
    train_dir,
    train_mask_dir,
    val_dir,
    val_mask_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True
):
    train_ds = CarvanaDataset(
        image_dir=train_dir, 
        mask_dir=train_mask_dir, 
        transform=train_transform
    )
    
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )
    
    val_ds = CarvanaDataset(
        image_dir=val_dir, 
        mask_dir=val_mask_dir, 
        transform=val_transform
    )
    
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )
    
    return train_loader, val_loader
    
def check_accuracy(loader, model, device="mps"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0 # Better evaluation
    model.eval()
    
    with torch.no_grad():
        
        for x,y in loader:
            
            
            x = x.to(device)
            
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(
        f"dice_score {dice_score/len(loader)}"
    )
    
    model.train()
    
def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="mps"
):
    model.eval()
    for idx,(x,y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        
        torchvision.utils.save_image(preds, f"{folder}pred_{idx}.png")
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}correct{idx}.png")
    model.train()

In [14]:
# Hyperparameters

LEARNING_RATE=1e-4
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 2

IMAGE_HEIGHT = 160 # 1280 originally
IMAGE_WIDTH = 240 # 1918 originally

PIN_MEMORY = True
LOAD_MODEL = False

TRAIN_IMG_DIR = "data/train_images/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val_images/"
VAL_MASK_DIR = "data/val_masks/"

In [15]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    
    for batch_idx, (data, targets) in enumerate(loop):
        
        print(len(data))
        print(len(data[0]), len(data[1]))
        # data = data.to(device=DEVICE)
        # targets = targets.float().unsqueeze(1).to(device=DEVICE)
        
        # # forward
        # predictions = model(data)
        # loss = loss_fn(predictions, targets)
            
        # # backward
        # optimizer.zero_grad()
        # scaler.scale(loss).backward()
        # scaler.step(optimizer)
        # scaler.update()

    # update tqdm loop
    # loop.set_postfix(loss=loss.item())
    
def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        # To Tensor doesn't divide by 255 like PyTorch
        # it's done inside Normalize function
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0
        ),
        ToTensorV2(),
    ])
    val_transforms = A.Compose([
        A.Resize(height=IMAGE_HEIGHT,width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0
        ),
        ToTensorV2(),
    ])
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )
    
    # if LOAD_MODEL:
    #     load_checkpoint(torch.load("my_mps_checkpoint.pth.tar"), model)
    #     check_accuracy(val_loader, model, device=DEVICE)
    
    scaler = torch.cuda.amp.GradScaler()
    train_fn(train_loader, model, optimizer, loss_fn, scaler)
    # for epoch in range(NUM_EPOCHS):
    #     train_fn(train_loader, model, optimizer, loss_fn, scaler)
        
    #     # save model
    #     checkpoint = {
    #         "state_dict": model.state_dict(),
    #         "optimizer": optimizer.state_dict(),
    #     }
        
    #     save_checkpoint(checkpoint)
    #     # check acccuracy
    #     check_accuracy(val_loader, model, device=DEVICE)
        
    #     # print some example to a folder
    #     save_predictions_as_imgs(
    #         val_loader, model, folder="saved_images_mps/", device=DEVICE
    #     )
        
    # For multu class
    # model = UNET(in_channels=3, out_channels=3).to(DEVICE)
    # loss_fn = nn. cross entropy loss


In [16]:
main()

  0%|          | 1/315 [00:02<14:02,  2.68s/it]

16
3 3
16
3 3


  1%|          | 3/315 [00:03<04:43,  1.10it/s]

16
3 3
16
3 3


  2%|▏         | 5/315 [00:03<03:03,  1.69it/s]

16
3 3
16
3 3


  2%|▏         | 7/315 [00:04<02:21,  2.18it/s]

16
3 3
16
3 3


  3%|▎         | 9/315 [00:04<02:00,  2.53it/s]

16
3 3
16
3 3


  3%|▎         | 11/315 [00:05<01:48,  2.80it/s]

16
3 3
16
3 3


  4%|▍         | 13/315 [00:06<01:39,  3.02it/s]

16
3 3
16
3 3


  5%|▍         | 15/315 [00:06<01:35,  3.13it/s]

16
3 3
16
3 3


  5%|▌         | 17/315 [00:07<01:32,  3.23it/s]

16
3 3
16
3 3


  6%|▌         | 19/315 [00:07<01:29,  3.32it/s]

16
3 3
16
3 3


  7%|▋         | 21/315 [00:08<01:27,  3.37it/s]

16
3 3
16
3 3


  7%|▋         | 23/315 [00:08<01:24,  3.46it/s]

16
3 3
16
3 3


  8%|▊         | 26/315 [00:09<01:46,  2.72it/s]

16
3 3
16
3 3





KeyboardInterrupt: 