In [None]:
import os, glob, random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm 

import torch

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torch.nn.functional as F    

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from PIL import Image

import torchvision.transforms as T
import torchvision.transforms.functional as TF

import matplotlib.pyplot as plt

try:
    import cv2
    _HAS_CV2 = True
except Exception:
    _HAS_CV2 = False

print("cv2 available:", _HAS_CV2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


In [None]:
BASE =  #give your folder name 

TRAIN_DIR = os.path.join(BASE, "train")
TEST_DIR  = os.path.join(BASE, "test")

file_list = sorted(glob.glob(os.path.join(TRAIN_DIR, "*")))

print("BASE:", BASE)
print("TRAIN_DIR:", TRAIN_DIR)
print("TEST_DIR:", TEST_DIR)
print("Total number of files:", len(file_list))
print("Sample files:", file_list[:5])

IMG_SIZE = 256
BATCH_SIZE = 4
NUM_EPOCHS = 200
LR = 1e-4
NUM_WORKERS = 0

MODEL_DIR = os.path.join(BASE, "models")
os.makedirs(MODEL_DIR, exist_ok=True)

CKPT_PATH = os.path.join(MODEL_DIR, "batsnet.pth")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


In [None]:
to_tensor_norm = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
def detect_seam_x(img_pil, sample_border=5):
    arr = np.array(img_pil.convert('RGB')).astype(np.float32)
    cols = arr.mean(axis=2).mean(axis=0)
    diffs = np.abs(np.diff(cols))
    if diffs.size == 0:
        return arr.shape[1] // 2
    sep = int(np.argmax(diffs) + 1)
    w = arr.shape[1]
    if sep < w*0.1 or sep > w*0.9:
        return w // 2
    return sep

def mask_to_binary_from_color(mask_pil, dist_thresh=15, bg_border=5, morph_kernel=3):
    arr = np.array(mask_pil.convert('RGB')).astype(np.float32)
    H, W, _ = arr.shape
    b = max(1, bg_border)
    samples = []
    samples.append(arr[:b, :b, :].reshape(-1,3))
    samples.append(arr[:b, -b:, :].reshape(-1,3))
    samples.append(arr[-b:, :b, :].reshape(-1,3))
    samples.append(arr[-b:, -b:, :].reshape(-1,3))
    samples.append(arr[:b, W//4:W//4+b, :].reshape(-1,3))
    samples.append(arr[-b:, W//4:W//4+b, :].reshape(-1,3))
    samples = np.concatenate(samples, axis=0)
    bg_color = np.median(samples, axis=0).astype(np.float32)

    diff = arr - bg_color[None,None,:]   
    dist = np.linalg.norm(diff, axis=2)  
    bin_mask = (dist > dist_thresh).astype(np.uint8)

    if _HAS_CV2 and morph_kernel and morph_kernel > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_kernel, morph_kernel))
        bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_OPEN, k)
        bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, k)

    return torch.from_numpy(bin_mask.astype(np.float32)).unsqueeze(0) 

In [None]:
class Pix2PixConcatDataset(Dataset):
    def __init__(self, file_list, img_size=IMG_SIZE, augment=False, dist_thresh=15, morph_kernel=3):
        self.files = file_list
        self.img_size = img_size
        self.augment = augment
        self.dist_thresh = dist_thresh
        self.morph_kernel = morph_kernel

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        p = self.files[idx]
        img_all = Image.open(p).convert('RGB')
        sep = detect_seam_x(img_all)
        w, h = img_all.size
        sep = max(1, min(sep, w-1))

        left = img_all.crop((0, 0, sep, h))
        right = img_all.crop((sep, 0, w, h))

        left = TF.resize(left, (self.img_size, self.img_size), interpolation=Image.BILINEAR)
        right = TF.resize(right, (self.img_size, self.img_size), interpolation=Image.NEAREST)

        if self.augment:
            if random.random() < 0.5:
                left = TF.hflip(left)
                right = TF.hflip(right)

            if random.random() < 0.5:
                left = TF.vflip(left)
                right = TF.vflip(right)

        inp_t = to_tensor_norm(left).float()  
        tgt_t = mask_to_binary_from_color(right, dist_thresh=self.dist_thresh,
                                          bg_border=5, morph_kernel=self.morph_kernel).float()  

        return inp_t, tgt_t


In [None]:
DIST_THRESH = 15
MORPH_KERNEL = 3 if _HAS_CV2 else 0

ds = Pix2PixConcatDataset(file_list, img_size=IMG_SIZE, augment=True,
                          dist_thresh=DIST_THRESH, morph_kernel=MORPH_KERNEL)
print("Dataset length:", len(ds))

indices = list(range(len(ds)))
train_idx, test_idx = train_test_split(indices, test_size=0.3, random_state=__, shuffle=True) #give random state

train_loader = DataLoader(ds, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(train_idx),
                          num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(ds, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(test_idx),
                          num_workers=NUM_WORKERS, pin_memory=True)

train_paths = [file_list[i] for i in train_idx]
test_paths  = [file_list[i] for i in test_idx]

print("Train images:", len(train_paths))
print("Test images :", len(test_paths))

print("\n--- Sample train paths ---")
for p in train_paths[:5]:
    print(p)

print("\n--- Sample test paths ---")
for p in test_paths[:5]:
    print(p)

images, masks = next(iter(train_loader))
print("Batch shapes:", images.shape, masks.shape)

img0 = ((images[0].permute(1,2,0).numpy() * 0.5) + 0.5)
mask0 = masks[0].squeeze(0).numpy()

plt.figure(figsize=(6,3))
plt.subplot(1,2,1); plt.imshow(img0); plt.title("Sample Image"); plt.axis('off')
plt.subplot(1,2,2); plt.imshow(mask0, cmap='gray'); plt.title("Sample Mask (binary)"); plt.axis('off')
plt.show()


In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)
    
class TextureEncoder(nn.Module):
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch * 2, stride=2)
        self.enc3 = ConvBlock(base_ch * 2, base_ch * 4, stride=2)
        self.enc4 = ConvBlock(base_ch * 4, base_ch * 4, stride=2)

    def forward(self, x):
        x = self.enc1(x)
        x = self.enc2(x)
        x = self.enc3(x)
        x = self.enc4(x)
        return x  
    
