<a href="https://colab.research.google.com/github/WasifAsi/DSGP_GP-19/blob/deeplabv3%2B/DeepLab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Imports**

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from PIL import Image
import cv2
from google.colab import drive
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pickle
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix

# Mount Google Drive
drive.mount('/content/drive')

# Create model directory in Google Drive
MODEL_SAVE_PATH = '/content/drive/MyDrive/shoreline_models'
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Set hyperparameters
learning_rate = 0.0005
batch_size = 6  # decreased from 8
num_epochs = 200
image_size = 540  # Increased from 240

#**Loss Functions**

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.003, gamma=5.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target):
        # Apply sigmoid
        pred_prob = torch.sigmoid(pred)

        # Calculate pt
        pt = target * pred_prob + (1 - target) * (1 - pred_prob)

        # Calculate focal weight
        focal_weight = (1 - pt) ** self.gamma

        # Calculate loss
        loss = -self.alpha * focal_weight * torch.log(pt + 1e-12)

        return loss.mean()

# Combine with Dice Loss
class CombinedFocalDiceLoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2.0, dice_weight=0.5):
        super().__init__()
        self.focal = FocalLoss(alpha=alpha, gamma=gamma)
        self.dice_weight = dice_weight

    def forward(self, pred, target):
        focal_loss = self.focal(pred, target)

        pred_prob = torch.sigmoid(pred)
        intersection = (pred_prob * target).sum()
        dice_loss = 1 - (2. * intersection + 1) / (pred_prob.sum() + target.sum() + 1)

        return focal_loss + self.dice_weight * dice_loss