# ResUNet Weld Seam Segmentation Training Pipeline
This notebook handles:
1. Downloading the `u-net_model` v4 dataset from Roboflow.
2. Defining the pure Deep Residual UNet (ResUNet) architecture.
3. Setting up data loaders with Image Augmentation (Albumentations).
4. Training the model with Mixed Precision, BCE+Dice Loss, and Cosine Annealing.
5. Running predictions on the test set and exporting `best_resunet_seam.pth`.

In [None]:
!pip install -q roboflow albumentations
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from roboflow import Roboflow
import albumentations as A
from albumentations.pytorch import ToTensorV2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Download Dataset

In [None]:
# Using the API key from your workspace
rf = Roboflow(api_key="rf_HrQ6aUiVG3PmJKFOe8pmmXpxol62")
project = rf.workspace("computer-vision-yyh42").project("u-net_model")
version = project.version(4)
dataset = version.download("png-mask-semantic")

DATA_YAML_PATH = dataset.location
print("Dataset downloaded to:", DATA_YAML_PATH)

## 2. Deep Residual UNet (ResUNet) Architecture
Instead of plain convolutions, ResUNet uses residual blocks. This allows gradients to flow smoothly, preventing vanishing gradients, which is critical for thin paths like 1-pixel weld seams.

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_c)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_c != out_c:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_c, out_c, 1, stride, bias=False),
                nn.BatchNorm2d(out_c)
            )

    def forward(self, x):
        res = self.shortcut(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x + res

class ResUNet(nn.Module):
    def __init__(self, in_c=3, out_c=1):
        super().__init__()
        # Initial layer
        self.conv_init = nn.Sequential(
            nn.Conv2d(in_c, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False)
        )
        self.shortcut_init = nn.Sequential(
            nn.Conv2d(in_c, 64, 1, 1, bias=False),
            nn.BatchNorm2d(64)
        )
        
        # Encoders
        self.res1 = ResidualBlock(64, 128, stride=2)
        self.res2 = ResidualBlock(128, 256, stride=2)
        self.res3 = ResidualBlock(256, 512, stride=2)
        
        # Decoders
        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec3 = ResidualBlock(512, 256) 
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = ResidualBlock(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = ResidualBlock(128, 64)
        
        self.final_conv = nn.Conv2d(64, out_c, 1)
        
    def forward(self, x):
        # Initial block with identity adding
        res_init = self.shortcut_init(x)
        c1 = self.conv_init(x) + res_init
        
        c2 = self.res1(c1)
        c3 = self.res2(c2)
        c4 = self.res3(c3)
        
        # Decode
        u3 = self.up3(c4)
        u3 = torch.cat([u3, c3], dim=1)
        d3 = self.dec3(u3)
        
        u2 = self.up2(d3)
        u2 = torch.cat([u2, c2], dim=1)
        d2 = self.dec2(u2)
        
        u1 = self.up1(d2)
        u1 = torch.cat([u1, c1], dim=1)
        d1 = self.dec1(u1)
        
        out = self.final_conv(d1)
        return out

## 3. Dataset & Dataloaders

In [None]:
class WeldSeamDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = image_dir.replace('images', 'masks')
        if not os.path.exists(self.mask_dir):
            # Sometimes roboflow puts them in different places
            pass
        self.images = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        base_name = os.path.splitext(img_name)[0]
        
        # Find corresponding mask
        mask_candidates = [f for f in os.listdir(self.mask_dir) if f.startswith(base_name)]
        if len(mask_candidates) == 0:
            raise FileNotFoundError(f"No mask found for {img_name} in {self.mask_dir}")
        mask_path = os.path.join(self.mask_dir, mask_candidates[0])
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        # Binary mask
        mask = (mask > 0).astype(np.float32)
        
        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
        return image, mask.unsqueeze(0)

train_transform = A.Compose([
    A.Resize(512, 512),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ColorJitter(brightness=0.2, contrast=0.2, p=0.3),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

try:
    train_dataset = WeldSeamDataset(os.path.join(DATA_YAML_PATH, 'train', 'images'), transform=train_transform)
    val_dataset = WeldSeamDataset(os.path.join(DATA_YAML_PATH, 'valid', 'images'), transform=val_transform)
except Exception as e:
    print("Error loading datasets:", e)
    train_dataset, val_dataset = [], []

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) if len(train_dataset)>0 else []
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2) if len(val_dataset)>0 else []

print(f"Train images: {len(train_dataset)}, Val images: {len(val_dataset)}")

## 4. Loss Function
We use BCE + Dice Loss. Dice Loss specifically helps segment thin lines gracefully, since it scores intersection over union.

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, logits, targets, smooth=1.0):
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets)
        
        probs = torch.sigmoid(logits)
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)
        
        intersection = (probs_flat * targets_flat).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (probs_flat.sum() + targets_flat.sum() + smooth)
        
        return bce_loss + dice_loss

def compute_dice_coeff(logits, targets):
    probs = torch.sigmoid(logits) > 0.5
    intersection = (probs & (targets > 0.5)).sum().float()
    return (2. * intersection) / (probs.sum() + targets.sum() + 1e-6)

## 5. Training Loop
Trains the ResUNet model using Mixed Precision for speed and reduced memory.

In [None]:
model = ResUNet(in_c=3, out_c=1).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
criterion = DiceBCELoss()
scaler = torch.cuda.amp.GradScaler()

num_epochs = 50
best_val_dice = 0.0

train_losses, val_losses, val_dices = [], [], []

if len(train_loader) > 0:
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for imgs, masks in train_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                preds = model(imgs)
                loss = criterion(preds, masks)
                
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            epoch_loss += loss.item()
            
        train_losses.append(epoch_loss / len(train_loader))
        scheduler.step()
        
        # Validation
        model.eval()
        val_loss, val_dice = 0, 0
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs, masks = imgs.to(device), masks.to(device)
                with torch.cuda.amp.autocast():
                    preds = model(imgs)
                    loss = criterion(preds, masks)
                    val_loss += loss.item()
                    val_dice += compute_dice_coeff(preds, masks).item()
                    
        val_losses.append(val_loss / len(val_loader))
        val_dice = val_dice / len(val_loader)
        val_dices.append(val_dice)
        
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_losses[-1]:.4f} | Val Dice: {val_dice:.4f}")
        
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            torch.save(model.state_dict(), "best_resunet_seam.pth")
            print("--> Saved new best model")

    print("Training Complete. Best Val Dice:", best_val_dice)

## 6. Plotting Results and Exporting

In [None]:
if len(train_losses) > 0:
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.legend()
    plt.title('Loss Curves')
    
    plt.subplot(1, 2, 2)
    plt.plot(val_dices, label='Val Dice Coeff')
    plt.legend()
    plt.title('Validation Accuracy (Dice)')
    plt.show()

# To download the weights locally:
try:
    from google.colab import files
    files.download('best_resunet_seam.pth')
except:
    pass