class AppearanceDeviation(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.local_mean = nn.Conv2d(
            channels, channels, kernel_size=5, padding=2, groups=channels, bias=False
        )

    def forward(self, F):
        mu = self.local_mean(F)
        D = torch.norm(F - mu, dim=1, keepdim=True)
        return D
    
class FeatureVariance(nn.Module):
    def __init__(self, kernel_size=5):
        super().__init__()
        self.pool = nn.AvgPool2d(kernel_size, stride=1, padding=kernel_size // 2)

    def forward(self, F):
        mean = self.pool(F)
        mean_sq = self.pool(F ** 2)
        var = mean_sq - mean ** 2
        V = torch.mean(var, dim=1, keepdim=True)
        return V
    
class TransitionEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(2, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, D, V):
        T = torch.cat([D, V], dim=1)
        return self.conv(T)
    
class SpatialCoherence(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.d1 = nn.Conv2d(channels, channels, 3, padding=1, dilation=1)
        self.d2 = nn.Conv2d(channels, channels, 3, padding=2, dilation=2)
        self.d4 = nn.Conv2d(channels, channels, 3, padding=4, dilation=4)
        self.fuse = nn.Conv2d(channels * 3, channels, 1)

    def forward(self, x):
        x1 = self.d1(x)
        x2 = self.d2(x)
        x3 = self.d4(x)
        x_cat = torch.cat([x1, x2, x3], dim=1)
        return self.fuse(x_cat)

class PTDNet(nn.Module):
    def __init__(self, base_ch=48):
        super().__init__()

        self.encoder = TextureEncoder(in_ch=3, base_ch=base_ch)
        enc_out_ch = base_ch * 4
        self.dev = AppearanceDeviation(channels=enc_out_ch)
        self.var = FeatureVariance()
        self.trans_enc = TransitionEncoder()
        self.coherence = SpatialCoherence(channels=32)
        self.boundary_head = nn.Conv2d(32, 1, 1)      

    def forward(self, x):
        Fm = self.encoder(x)     
        D = self.dev(Fm)                 
        V = self.var(Fm)           
        T = self.trans_enc(D, V)          
        C = self.coherence(T)           
        logits = self.boundary_head(C)   

        logits = F.interpolate(logits, size=x.shape[2:], mode='bilinear', align_corners=False)

        return logits

model = PTDNet().to(DEVICE)

In [None]:
def dice_loss(pred, target, eps=1e-6):
    pred = pred.view(-1)
    target = target.view(-1)
    inter = (pred * target).sum()
    return 1 - (2 * inter + eps) / (pred.sum() + target.sum() + eps)

def soft_dice_loss(pred, target, eps=1e-6):
    inter = (pred * target).sum(dim=(1,2,3))
    union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))
    dice = (2 * inter + eps) / (union + eps)
    return 1 - dice.mean()

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

In [None]:
best_val_loss = float('inf')
train_losses = []
val_losses = []
train_accs = []
val_accs = []

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    running_correct = 0
    running_pixels = 0

    for imgs, masks in tqdm(train_loader, desc=f"Train E{epoch}", leave=False):
        imgs = imgs.to(DEVICE, dtype=torch.float32)
        masks = masks.to(DEVICE, dtype=torch.float32)

        optimizer.zero_grad()
        logits, sigma = model(imgs)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()   

        running_correct += (preds.bool() == masks.bool()).sum().item()
        running_pixels += masks.numel()

        loss_bce = nn.BCEWithLogitsLoss()(logits, masks)
        loss_dice = soft_dice_loss(probs, masks)

        loss = 0.6 * loss_bce + 0.4 * loss_dice 

        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)

    train_loss = running_loss / len(train_idx)
    train_acc = running_correct / float(running_pixels)
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    model.eval()
    val_running = 0.0
    val_running_correct = 0
    val_running_pixels = 0

    with torch.no_grad():
        for imgs, masks in test_loader:
            imgs = imgs.to(DEVICE, dtype=torch.float32)
            masks = masks.to(DEVICE, dtype=torch.float32)

            logits, sigma = model(imgs)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            val_running_correct += (preds.bool() == masks.bool()).sum().item()
            val_running_pixels += masks.numel()

            loss_bce = nn.BCEWithLogitsLoss()(logits, masks)
            loss_dice = soft_dice_loss(probs, masks)

            val_loss_batch = 0.6 * loss_bce + 0.4 * loss_dice 

            val_running += val_loss_batch.item() * imgs.size(0)

    val_loss = val_running / len(test_idx)
    val_acc = val_running_correct / float(val_running_pixels)

    val_losses.append(val_loss)
    val_accs.append(val_acc)

    scheduler.step(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss
        }, CKPT_PATH)

    print(f"Epoch {epoch}/{NUM_EPOCHS} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | train_acc={train_acc:.4f} | val_acc={val_acc:.4f}")


