In [7]:
!unzip /content/dataset.zip

Archive:  /content/dataset.zip
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of /content/dataset.zip or
        /content/dataset.zip.zip, and cannot find /content/dataset.zip.ZIP, period.


In [None]:
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset

from sklearn.model_selection import train_test_split
from PIL import Image
import numpy as np

In [None]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 25
IMAGE_HEIGHT = 352
IMAGE_WIDTH = 352
PIN_MEMORY = True
IMG_DIR = "dataset/img"
MASK_DIR = "dataset/mask"

In [None]:
class SentinelDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(self.image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [None]:
class myNetwork(nn.Module):
    def __init__(self, n_classes=1, bilinear=False):
        super(myNetwork, self).__init__()
        self.n_classes = n_classes
        self.bilinear = bilinear

        # Encoder
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

        self.encoder_conv1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
        self.encoder_pool1 = resnet.maxpool
        self.encoder_layer1 = resnet.layer1
        self.encoder_layer2 = resnet.layer2
        self.encoder_layer3 = resnet.layer3
        self.encoder_layer4 = resnet.layer4

        # Decoder
        self.up1_upsample = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
        self.up1_conv = self._double_conv(1024 + 1024, 1024)

        self.up2_upsample = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up2_conv = self._double_conv(512 + 512, 512)

        self.up3_upsample = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up3_conv = self._double_conv(256 + 256, 256)

        self.up4_upsample = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up4_conv = self._double_conv(64 + 128, 128)

        self.up_final = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def _double_conv(self, in_channels, out_channels, mid_channels=None):
        if not mid_channels:
            mid_channels = out_channels
        return nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def encoder(self, x):
        x1 = self.encoder_conv1(x)
        x2_p = self.encoder_pool1(x1)
        x2 = self.encoder_layer1(x2_p)
        x3 = self.encoder_layer2(x2)
        x4 = self.encoder_layer3(x3)
        x5 = self.encoder_layer4(x4)
        return x5, x4, x3, x2, x1

    def decoder(self, x5, x4, x3, x2, x1):
        up1_out = self.up1_upsample(x5)
        diffY = x4.size()[2] - up1_out.size()[2]
        diffX = x4.size()[3] - up1_out.size()[3]
        up1_out = F.pad(up1_out, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x4, up1_out], dim=1)
        x = self.up1_conv(x)

        up2_out = self.up2_upsample(x)
        diffY = x3.size()[2] - up2_out.size()[2]
        diffX = x3.size()[3] - up2_out.size()[3]
        up2_out = F.pad(up2_out, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x3, up2_out], dim=1)
        x = self.up2_conv(x)

        up3_out = self.up3_upsample(x)
        diffY = x2.size()[2] - up3_out.size()[2]
        diffX = x2.size()[3] - up3_out.size()[3]
        up3_out = F.pad(up3_out, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, up3_out], dim=1)
        x = self.up3_conv(x)

        up4_out = self.up4_upsample(x)
        diffY = x1.size()[2] - up4_out.size()[2]
        diffX = x1.size()[3] - up4_out.size()[3]
        up4_out = F.pad(up4_out, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x1, up4_out], dim=1)
        x = self.up4_conv(x)

        x = self.up_final(x)
        return self.outc(x)

    def forward(self, x):
        x5, x4, x3, x2, x1 = self.encoder(x)
        return self.decoder(x5, x4, x3, x2, x1)

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    model.train()
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    iou_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            iou_score += (preds * y).sum() / ((preds + y).sum() - (preds * y).sum() + 1e-8)

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    print(f"IoU score: {iou_score/len(loader)}")
    model.train()

In [None]:
all_images = os.listdir(IMG_DIR)
train_imgs, val_imgs = train_test_split(all_images, test_size=0.2, random_state=42)

train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

model = myNetwork(n_classes=1).to(DEVICE)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_ds = SentinelDataset(
    image_dir=IMG_DIR,
    mask_dir=MASK_DIR,
    transform=train_transform,
)
train_ds.images = train_imgs

val_ds = SentinelDataset(
    image_dir=IMG_DIR,
    mask_dir=MASK_DIR,
    transform=val_transforms,
)
val_ds.images = val_imgs

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=PIN_MEMORY,
    shuffle=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=PIN_MEMORY,
    shuffle=False,
)

scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)
    check_accuracy(val_loader, model, device=DEVICE)

torch.save(model.state_dict(), "unet_resnet50.pth")