In [None]:
!pip install gdown segmentation-models-pytorch torch torchvision opencv-python

import os
import gdown
import zipfile
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import numpy as np
from tqdm import tqdm

# ==============================
# 1. Get dataset from user
# ==============================
suim_link = input("Enter Google Drive link for SUIM dataset: ").strip()

def download_and_extract(gdrive_link, output_dir):
    if "id=" in gdrive_link:
        file_id = gdrive_link.split("id=")[1]
    elif "/d/" in gdrive_link:
        file_id = gdrive_link.split("/d/")[1].split("/")[0]
    else:
        raise ValueError("Invalid Google Drive link format.")

    gdown.download(f"https://drive.google.com/uc?id={file_id}", "temp.zip", quiet=False)
    with zipfile.ZipFile("temp.zip", 'r') as zip_ref:
        zip_ref.extractall(output_dir)
    os.remove("temp.zip")
    print(f"Extracted to: {output_dir}")

download_and_extract(suim_link, "SUIM")

# ==============================
# 2. Dataset class
# ==============================
class SegDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.files = sorted(os.listdir(img_dir))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.files[idx])
        mask_path = os.path.join(self.mask_dir, self.files[idx])

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.float32)  # Binary mask

        img = torch.tensor(img.transpose(2,0,1), dtype=torch.float32)
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)

        return img, mask

# ==============================
# 3. Train / Val split
# ==============================
img_dir = "SUIM/images"
mask_dir = "SUIM/masks"

all_imgs = sorted(os.listdir(img_dir))
split_idx = int(0.8 * len(all_imgs))
train_imgs = all_imgs[:split_idx]
val_imgs = all_imgs[split_idx:]

os.makedirs("train/images", exist_ok=True)
os.makedirs("train/masks", exist_ok=True)
os.makedirs("val/images", exist_ok=True)
os.makedirs("val/masks", exist_ok=True)

for fname in train_imgs:
    os.rename(os.path.join(img_dir, fname), os.path.join("train/images", fname))
    os.rename(os.path.join(mask_dir, fname), os.path.join("train/masks", fname))
for fname in val_imgs:
    os.rename(os.path.join(img_dir, fname), os.path.join("val/images", fname))
    os.rename(os.path.join(mask_dir, fname), os.path.join("val/masks", fname))

train_dataset = SegDataset("train/images", "train/masks")
val_dataset = SegDataset("val/images", "val/masks")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# ==============================
# 4. Loss Functions
# ==============================
# Lovasz Loss from segmentation_models_pytorch
lovasz_loss_fn = smp.losses.LovaszLoss(mode='binary')
bce_loss_fn = nn.BCEWithLogitsLoss()

def total_loss(pred, target):
    lovasz = lovasz_loss_fn(pred, target)
    bce = bce_loss_fn(pred, target)
    return lovasz, bce, lovasz + bce

# ==============================
# 5. Model & Optimizer
# ==============================
device = "cuda" if torch.cuda.is_available() else "cpu"
model = smp.Unet(encoder_name="resnet34", encoder_weights=None, classes=1, activation=None).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# ==============================
# 6. Training Loop
# ==============================
EPOCHS = 40
report_epochs = [1, 10, 20, 30, 40]

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_lovasz_loss = 0
    train_bce_loss = 0
    train_total_loss = 0

    for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} - Train"):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)

        lovasz, bce, total = total_loss(outputs, masks)
        total.backward()
        optimizer.step()

        train_lovasz_loss += lovasz.item()
        train_bce_loss += bce.item()
        train_total_loss += total.item()

    train_lovasz_loss /= len(train_loader)
    train_bce_loss /= len(train_loader)
    train_total_loss /= len(train_loader)

    # Validation
    model.eval()
    val_total_loss = 0
    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} - Val"):
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            _, _, total = total_loss(outputs, masks)
            val_total_loss += total.item()
    val_total_loss /= len(val_loader)

    # Report only for specific epochs
    if epoch in report_epochs:
        print(f"\nEpoch {epoch}:")
        print(f"  Train Lovasz Loss: {train_lovasz_loss:.4f}")
        print(f"  Train BCE Loss: {train_bce_loss:.4f}")
        print(f"  Train Total Loss: {train_total_loss:.4f}")
        print(f"  Val Total Loss: {val_total_loss:.4f}")
