In [None]:
!pip install segmentation_models_pytorch

In [None]:
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, data_folder, mask_path, input_size=(960,640), mask_size=(960,640)):
        self.input_size = input_size
        self.image_dir = data_folder
        self.mask_size = mask_size
        self.mask_path = mask_path
        self.images = os.listdir(self.image_dir)

        self.transform_image = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.ToTensor(),
        ])

        self.transform_mask = transforms.Compose([
            transforms.Resize(self.mask_size),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_path, self.images[idx].replace(".jpg", ".png"))
        image = Image.open(img_path)
        mask = Image.open(mask_path)

        image = self.transform_image(image)
        mask = self.transform_mask(mask)

        return image, mask


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import segmentation_models_pytorch as smp
import os
import glob

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 1
batch_size = 16
learning_rate = 0.0001
num_epochs = 20

model_name = 'unet'
encoder_name = 'resnet18'

model = smp.Unet(encoder_name, in_channels=3, classes=num_classes, activation="sigmoid").to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

model = model.to(device)

train_image_path = "/data/training/images"
train_mask_path = "/data/training/masks"

train_dataset = CustomDataset(train_image_path, train_mask_path)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

checkpoint_dir = '/model/checkpoints/'
latest_checkpoint = max(glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pth')), key=os.path.getctime, default=None)

if latest_checkpoint is not None:
    checkpoint = torch.load(latest_checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Resuming training from epoch {start_epoch}, loss: {checkpoint['loss']:.4f}")
else:
    start_epoch = 0
    print("No checkpoint found. Starting from epoch 0.")

for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0

    with tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} (Training)', leave=True) as train_bar:
        for batch_idx, (images, masks) in enumerate(train_bar):
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            train_bar.set_postfix(loss=f'{loss.item():.4f}')

    average_loss = running_loss / len(train_loader)

    save_interval = 1
    if (epoch + 1) % save_interval == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch + 1}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': average_loss,
        }, checkpoint_path)

        print(f"Checkpoint saved at {checkpoint_path}")
    torch.save(model.state_dict(), '/model/unet_resnet18_100k.pth')

torch.save(model.state_dict(), '/model/unet_resnet18_100k.pth')
torch.cuda.empty_cache()
