In [None]:
import segmentation_models_pytorch as smp
import torch

In [None]:
from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn("timm-efficientnet-b3", pretrained='imagenet')

In [None]:
import os 

os.path.exists('/mnt/gis/image/18/18_231006_155459.jpg')

In [None]:
import os
import pandas as pd
from torchvision.io import decode_image
import torch
from torch.utils.data import Dataset

from torch.utils.data import Dataset
from PIL import Image
import os
import albumentations as A
import cv2 
import numpy as np 


class SegmentationDataset(Dataset):
    def __init__(self, images, masks, image_dir, mask_dir, transform=None):
        self.images = images
        self.masks = masks 
        self.transform = transform
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        # print(self.image_list)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])  
        image = Image.open(img_path).convert("RGB")
        image = np.array(image)
        mask = np.load(mask_path)

        if self.transform:
            augmented = self.transform(image=np.array(image), mask=np.array(mask))
            image, mask = augmented["image"], augmented["mask"]
        else:
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            mask = torch.from_numpy(mask).float()

        # image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        # mask = torch.from_numpy(mask).float()
        
        return image, mask

img_dir = '/mnt/gis/image/18/'
mask_dir = '/mnt/gis/label/18/'
images = os.listdir(img_dir)
masks = os.listdir(mask_dir)


coords = []
for mask in masks:
    x,y = mask.split('_')
    x,y = int(x), int(y.replace('.npy', ''))
    coords.append((x,y))

images = [f'18_{x}_{y}.jpg' for (x,y) in coords]

from sklearn.model_selection import train_test_split



config = {
        'batch_size': 4,
        'epochs': 5,
        'learning_rate': 1e-4,
        'val_split': 0.2,
        'num_workers': 4,
    }


train_images, val_images, train_masks, val_masks = train_test_split(
    images, masks, test_size=config['val_split'], random_state=42
)

TARGET_SIZE = (256, 256) 

transform = A.Compose([
    # A.RandomResizedCrop(256, 256, scale=(0.8, 1.0)),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.GridDistortion(p=0.2),
    A.ElasticTransform(p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    A.ToTensorV2(),
], additional_targets={'mask': 'mask'})


train_dataset = SegmentationDataset(train_images, train_masks, image_dir=img_dir, mask_dir=mask_dir, transform=transform)
val_dataset = SegmentationDataset(val_images, val_masks, image_dir=img_dir, mask_dir=mask_dir,  transform=transform)

from torch.utils.data import DataLoader

train_loader =  DataLoader(train_dataset, batch_size=8)
val_loader =  DataLoader(val_dataset, batch_size=8)


for images, masks in train_loader:
    print(images)
    images, masks = images.cuda(), masks.cuda()
    break


# Check data loader looks ok 

In [None]:
import matplotlib.pyplot as plt 

def check_dataset(dataset, n=5):
    for i in range(n):
        sample_image, sample_mask = dataset[i]
        #print(sample_image)
        # print(sample_image.shape, sample_mask.shape)
        plt.subplot(1, 2, 1)
        plt.imshow(sample_image.permute(1, 2, 0))
        plt.subplot(1, 2, 2)
        plt.imshow(sample_mask.squeeze())
        plt.show()

check_dataset(train_dataset)

In [None]:
check_dataset(val_dataset)

In [None]:
from tqdm import tqdm


def calculate_metrics(pred, target, threshold=0.5):
    """Calculate IoU, Dice, and other metrics."""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()
    
    # IoU
    intersection = (pred_binary * target_binary).sum()
    union = pred_binary.sum() + target_binary.sum() - intersection
    iou = intersection / (union + 1e-8)
    
    # Dice coefficient
    dice = (2 * intersection) / (pred_binary.sum() + target_binary.sum() + 1e-8)
    
    # Pixel accuracy
    correct = (pred_binary == target_binary).sum()
    total = target_binary.numel()
    accuracy = correct / total
    
    return {
        'iou': iou.item(),
        'dice': dice.item(),
        'accuracy': accuracy.item()
    }


def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    train_loss = 0.0
    train_metrics = {'iou': 0.0, 'dice': 0.0, 'accuracy': 0.0}
    
    for images, masks in tqdm(train_loader, desc="Training"):
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        #print(outputs.shape, masks.shape, masks.unsqueeze(1).shape)
        loss = criterion(outputs, masks.unsqueeze(1))
        loss.backward()
        optimizer.step()
            
        
        train_loss += loss.item()
        
        # Calculate metrics
        with torch.no_grad():
            batch_metrics = calculate_metrics(torch.sigmoid(outputs), masks.unsqueeze(1))
            for key in train_metrics:
                train_metrics[key] += batch_metrics[key]
    
    train_loss /= len(train_loader)
    for key in train_metrics:
        train_metrics[key] /= len(train_loader)
    
    return train_loss, train_metrics

def validate_model(model, val_loader, criterion, device):
    """Validate the model."""
    model.eval()
    val_loss = 0.0
    val_metrics = {'iou': 0.0, 'dice': 0.0, 'accuracy': 0.0}
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating"):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks.unsqueeze(1))  # Add channel dim for masks
            
            val_loss += loss.item()
            
            # Calculate metrics
            batch_metrics = calculate_metrics(torch.sigmoid(outputs), masks.unsqueeze(1))
            for key in val_metrics:
                val_metrics[key] += batch_metrics[key]
    
    # Average metrics
    val_loss /= len(val_loader)
    for key in val_metrics:
        val_metrics[key] /= len(val_loader)
    
    return val_loss, val_metrics

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = smp.DeepLabV3Plus(
    encoder_name="timm-efficientnet-b3",
    encoder_weights="imagenet",
    classes=1,
    activation=None,  # Use raw logits
)

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
    activation=None
)

