In [1]:
# !pip install torch
!pip install tensorboard
import os
import argparse
import logging
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as T
import torchvision.transforms.functional as TF  

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
# ==================== ARGUMENT PARSING ====================
parser = argparse.ArgumentParser(description="Train UNet segmentation model with detailed logging")
parser.add_argument('--data_dir', type=str, default='.', help='Root directory containing the Mikroplastikai folder')
parser.add_argument('--ckpt_dir', type=str, default='./checkpoints', help='Directory to save model checkpoints')
parser.add_argument('--log_dir', type=str, default='./logs', help='Directory for TensorBoard logs')
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training')
parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--val_split', type=float, default=0.2, help='Fraction of data for validation')
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
parser.add_argument('--threshold', type=float, default=0.5, help='Threshold for segmentation binarization')
# In notebook environments, ignore unknown Jupyter args
args, _ = parser.parse_known_args()

In [3]:
# ==================== SETUP ====================
# Logging
os.makedirs(args.log_dir, exist_ok=True)
logging.basicConfig(level=logging.INFO,
                    format='[%(asctime)s] %(message)s',
                    handlers=[
                        logging.FileHandler(os.path.join(args.log_dir, 'train.log')),
                        logging.StreamHandler()
                    ])
logger = logging.getLogger()

# Reproducibility and device
torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Directories and hyperparameters
DATA_DIR      = args.data_dir
MICRO_DIR     = os.path.join(DATA_DIR, 'Mikroplastikai')
IMAGE_DIR     = os.path.join(MICRO_DIR, 'images')
MASK_DIR      = os.path.join(MICRO_DIR, 'masks')
CKPT_DIR      = args.ckpt_dir
BATCH_SIZE    = args.batch_size
NUM_EPOCHS    = args.num_epochs
LEARNING_RATE = args.lr
VAL_SPLIT     = args.val_split
THRESHOLD     = args.threshold

# Ensure checkpoint directory exists
os.makedirs(CKPT_DIR, exist_ok=True)

In [4]:
# =================== PADDING UTIL =====================
def pad_to_32(x):
    """
    Pads a PIL Image or C×H×W tensor on right/bottom
    so that height and width become multiples of 32.
    """
    if isinstance(x, Image.Image):
        w, h = x.size
    else:
        _, h, w = x.shape
    pad_h = (32 - h % 32) % 32
    pad_w = (32 - w % 32) % 32
    # pad = (left, top, right, bottom)
    return TF.pad(x, (0, 0, pad_w, pad_h))

In [13]:
# ==================== MODEL & DATASET DEFINITIONS ====================
# Attempt to import segmentation_models_pytorch, or instruct installation
try:
    import segmentation_models_pytorch as smp
except ModuleNotFoundError:
    raise ModuleNotFoundError(
        "segmentation_models_pytorch is required but not installed. "
        "Install it via `pip install segmentation-models-pytorch` and retry.`"
    )
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class SegmentationDatasetRGB(Dataset):
      def __init__(self, images_dir, masks_dir, invert_mask=False, mask_convention='white_fg', transforms=None):
          # CHANGED: collect only files present in both dirs, paired by basename
          img_files  = sorted(f for f in os.listdir(images_dir)  if os.path.isfile(os.path.join(images_dir, f)))
          mask_files = sorted(f for f in os.listdir(masks_dir)   if os.path.isfile(os.path.join(masks_dir, f)))
          img_map    = {os.path.splitext(f)[0]: f for f in img_files}
          mask_map   = {os.path.splitext(f)[0]: f for f in mask_files}
          common     = sorted(set(img_map) & set(mask_map))
          if not common:
              raise RuntimeError(f"No matching image/mask basenames in {images_dir} & {masks_dir}")

          self.samples = [
              (os.path.join(images_dir, img_map[k]),
               os.path.join(masks_dir, mask_map[k]))
              for k in common
          ]
          self.invert          = invert_mask
          self.mask_convention = mask_convention
          self.transforms      = transforms

      def __len__(self):
          return len(self.samples)

      def __getitem__(self, idx):
          img_path, mask_path = self.samples[idx]

          img  = Image.open(img_path).convert('RGB')
          img  = pad_to_32(img)
          mask = Image.open(mask_path).convert('L')
          mask = pad_to_32(mask)

          mask = np.array(mask, dtype=np.float32)
          if self.mask_convention == 'white_fg':
              mask /= 255.0
          if self.invert:
              mask = 1.0 - mask
          mask = np.expand_dims(mask, axis=2)

          if self.transforms:
              img = self.transforms(img)
          else:
              img = T.ToTensor()(img)

          mask = torch.from_numpy(mask).permute(2,0,1).float()
          return img, mask

