# Drone Images Segmentation Using SegNet

In [None]:
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import v2
from torchvision import tv_tensors
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

import os
import time
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt


# Dataset Class

In [None]:
class DroneImagesSegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, filenames, spatial_transforms=None, normalize=None):
        self.images_dir = images_dir
        self.masks_dir = images_dir
        self.filenames = filenames
        self.spatial_transforms = spatial_transforms
        self.normalize = normalize

    def __len__(self):
        return len(self.images_dir)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.images_dir, self.filenames[index])
        mask_path = os.path.join(self.masks_dir, self.filenames[index])
        
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("RGB")

        image = tv_tensors(image)
        mask = tv_tensors(mask)

        if self.spatial_transforms:
            image, mask = self.spatial_transforms(image, mask)
        
        if self.normalize:
            image = self.normalize(image)
        
        mask = (mask > 0).to(torch.float32)

        return image, mask

# Helper Functions

# Model Architecture