In [None]:
out_dir = globals().get("MODEL_DIR", ".")
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, "loss_and_acc_curve.png")

epochs = np.arange(1, len(train_losses) + 1)
train_losses = np.array(train_losses)
val_losses = np.array(val_losses)
train_accs = np.array(train_accs)
val_accs = np.array(val_accs)

def smooth(x, window=5):
    if len(x) < window:
        return x
    kernel = np.ones(window) / window
    return np.convolve(x, kernel, mode="same")

train_losses_s = smooth(train_losses, window=5)
val_losses_s = smooth(val_losses, window=5)
train_accs_s = smooth(train_accs, window=5)
val_accs_s = smooth(val_accs, window=5)

plt.figure(figsize=(12,4))

plt.subplot(1,2,1)
plt.plot(epochs, train_losses, label='Train loss (raw)', alpha=0.3)
plt.plot(epochs, val_losses, label='Val loss (raw)', alpha=0.3)
best_epoch = int(np.argmin(val_losses)) + 1
plt.scatter([best_epoch], [val_losses.min()], color='red', zorder=10, label=f'Best val (ep {best_epoch})')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss vs Epoch')
plt.legend(); plt.grid(alpha=0.2)

plt.subplot(1,2,2)
plt.plot(epochs, train_accs, label='Train acc (raw)', alpha=0.3)
plt.plot(epochs, val_accs, label='Val acc (raw)', alpha=0.3)
best_acc_epoch = int(np.argmax(val_accs)) + 1
plt.scatter([best_acc_epoch], [val_accs.max()], color='red', zorder=10, label=f'Best val acc (ep {best_acc_epoch})')
plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Pixel-wise Accuracy vs Epoch')
plt.ylim(0.0, 1.0)
plt.legend(); plt.grid(alpha=0.2)