model = model.to(device)

# Loss function




history = {
    'train_loss': [], 'val_loss': [],
    'train_iou': [], 'val_iou': [],
    'train_dice': [], 'val_dice': [],
    'train_accuracy': [], 'val_accuracy': []
}

import torch.nn as nn

# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
#loss_fn = smp.losses.FocalLoss(mode='binary', alpha=0.25, gamma=2, from_logits=True)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
#loss_fn = nn.BCELoss()

for epoch in range(config['epochs']):


    train_loss, train_metrics = train_epoch(
        model, train_loader, loss_fn, optimizer, device
    )

    val_loss, val_metrics = validate_model(model, val_loader, loss_fn, device) # todo: get eval loaders 

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_metrics['iou'])
    history['val_iou'].append(val_metrics['iou'])
    history['train_dice'].append(train_metrics['dice'])
    history['val_dice'].append(val_metrics['dice'])
    history['train_accuracy'].append(train_metrics['accuracy'])
    history['val_accuracy'].append(val_metrics['accuracy'])

    print(f"Train Loss: {train_loss:.4f}, Train IoU: {train_metrics['iou']:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val IoU: {val_metrics['iou']:.4f}")

In [None]:
import matplotlib.pyplot as plt 


def plot_training_history(history):
    """Plot training history."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train Loss')
    axes[0, 0].plot(history['val_loss'], label='Val Loss')
    axes[0, 0].set_title('Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].legend()
    
    # IoU
    axes[0, 1].plot(history['train_iou'], label='Train IoU')
    axes[0, 1].plot(history['val_iou'], label='Val IoU')
    axes[0, 1].set_title('IoU')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].legend()
    
    # Dice
    axes[1, 0].plot(history['train_dice'], label='Train Dice')
    axes[1, 0].plot(history['val_dice'], label='Val Dice')
    axes[1, 0].set_title('Dice Coefficient')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].legend()
    
    # Accuracy
    axes[1, 1].plot(history['train_accuracy'], label='Train Accuracy')
    axes[1, 1].plot(history['val_accuracy'], label='Val Accuracy')
    axes[1, 1].set_title('Accuracy')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].legend()
    
    plt.tight_layout()

    plt.show()

plot_training_history(history)

In [None]:
for batch_i, (images, masks) in enumerate(train_loader):
    images = images.to(device)
    masks = masks.to(device)
    
    with torch.no_grad():
        outputs = model(images)

    images = images.detach().cpu()
    preds = outputs.detach().cpu()
    
    for i in range(8):
        fig, axs = plt.subplots(1,2,figsize=(10,5))
        axs[0].imshow(images[i].permute(1, 2, 0))
        axs[1].imshow(preds[i].squeeze() > 0)
        plt.show()


    if batch_i > -1:
        break