In [None]:
!unzip /content/dataset.zip

In [2]:
!pip freeze > requirements.txt

In [3]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import os

from PIL import Image
from torch.utils.data import Dataset
import numpy as np

# --- 1. Configuration ---
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 25
IMAGE_HEIGHT = 352
IMAGE_WIDTH = 352
PIN_MEMORY = True
IMG_DIR = "dataset/img"
MASK_DIR = "dataset/mask"

In [4]:
# --- Configuration ---
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 25
IMAGE_HEIGHT = 352
IMAGE_WIDTH = 352
PIN_MEMORY = True
IMG_DIR = "dataset/img"
MASK_DIR = "dataset/mask"

In [5]:
class SentinelDataset(Dataset):
    def __init__(self, image_dirs, mask_dirs, transform=None):
        self.image_dirs = image_dirs if isinstance(image_dirs, list) else [image_dirs]
        self.mask_dirs = mask_dirs if isinstance(mask_dirs, list) else [mask_dirs]
        self.transform = transform

        self.image_files = []
        for image_dir in self.image_dirs:
            for root, _, files in os.walk(image_dir):
                for file in files:
                    if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
                        self.image_files.append(os.path.join(root, file))

        # Assuming mask filenames correspond to image filenames
        self.images = [os.path.basename(f) for f in self.image_files]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]

        # Find the full path of the image
        img_path = None
        for full_path in self.image_files:
            if img_name in full_path:
                img_path = full_path
                break

        if img_path is None:
            raise FileNotFoundError(f"Image {img_name} not found in the dataset.")

        mask_path = img_path.replace('img', 'mask')

        if mask_path is None:
            raise FileNotFoundError(f"Mask {img_name} not found in the dataset. Checked directories: {self.mask_dirs}")

        # Load image and mask
        image = np.array(Image.open(img_path))
        # If the image is 2D, add a channel dimension
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=-1)

        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [14]:
class myNetwork(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=False):
        super(myNetwork, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # Encoder
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

        # Adapt the first convolutional layer for n-channel input
        self.encoder_conv1 = nn.Sequential(
            nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            resnet.bn1,
            resnet.relu
        )

        # --- IMPROVEMENT ---
        # If using 3 channels, copy the pretrained weights
        if n_channels == 3:
            self.encoder_conv1[0].weight.data = resnet.conv1.weight.data.clone()
        # -------------------

        self.encoder_pool1 = resnet.maxpool
        self.encoder_layer1 = resnet.layer1
        self.encoder_layer2 = resnet.layer2
        self.encoder_layer3 = resnet.layer3
        self.encoder_layer4 = resnet.layer4

        # Decoder (rest of the code is unchanged)
        self.up1_upsample = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
        self.up1_conv = self._double_conv(1024 + 1024, 1024)

        self.up2_upsample = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up2_conv = self._double_conv(512 + 512, 512)

        self.up3_upsample = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up3_conv = self._double_conv(256 + 256, 256)

        self.up4_upsample = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up4_conv = self._double_conv(64 + 128, 128) # Note: 64 from x1, 128 from upsample

        self.up_final = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def _double_conv(self, in_channels, out_channels, mid_channels=None):
        if not mid_channels:
            mid_channels = out_channels
        return nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def encoder(self, x):
        x1 = self.encoder_conv1(x)
        x2_p = self.encoder_pool1(x1)
        x2 = self.encoder_layer1(x2_p)
        x3 = self.encoder_layer2(x2)
        x4 = self.encoder_layer3(x3)
        x5 = self.encoder_layer4(x4)
        return x5, x4, x3, x2, x1

    def decoder(self, x5, x4, x3, x2, x1):
        up1_out = self.up1_upsample(x5)
        diffY = x4.size()[2] - up1_out.size()[2]
        diffX = x4.size()[3] - up1_out.size()[3]
        up1_out = F.pad(up1_out, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x4, up1_out], dim=1)
        x = self.up1_conv(x)

        up2_out = self.up2_upsample(x)
        diffY = x3.size()[2] - up2_out.size()[2]
        diffX = x3.size()[3] - up2_out.size()[3]
        up2_out = F.pad(up2_out, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x3, up2_out], dim=1)
        x = self.up2_conv(x)

        up3_out = self.up3_upsample(x)
        diffY = x2.size()[2] - up3_out.size()[2]
        diffX = x2.size()[3] - up3_out.size()[3]
        up3_out = F.pad(up3_out, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, up3_out], dim=1)
        x = self.up3_conv(x)

        up4_out = self.up4_upsample(x)
        diffY = x1.size()[2] - up4_out.size()[2]
        diffX = x1.size()[3] - up4_out.size()[3]
        up4_out = F.pad(up4_out, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x1, up4_out], dim=1)
        x = self.up4_conv(x)

        x = self.up_final(x)
        return self.outc(x)

    def forward(self, x):
        x5, x4, x3, x2, x1 = self.encoder(x)
        return self.decoder(x5, x4, x3, x2, x1)

In [7]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    model.train()
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    iou_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            iou_score += (preds * y).sum() / ((preds + y).sum() - (preds * y).sum() + 1e-8)

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    print(f"IoU score: {iou_score/len(loader)}")
    model.train()

In [None]:
train_ds = SentinelDataset(
    image_dirs=IMG_DIR,
    mask_dirs=MASK_DIR,
    transform=None, # We will set this later
)

train_imgs, val_imgs = train_test_split(train_ds.images, test_size=0.2, random_state=42)

train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

model = myNetwork(n_channels=3, n_classes=1).to(DEVICE)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_ds = SentinelDataset(
    image_dirs=IMG_DIR,
    mask_dirs=MASK_DIR,
    transform=train_transform,
)
train_ds.images = train_imgs

val_ds = SentinelDataset(
    image_dirs=IMG_DIR,
    mask_dirs=MASK_DIR,
    transform=val_transforms,
)
val_ds.images = val_imgs

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=PIN_MEMORY,
    shuffle=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=PIN_MEMORY,
    shuffle=False,
)

scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)
    check_accuracy(val_loader, model, device=DEVICE)

torch.save(model.state_dict(), "unet_resnet50.pth")

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
100%|██████████| 451/451 [00:49<00:00,  9.16it/s, loss=0.0689]


Got 104896961/111761408 with acc 93.86
Dice score: 0.9681315422058105
IoU score: 0.9386692643165588


100%|██████████| 451/451 [00:47<00:00,  9.50it/s, loss=0.119]


Got 104896054/111761408 with acc 93.86
Dice score: 0.9681270122528076
IoU score: 0.9386606216430664


 94%|█████████▍| 424/451 [00:44<00:02,  9.57it/s, loss=0.118]