plt.tight_layout()
plt.savefig(out_path, dpi=150)
print("Saved loss+accuracy plot to:", out_path)
plt.show()


In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': NUM_EPOCHS
}, CKPT_PATH)

print("Model saved successfully at:", CKPT_PATH)

In [None]:
import random
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
if 'model_state_dict' in ckpt:
    model.load_state_dict(ckpt['model_state_dict'])
model.to(DEVICE); model.eval()

idx = random.choice(test_idx)
test_img, test_mask = ds[idx]
inp = test_img.unsqueeze(0).to(DEVICE)

with torch.no_grad():
    logits, sigma = model(inp)
    probs = torch.sigmoid(logits)

prob_map = probs[0,0].cpu().numpy()
pred_bin = (prob_map > 0.5).astype(float)
gt_map = test_mask.squeeze(0).numpy()

inp_vis = ((test_img.permute(1,2,0).numpy() * 0.5) + 0.5)

plt.figure(figsize=(10,4))
plt.subplot(1,3,1); plt.imshow(inp_vis); plt.title("Input"); plt.axis('off')
plt.subplot(1,3,2); plt.imshow(gt_map, cmap='gray'); plt.title("GT"); plt.axis('off')
plt.subplot(1,3,3); plt.imshow(pred_bin, cmap='gray'); plt.title("Pred"); plt.axis('off')
plt.show()

In [None]:
test_ds = Pix2PixConcatDataset(
    test_paths,
    img_size=IMG_SIZE,
    augment=False,
    dist_thresh=DIST_THRESH,
    morph_kernel=MORPH_KERNEL
)

test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print("Test dataset length:", len(test_ds))

eps = 1e-6

acc_per_image = []
iou_per_image = []
precision_per_image = []
recall_per_image = []
sensitivity_per_image = []
dice_per_image = []

preds_list = []
gts_list = []
imgs_list = []
paths_list = []

model.eval()
global_idx = 0

