In [6]:
# Common imports
import os
import numpy as np
import pandas as pd
from glob import glob
import torch
import torchvision.transforms as transforms
from PIL import Image

# Data
from torchvision.transforms.functional import resize, to_tensor

# Data Viz
import matplotlib.pyplot as plt

# Model
import torch.nn as nn
import torch.nn.functional as F

# Callfpytorch-gradcam

# Metrics
from torchmetrics import JaccardIndex


SyntaxError: invalid syntax (1166817886.py, line 21)

In [None]:

def load_image(image, SIZE):
    img = Image.open(image)
    img_array = to_tensor(resize(img, (SIZE, SIZE)))
    return torch.round(img_array, decimals=4)

def load_images(image_paths, SIZE, mask=False, trim=None):
    if trim is not None:
        image_paths = image_paths[:trim]

    if mask:
        images = torch.zeros((len(image_paths), 1, SIZE, SIZE))  # Channel-first for PyTorch
    else:
        images = torch.zeros((len(image_paths), 3, SIZE, SIZE))

    for i, image in enumerate(image_paths):
        img = load_image(image, SIZE)
        if mask:
            images[i] = img[:1, :, :]  # Select the first channel for mask
        else:
            images[i] = img

    return images

# ... (rest of the `show_image`, `show_mask` functions remain largely the same)

# ... (Loading images and dataset preparation is similar)

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, rate, pooling=True):
        super(EncoderBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.dropout = nn.Dropout2d(rate)
        self.pool = nn.MaxPool2d(2) if pooling else nn.Identity()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.dropout(x)
        pooled_x = self.pool(x)
        return pooled_x, x  # Return both pooled and unpooled features

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, rate):
        super(DecoderBlock, self).__init__()

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = EncoderBlock(in_channels * 2, out_channels, rate, pooling=False)  # Concatenate skip connection

    def forward(self, x, skip_x):
        x = self.up(x)
        x = torch.cat([x, skip_x], dim=1)  # Concatenate along the channel dimension
        x = self.conv(x)
        return x

class AttentionGate(nn.Module):
    def __init__(self, in_channels, bn=True):
        super(AttentionGate, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.bn = nn.BatchNorm2d(in_channels) if bn else nn.Identity()

    def forward(self, x, skip_x):
        g = self.conv1(x)
        x = self.conv2(skip_x)
        x = F.relu(g + x)
        x = self.conv3(x)
        alpha = self.sigmoid(x)
        alpha = self.up(alpha)
        x = alpha * skip_x
        return self.bn(x)

# ... (Rest of the model definition using these PyTorch blocks)

# Loss and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

# Metrics
iou_metric = JaccardIndex(num_classes=2).to(device)  # Assuming you're using a GPU, move metric to the same device

# Training loop
for epoch in range(num_epochs):
    for i, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        # Calculate metrics
        iou = iou_metric(outputs, masks)

        # ... (Print loss, accuracy, IoU, etc.)

# ... (Rest of the training and evaluation code)