### Model U-net architecture

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3,1,1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3,1,1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64,128,256,512],):
        super(UNET,self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
        # Down Part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature #pass to next level
    
        # Up Part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2,feature))
    
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0],out_channels, kernel_size=1)
    
    def forward(self, x):
        skip_connections = []
    
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
    
        x = self.bottleneck(x)
    
        skip_connections = skip_connections[::-1] #reverse the connection
            
        for idx in range(0,len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if(x.shape != skip_connection.shape):
                x = TF.resize(x, size=skip_connection.shape[2:])
                
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)
                
        return self.final_conv(x)

In [4]:
def test():
    x = torch.randn((3,1,160,160))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    print(preds.shape)
    print(x.shape)
    assert preds.shape == x.shape

In [5]:
if __name__ == "__main__":
    test()

torch.Size([3, 1, 160, 160])
torch.Size([3, 1, 160, 160])


### Dataset 

In [6]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

In [7]:
class Carvana(Dataset):
    def __init__(self,image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self,index):
        img_path = os.path.join(self.image_dir,self.images[index])
        mask_path = os.path.join(self.mask_dir,self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"),dtype = np.float32)
        mask[mask == 255.0] = 1.0 #as we need sigmoid on last layer

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        image = transforms.ToTensor()(image)
        mask = torch.tensor(mask, dtype=torch.float32)
            
        return image,mask

### Training

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

In [9]:
LEARNING_RATE = 1e-4
BATCH_SIZE = 1
NUM_EPOCHS = 20
shuffle=True
pin_memory=True
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_WORKERS = 0
IMAGE_DIR = "data/train"
MASK_DIR = "data/train_masks"

In [None]:
transform = transforms.Compose([
    transforms.Resize((160, 160)),
])

In [10]:
def dice_loss(pred, target):
    smooth = 1.0
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    return 1 - (2.0 * intersection + smooth) / (union + smooth)

train_dataset = Carvana(
    image_dir=IMAGE_DIR,
    mask_dir=MASK_DIR,
    transform=None  # Add augmentation logic if needed
)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    pin_memory=True
)

In [11]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

In [16]:
def train():
    model.train()
    loop = tqdm(train_loader, leave=True)
    total_loss = 0

    for batch_idx, (images, masks) in enumerate(loop):
        print(f"Batch {batch_idx + 1}/{len(train_loader)}")  # Log batch progress
        images = images.to(DEVICE, dtype=torch.float32)
        masks = masks.to(DEVICE, dtype=torch.float32)

        # Forward pass
        predictions = model(images)
        loss = criterion(predictions, masks) + dice_loss(predictions, masks)
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log GPU memory
        print_gpu_memory()

        # Clear batch data
        del images, masks, predictions, loss
        torch.cuda.empty_cache()

        # Update progress bar
        loop.set_postfix(loss=loss.item())

    print(f"Epoch Loss: {total_loss / len(train_loader)}")

In [17]:
for epoch in range(NUM_EPOCHS):
    print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}]")
    train()

Epoch [1/20]


  0%|                                                                                         | 0/5088 [00:00<?, ?it/s]

Batch 1/5088


  0%|                                                                                         | 0/5088 [00:08<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 76.00 MiB. GPU 0 has a total capacity of 6.00 GiB of which 0 bytes is free. Of the allocated memory 12.21 GiB is allocated by PyTorch, and 166.66 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)