with torch.no_grad():
    for imgs, masks in tqdm(test_loader, desc="Evaluate"):
        imgs = imgs.to(DEVICE, dtype=torch.float32)
        masks = masks.to(DEVICE, dtype=torch.float32)

        logits1= model(imgs)
        probs1 = torch.sigmoid(logits1)

        imgs_flip = torch.flip(imgs, [3])
        logits2 = model(imgs_flip)
        probs2 = torch.sigmoid(logits2)
        probs2 = torch.flip(probs2, [3])

        probs = (probs1 + probs2) / 2
        preds = (probs > 0.5).float()


        B = imgs.size(0)
        H, W = masks.shape[2], masks.shape[3]

        for i in range(B):
            pred_i = preds[i, 0].cpu().numpy().astype(np.float32).ravel()
            gt_i   = masks[i, 0].cpu().numpy().astype(np.float32).ravel()

            TP = (pred_i * gt_i).sum()
            FP = (pred_i * (1 - gt_i)).sum()
            FN = ((1 - pred_i) * gt_i).sum()
            TN = ((1 - pred_i) * (1 - gt_i)).sum()

            acc = (TP + TN) / (TP + TN + FP + FN + eps)
            union = TP + FP + FN
            iou = (TP + eps) / (union + eps)
            precision = (TP + eps) / (TP + FP + eps)
            recall = (TP + eps) / (TP + FN + eps)
            sensitivity = recall
            dice = (2 * TP + eps) / ((pred_i.sum() + gt_i.sum()) + eps)

            acc_per_image.append(acc)
            iou_per_image.append(iou)
            precision_per_image.append(precision)
            recall_per_image.append(recall)
            sensitivity_per_image.append(sensitivity)
            dice_per_image.append(dice)

            try:
                imgs_list.append(((imgs[i].cpu().permute(1,2,0).numpy() * 0.5) + 0.5))
            except:
                imgs_list.append(None)

            preds_list.append(pred_i.reshape(H, W))
            gts_list.append(gt_i.reshape(H, W))
            paths_list.append(test_paths[global_idx])

            global_idx += 1

acc_per_image = np.array(acc_per_image)
iou_per_image = np.array(iou_per_image)
precision_per_image = np.array(precision_per_image)
recall_per_image = np.array(recall_per_image)
sensitivity_per_image = np.array(sensitivity_per_image)
dice_per_image = np.array(dice_per_image)

print("\n===== Per-image Test Results =====")
print(f"Accuracy:    mean={acc_per_image.mean():.4f}  std={acc_per_image.std():.4f}")
print(f"IoU:         mean={iou_per_image.mean():.4f}  std={iou_per_image.std():.4f}")
print(f"Precision:   mean={precision_per_image.mean():.4f}  std={precision_per_image.std():.4f}")
print(f"Recall:      mean={recall_per_image.mean():.4f}  std={recall_per_image.std():.4f}")
print(f"Sensitivity: mean={sensitivity_per_image.mean():.4f}  std={sensitivity_per_image.std():.4f}")
print(f"Dice:        mean={dice_per_image.mean():.4f}  std={dice_per_image.std():.4f}")

plt.figure(figsize=(12,5))
plt.plot(dice_per_image, label="Dice")
plt.plot(iou_per_image, label="IoU")
plt.plot(precision_per_image, label="Precision")
plt.plot(recall_per_image, label="Recall")
plt.plot(sensitivity_per_image, label="Sensitivity")
plt.plot(acc_per_image, label="Accuracy")

plt.xlabel("Test image index")
plt.ylabel("Metric value")
plt.ylim(-0.05, 1.05)
plt.legend()
plt.title("Per-image metrics on test set")
plt.grid(alpha=0.2)
plt.tight_layout()
plt.show()

n_show = 4
sorted_idx = np.argsort(dice_per_image)
worst_idx = sorted_idx[:n_show]
best_idx  = sorted_idx[-n_show:][::-1]

def show_cases(idxs, title_prefix="Case"):
    plt.figure(figsize=(4 * len(idxs), 4))
    for k, idx in enumerate(idxs):
        img = imgs_list[idx]
        pred = preds_list[idx]
        gt = gts_list[idx]
        path = os.path.basename(paths_list[idx])

        plt.subplot(1, len(idxs), k+1)
        if img is not None:
            plt.imshow(img)
            plt.imshow(pred, cmap='gray', alpha=0.45)
            plt.contour(gt, colors='r', linewidths=0.5)
        else:
            plt.imshow(pred, cmap='gray')
            plt.contour(gt, colors='r', linewidths=0.5)

        plt.title(f"{title_prefix}\n{path}\nDice={dice_per_image[idx]:.3f}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

print("\nShowing worst cases:", worst_idx.tolist())
show_cases(worst_idx, title_prefix="Worst")

print("\nShowing best cases:", best_idx.tolist())
show_cases(best_idx, title_prefix="Best")
