In [None]:
!pip install gdown segmentation-models-pytorch torch torchvision opencv-python matplotlib

import os
import gdown
import zipfile
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# ==============================
# 1. Get dataset from user
# ==============================
suim_link = input("Enter Google Drive link for SUIM dataset: ").strip()

def download_and_extract(gdrive_link, output_dir):
    if "id=" in gdrive_link:
        file_id = gdrive_link.split("id=")[1]
    elif "/d/" in gdrive_link:
        file_id = gdrive_link.split("/d/")[1].split("/")[0]
    else:
        raise ValueError("Invalid Google Drive link format.")

    gdown.download(f"https://drive.google.com/uc?id={file_id}", "temp.zip", quiet=False)
    with zipfile.ZipFile("temp.zip", 'r') as zip_ref:
        zip_ref.extractall(output_dir)
    os.remove("temp.zip")
    print(f"Extracted to: {output_dir}")

download_and_extract(suim_link, "SUIM")

# ==============================
# 2. Dataset class
# ==============================
class SegDataset(Dataset):
    def __init__(self, img_dir, mask_dir):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.files = sorted(os.listdir(img_dir))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.files[idx])
        mask_path = os.path.join(self.mask_dir, self.files[idx])

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.float32)  # binary mask

        img_tensor = torch.tensor(img.transpose(2,0,1), dtype=torch.float32)
        mask_tensor = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)

        return img_tensor, mask_tensor, img, mask

# ==============================
# 3. Data preparation
# ==============================
img_dir = "SUIM/images"
mask_dir = "SUIM/masks"

all_imgs = sorted(os.listdir(img_dir))
split_idx = int(0.8 * len(all_imgs))
train_imgs = all_imgs[:split_idx]
val_imgs = all_imgs[split_idx:]

os.makedirs("train/images", exist_ok=True)
os.makedirs("train/masks", exist_ok=True)
os.makedirs("val/images", exist_ok=True)
os.makedirs("val/masks", exist_ok=True)

# Move files (or copy if you prefer)
for fname in train_imgs:
    os.rename(os.path.join(img_dir, fname), os.path.join("train/images", fname))
    os.rename(os.path.join(mask_dir, fname), os.path.join("train/masks", fname))
for fname in val_imgs:
    os.rename(os.path.join(img_dir, fname), os.path.join("val/images", fname))
    os.rename(os.path.join(mask_dir, fname), os.path.join("val/masks", fname))

train_dataset = SegDataset("train/images", "train/masks")
val_dataset = SegDataset("val/images", "val/masks")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# ==============================
# 4. Model, Loss, Optimizer
# ==============================
device = "cuda" if torch.cuda.is_available() else "cpu"
model = smp.Unet(encoder_name="resnet34", encoder_weights=None, classes=1, activation=None).to(device)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# ==============================
# 5. Function to visualize results
# ==============================
def visualize_results(model, dataloader, num_samples=3):
    model.eval()
    shown = 0
    with torch.no_grad():
        for imgs, masks, orig_imgs, orig_masks in dataloader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            preds = torch.sigmoid(outputs) > 0.5
            preds = preds.squeeze().cpu().numpy()

            # Show
            for i in range(len(orig_imgs)):
                plt.figure(figsize=(9,3))
                plt.subplot(1,3,1)
                plt.imshow(orig_imgs[i])
                plt.title("Original")
                plt.axis('off')

                plt.subplot(1,3,2)
                plt.imshow(orig_masks[i], cmap='gray')
                plt.title("Ground Truth")
                plt.axis('off')

                plt.subplot(1,3,3)
                plt.imshow(preds if preds.ndim==2 else preds[i], cmap='gray')
                plt.title("Prediction")
                plt.axis('off')

                plt.show()
                shown += 1
                if shown >= num_samples:
                    return

# ==============================
# 6. Training Loop
# ==============================
EPOCHS = 50
display_epochs = [1, 30, 50]

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0
    for imgs, masks, _, _ in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} - Train"):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    print(f"Epoch {epoch} - Train Loss: {train_loss:.4f}")

    # Display results at specific epochs
    if epoch in display_epochs:
        print(f"\n--- Visualizing results at epoch {epoch} ---")
        visualize_results(model, val_loader, num_samples=3)
