In [1]:
import zipfile
from google.colab import drive
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.transforms import functional as tf
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, CenterCrop
import torch.nn.functional as F
from torchvision.utils import save_image
import segmentation_models_pytorch as sm
import torchvision.models as models
from segmentation_models_pytorch.losses import DiceLoss, FocalLoss

In [2]:
drive.mount('/content/drive')

zipped_train = '/content/drive/MyDrive/train_seg.zip'
z=zipfile.ZipFile(zipped_train,'r')
z.extractall(path='/content/train')

zipped_test = '/content/drive/MyDrive/test_seg.zip'
z=zipfile.ZipFile(zipped_test,'r')
z.extractall(path='/content/test')

import gc
gc.collect()

Mounted at /content/drive


0

In [3]:
class CustomImageDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        self.images_dir = os.path.join(data_dir, "images")
        self.masks_dir = os.path.join(data_dir, "masks")
        self.ids = [os.path.splitext(file)[0] for file in os.listdir(self.images_dir)
                    if os.path.exists(os.path.join(self.masks_dir, os.path.splitext(file)[0] + '.png'))]

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_path = os.path.join(self.images_dir, img_id + ".jpg")
        mask_path = os.path.join(self.masks_dir, img_id + ".png")
        image = Image.open(img_path)
        image = image.resize((img_size, img_size))
        mask = Image.open(mask_path).convert("L")
        mask = mask.resize((256, 256))
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

In [4]:
transformer = A.Compose([
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.Normalize(mean=[0.485, 0.456, 0.407],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])


In [5]:
train_dir = "/content/dat/train/train"
dataset = CustomImageDataset(train_dir, transformer)
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.1)

In [6]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=4, shuffle=True)

In [7]:
UNet = sm.Unet(
    encoder_name="mobilenet_v2",        
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
)

In [8]:
class Loss(nn.Module):
    def __init__(self, dice_weight=0.5, ce_weight=0.5):
        super(Loss, self).__init__()

        self.dice_criterion = DiceLoss(mode='binary')
        self.ce_criterion = FocalLoss(mode='multiclass')

        self.dice_weight = dice_weight
        self.ce_weight = ce_weight

    def forward(self, outputs, targets):

        dice_loss = self.dice_criterion(outputs, targets)
        ce_loss = self.ce_criterion(outputs, targets.argmax(1))

        total_loss = self.dice_weight * dice_loss + self.ce_weight * ce_loss

        return total_loss

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(3,1).to(device)
criterion = Loss().to(device)
optimizer = optim.Adam(model.parameters (), lr = 1e-4)

In [10]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (img,mask) in enumerate(dataloader):
        img,mask = img.to(device).float(), mask.float().to(device)

        pred = model(img)
        loss = loss_fn(pred, mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss, current = loss.item(), (batch + 1) * len(X)
        print(f"loss: {loss:>7f} ")

In [11]:
epochs = 2
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader, model, criterion, optimizer)
print("Done!")

Epoch 1
-------------------------------
loss: 0.673059  [    4/ 6836]
loss: 0.544419  [  404/ 6836]
loss: 0.377123  [  804/ 6836]
loss: 0.354884  [ 1204/ 6836]
loss: 0.236219  [ 1604/ 6836]
loss: 0.187652  [ 2004/ 6836]
loss: 0.497534  [ 2404/ 6836]
loss: 0.233981  [ 2804/ 6836]
loss: 0.176436  [ 3204/ 6836]
loss: 0.348381  [ 3604/ 6836]
loss: 0.173597  [ 4004/ 6836]
loss: 0.200610  [ 4404/ 6836]
loss: 0.219469  [ 4804/ 6836]
loss: 0.141096  [ 5204/ 6836]
loss: 0.195525  [ 5604/ 6836]
loss: 0.211999  [ 6004/ 6836]
loss: 0.165298  [ 6404/ 6836]
loss: 0.835365  [ 6804/ 6836]
Epoch 2
-------------------------------
loss: 0.167974  [    4/ 6836]
loss: 0.171659  [  404/ 6836]
loss: 0.727224  [  804/ 6836]
loss: 0.685961  [ 1204/ 6836]
loss: 0.156356  [ 1604/ 6836]
loss: 0.180452  [ 2004/ 6836]
loss: 0.807187  [ 2404/ 6836]
loss: 0.149825  [ 2804/ 6836]
loss: 0.141525  [ 3204/ 6836]
loss: 0.202293  [ 3604/ 6836]
loss: 0.203141  [ 4004/ 6836]
loss: 0.106428  [ 4404/ 6836]
loss: 0.147435  [ 48

In [12]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device).float(), y.to(device).float()

            pred = model(X)
            test_loss += loss_fn(pred, y).item()

In [13]:
epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(val_loader, model, criterion, optimizer)
print("Done!")

Epoch 1
-------------------------------
loss: 0.120008  [    4/ 1207]
loss: 0.148046  [  404/ 1207]
loss: 0.334199  [  804/ 1207]
loss: 0.174412  [ 1204/ 1207]
Epoch 2
-------------------------------
loss: 0.155289  [    4/ 1207]
loss: 0.147005  [  404/ 1207]
loss: 0.292492  [  804/ 1207]
loss: 0.244786  [ 1204/ 1207]
Epoch 3
-------------------------------
loss: 0.760643  [    4/ 1207]
loss: 0.192845  [  404/ 1207]
loss: 0.554432  [  804/ 1207]
loss: 0.097973  [ 1204/ 1207]
Done!


In [27]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0,0,0],
                         std=[1, 1, 1])
])


In [34]:
images_dir = '/content/dat/test/images'
masks_dir = '/content/dat/test/mask'
os.makedirs(masks_dir, exist_ok=True)
image_filenames = os.listdir(images_dir)

In [29]:
from torchvision.transforms.functional import to_pil_image
model.eval()
with torch.no_grad():
    for image in image_filenames:
        image_path = os.path.join(images_dir, image)
        img = Image.open(image_path).convert('RGB')
        img = transform(img).unsqueeze(0).to(device)
        predicted_mask = model(img)
        predicted_mask = torch.sigmoid(predicted_mask)
        predicted_mask = (predicted_mask > 0.5).float()
        predicted_mask = to_pil_image(predicted_mask.squeeze().cpu())
        mask = os.path.splitext(image)[0] + '.png'
        mask_path = os.path.join(masks_dir, mask)
        predicted_mask.save(mask_path)

In [38]:
from google.colab import files
files.download("/content/masks.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>