def create_model():
    model = smp.Unet(
        encoder_name='resnet101',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1,
    )
    return model

In [6]:
# ==================== LOSS FUNCTIONS ====================
class DiceLoss(nn.Module):
    def __init__(self, smooth: float = 1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        probs = torch.sigmoid(logits)
        intersection = 2 * (probs * targets).sum(dim=(2, 3)) + self.smooth
        union = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + self.smooth
        dice = intersection / union
        return 1 - dice.mean()

bce_loss = nn.BCEWithLogitsLoss()
dice_loss = DiceLoss()

In [7]:
# ==================== DATA AUGMENTATION & NORMALIZATION ====================
data_transforms = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [8]:
# ==================== DATASET & DATALOADER ====================
full_dataset = SegmentationDatasetRGB(
    images_dir=IMAGE_DIR,
    masks_dir=MASK_DIR,
    invert_mask=False,
    mask_convention='white_fg',
    transforms=data_transforms
)

total_size = len(full_dataset)
val_size   = int(VAL_SPLIT * total_size)
train_size = total_size - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [9]:
# ==================== MODEL, OPTIMIZER, SCHEDULER ====================
torch.backends.cudnn.benchmark = True
model     = create_model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3
)

In [10]:
# ==================== METRIC UTILITIES ====================
def compute_metrics(logits: torch.Tensor, masks: torch.Tensor, threshold: float):
    probs = torch.sigmoid(logits)
    preds = (probs > threshold).float()
    masks = masks.float()
    tp = (preds * masks).sum()
    fp = (preds * (1 - masks)).sum()
    fn = ((1 - preds) * masks).sum()
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)
    iou = tp / (tp + fp + fn + 1e-6)
    return {'precision': precision.item(), 'recall': recall.item(), 'f1': f1.item(), 'iou': iou.item()}

In [11]:
# ==================== TRAIN & VALIDATION LOOPS ====================
writer = SummaryWriter(log_dir=args.log_dir)

def train_epoch(model, loader, optimizer):
    model.train()
    running_loss = 0.0
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = bce_loss(logits, masks) + dice_loss(logits, masks)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    return running_loss / len(loader.dataset)

@torch.no_grad()
def validate_epoch(model, loader):
    model.eval()
    running_loss = 0.0
    metrics_sum = {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'iou': 0.0}
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        logits = model(imgs)
        loss = bce_loss(logits, masks) + dice_loss(logits, masks)
        running_loss += loss.item() * imgs.size(0)
        batch_metrics = compute_metrics(logits, masks, THRESHOLD)
        for k, v in batch_metrics.items():
            metrics_sum[k] += v * imgs.size(0)
    avg_loss = running_loss / len(loader.dataset)
    avg_metrics = {k: v / len(loader.dataset) for k, v in metrics_sum.items()}
    return avg_loss, avg_metrics

In [12]:
# ==================== MAIN TRAINING LOOP ====================
best_val_loss = float('inf')
for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = train_epoch(model, train_loader, optimizer)
    val_loss, val_metrics = validate_epoch(model, val_loader)
    scheduler.step(val_loss)

    # Logging
    logger.info(f"Epoch {epoch}/{NUM_EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
                f"Val F1: {val_metrics['f1']:.4f} | Val IoU: {val_metrics['iou']:.4f}")
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Loss/Val', val_loss, epoch)
    writer.add_scalar('Metrics/Val_Precision', val_metrics['precision'], epoch)
    writer.add_scalar('Metrics/Val_Recall', val_metrics['recall'], epoch)
    writer.add_scalar('Metrics/Val_F1', val_metrics['f1'], epoch)
    writer.add_scalar('Metrics/Val_IoU', val_metrics['iou'], epoch)

    # Checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt_path = os.path.join(CKPT_DIR, f"unet_epoch{epoch}_valloss{val_loss:.4f}.pth")
        torch.save(model.state_dict(), ckpt_path)

writer.close()
logger.info(f"Training complete. Best validation loss: {best_val_loss:.4f}")


KeyboardInterrupt: 