In [1]:
import os
import glob
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torch.nn.functional import interpolate

In [2]:
import torch.nn.functional as F
import math


class CAM(nn.Module):
    def __init__(self, channels, r=16):
        super(CAM, self).__init__()
        self.channels = channels
        self.r = r
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels // r, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(channels // r, channels, bias=True)
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        max_pool = F.adaptive_max_pool2d(x, (1, 1)).view(b, c)
        avg_pool = F.adaptive_avg_pool2d(x, (1, 1)).view(b, c)

        max_out = self.mlp(max_pool).view(b, c, 1, 1)
        avg_out = self.mlp(avg_pool).view(b, c, 1, 1)

        attention = torch.sigmoid(max_out + avg_out)
        return x * attention


class SAM(nn.Module):
    def __init__(self):
        super(SAM, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, stride=1, padding=3, bias=False)

    def forward(self, x):
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        attention = self.conv(torch.cat([max_pool, avg_pool], dim=1))
        return x * torch.sigmoid(attention)


class CBAM(nn.Module):
    def __init__(self, channels, r=16):
        super(CBAM, self).__init__()
        self.cam = CAM(channels, r)
        self.sam = SAM()

    def forward(self, x):
        x = self.cam(x)
        x = self.sam(x)
        return x


class DoubleConv(nn.Module):
    """ [(Conv2d) => (BN) => (ReLU)] * 2 + CBAM + Dropout """

    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()

        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.cbam = CBAM(out_channels)

        self.dropout = nn.Dropout2d(p=0.3)

        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.residual(x)
        x = self.double_conv(x)
        x = self.cbam(x)
        x = self.dropout(x)
        return x + identity  



class DownSample(nn.Module):
    """ MaxPool => DoubleConv """
    def __init__(self,in_channels,out_channels) -> None:
        super().__init__()
        self.down_sample = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels,out_channels)
        )
    def forward(self,x):
        x  = self.down_sample(x)
        return x



class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.conv(x)



class UpSample(nn.Module):
    def __init__(self,in_channels,out_channels,c:int) -> None:
        """ UpSample input tensor by a factor of `c`
                - the value of base 2 log c defines the number of upsample
                layers that will be applied
        """
        super().__init__()
        n = 0 if c == 0 else int(math.log(c,2))

        self.upsample = nn.ModuleList(
            [nn.ConvTranspose2d(in_channels,in_channels,2,2) for i in range(n)]
        )
        self.conv_3 = nn.Conv2d(in_channels,out_channels,3,padding="same",stride=1)

    def forward(self,x):
        for layer in self.upsample:
            x = layer(x)
        return self.conv_3(x)

class UpSample2(nn.Module):
    def __init__(self,in_channels,out_channels,c:int) -> None:
        """ UpSample input tensor by a factor of `c`
                - the value of base 2 log c defines the number of upsample
                layers that will be applied
        """
        super().__init__()
        n = 0 if c == 0 else int(math.log(c,2))
        #print(f'LOG OF C:  {int(math.log(c,2))}')

        self.upsample = nn.ModuleList(
            [nn.ConvTranspose2d(in_channels,in_channels,2,2) for i in range(n)]
        )
        self.conv_3 = nn.Conv2d(in_channels,out_channels,3,padding="same",stride=1)

    def forward(self,x):
        for layer in self.upsample:
            #print(f'BEFORE UPSAMPEL: {x.shape}')
            x = layer(x)
            #print(f'After Transpose2D: {x.shape}')
            x = self.conv_3(x)
            #print(f'After Conv2D: {x.shape}')
        return x#self.conv_3(x)

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, n: int = 8) -> None:
        """
        Construct the J-net model.
        Args:
            in_channels: The number of color channels of the input image. 0:for binary 3: for RGB
            out_channels: The number of color channels of the input mask, corresponds to the number
                            of classes.Includes the background
            n: Channels size of the first CNN in the encoder layer. The bigger this value the bigger
                the number of parameters of the model. Defaults to n = 8, which is recommended by the
                authors of the paper.
        """
        super().__init__()
        # ------ Input convolution --------------
        self.in_conv = DoubleConv(in_channels, n)
        # -------- Encoder ----------------------
        self.down_1 = DownSample(n, 2 * n)
        self.down_2 = DownSample(2 * n, 4 * n)
        self.down_3 = DownSample(4 * n, 8 * n)
        self.down_4 = DownSample(8 * n, 16 * n)
        

        # -------- Upsampling ------------------ias=Tru
        self.up_1024_512 = UpSample(16 * n, 8 * n, 2)

        self.up_512_64 = UpSample(8 * n, n, 8)
        
        #self.up_512_64_more = UpSample(8 * n, n, 16)
        
        
        self.up_512_128 = UpSample(8 * n, 2 * n, 4)
        self.up_512_256 = UpSample(8 * n, 4 * n, 2)
        self.up_512_512 = UpSample(8 * n, 8 * n, 0)

        self.up_256_64 = UpSample(4 * n, n, 4)
        #self.up_256_64_more = UpSample(4 * n, n, 8)
        self.up_256_128 = UpSample(4 * n, 2 * n, 2)
        self.up_256_256 = UpSample(4 * n, 4 * n, 0)

        self.up_128_64 = UpSample(2 * n, n, 2)
        #self.up_128_64_more = UpSample(2 * n, n, 4)
        self.up_128_128 = UpSample(2 * n, 2 * n, 0)

        self.up_64_64 = UpSample(n, n, 0)
        m = int(n *0.5)
        
        self.up_64_32 =  UpSample2(out_channels, out_channels, 2)
        self.up_16_64 = UpSample(out_channels, out_channels, 2)
        
  
        self.up_skip_8 =  UpSample2(out_channels, out_channels, 8)
        self.up_skip_4 =  UpSample2(out_channels, out_channels, 4)
        
        
        self.up_concat =  UpSample2(2*out_channels, out_channels, 2)

        # ------ Decoder block ---------------
        self.dec_4 = DoubleConv(2 * 8 * n, 8 * n)
        self.dec_3 = DoubleConv(3 * 4 * n, 4 * n)
        self.dec_2 = DoubleConv(4 * 2 * n, 2 * n)
        self.dec_1 = DoubleConv(5 * n, n )
         # ------ Output convolution

        self.out_conv = OutConv(n, out_channels)
        
        

    def forward(self, x):
        x = self.in_conv(x)  # 64
        # ---- Encoder outputs
        x_enc_1 = self.down_1(x)  # 128
        x_enc_2 = self.down_2(x_enc_1)  # 256
        x_enc_3 = self.down_3(x_enc_2)  # 512
        x_enc_4 = self.down_4(x_enc_3)  # 1024

        # ------ decoder outputs
        x_up_1 = self.up_1024_512(x_enc_4)
        x_dec_4 = self.dec_4(torch.cat([x_up_1, self.up_512_512(x_enc_3)], dim=1))

        x_up_2 = self.up_512_256(x_dec_4)
        x_dec_3 = self.dec_3(torch.cat([x_up_2,
                                        self.up_512_256(x_enc_3),
                                        self.up_256_256(x_enc_2)
                                        ],
                                       dim=1))

        x_up_3 = self.up_256_128(x_dec_3)
        x_dec_2 = self.dec_2(torch.cat([
            x_up_3,
            self.up_512_128(x_enc_3),
            self.up_256_128(x_enc_2),
            self.up_128_128(x_enc_1)
        ], dim=1))

        x_up_4 = self.up_128_64(x_dec_2)
        x_dec_1 = self.dec_1(torch.cat([
            x_up_4,
            self.up_512_64(x_enc_3),
            self.up_256_64(x_enc_2),
            self.up_128_64(x_enc_1),
            self.up_64_64(x)
        ], dim=1))

        out1 =  self.out_conv(x_dec_1) # 16 / 32 / 80
        
        out2 = self.up_64_32(out1) # 32   / 64 / 160
        out3 = self.up_16_64(out2) # 64   / 128 / 320
        
        
        
        
        out1_skip_4 = self.up_skip_4(out1)      
        out1_skip_8 = self.up_skip_8(out1)  
        

        out4 = self.up_concat(torch.cat([out1_skip_4,out3],dim=1))
        out5 = self.up_concat(torch.cat([out1_skip_8,out4],dim=1))
        
       

        return out1, out2, out3, out4, out5
        
def save_net(ckpt_dir, net, optim, epoch):
    os.makedirs(ckpt_dir, exist_ok=True)
    path = os.path.join(ckpt_dir, f'model_epoch_{epoch}.pth')
    torch.save({'epoch': epoch, 'model': net.state_dict(), 'optim': optim.state_dict()}, path)

def load_net(ckpt_dir, net, optim):
    if not os.path.exists(ckpt_dir): return net, optim, 0
    files = sorted(glob.glob(os.path.join(ckpt_dir, '*.pth')))
    if not files: return net, optim, 0
    checkpoint = torch.load(files[-1])
    net.load_state_dict(checkpoint['model'])
    optim.load_state_dict(checkpoint['optim'])
    return net, optim, checkpoint['epoch'] + 1


In [3]:
import os
import glob
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import random

# === Paths and Config ===
ROOT_DIR = "/kaggle/working"
DATA_DIR = os.path.join(ROOT_DIR, 'prostate_data')
CKPT_DIR = os.path.join(ROOT_DIR, 'checkpoints_jnet_16')
LOG_DIR = os.path.join(ROOT_DIR, 'logs')
IMGS_DIR = os.path.join(DATA_DIR, 'imgs_r')
LABELS_DIR = os.path.join(DATA_DIR, 'labels')

# === Create directories ===
for path in [
    CKPT_DIR, LOG_DIR, os.path.join(LOG_DIR, 'train'), os.path.join(LOG_DIR, 'val'),
    os.path.join(IMGS_DIR, 'train'), os.path.join(IMGS_DIR, 'val'), os.path.join(IMGS_DIR, 'test'),
    os.path.join(LABELS_DIR, 'train'), os.path.join(LABELS_DIR, 'val'), os.path.join(LABELS_DIR, 'test')
]:
    os.makedirs(path, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# === Config ===
class Config:
    LEARNING_RATE = 1e-3
    BATCH_SIZE = 64
    NUM_EPOCHS = 100
    L1_LAMBDA = 1e-5
    L2_LAMBDA = 1e-5
    IOU_LAMBDA = 1.0
    DICE_LAMBDA = 1.0

cfg = Config()

# === Dataset ===
class CustomDataset(Dataset):
    def __init__(self, imgs_dir, labels_dir, transform=None):
        self.imgs = sorted(glob.glob(os.path.join(imgs_dir, '*.jpg')))
        self.labels = sorted(glob.glob(os.path.join(labels_dir, '*.jpg')))
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        # Load RGB image
        img = cv2.imread(self.imgs[index], cv2.IMREAD_COLOR) / 255.0  # shape: (H, W, 3)
        
        # Load label as grayscale binary mask
        label = cv2.imread(self.labels[index], cv2.IMREAD_GRAYSCALE) / 255.0  # shape: (H, W)

        # Resize for deep supervision
        img_d = cv2.resize(img, (16, 16), interpolation=cv2.INTER_NEAREST)
        label_d = cv2.resize(label, (16, 16), interpolation=cv2.INTER_NEAREST)

        # Resize main
        img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_NEAREST)
        label = cv2.resize(label, (256, 256), interpolation=cv2.INTER_NEAREST)

        sample = {
            'img': img,
            'label': label[:, :, np.newaxis],  # add channel dim
            'img_d': img_d,
            'label_d': label_d[:, :, np.newaxis],
            'id': os.path.basename(self.imgs[index])[:-4]
        }

        if self.transform:
            sample = self.transform(sample)

        return sample

# === Transforms ===
class ToTensor:
    def __call__(self, data):
        for k in ['img', 'img_d']:
            data[k] = torch.tensor(data[k].transpose(2, 0, 1), dtype=torch.float32)
        for k in ['label', 'label_d']:
            data[k] = torch.tensor(data[k].transpose(2, 0, 1), dtype=torch.float32)
        return data

class RandomFlip:
    def __call__(self, data):
        if np.random.rand() > 0.5:
            for k in ['img', 'label', 'img_d', 'label_d']:
                data[k] = np.flip(data[k], axis=1).copy()
        if np.random.rand() > 0.5:
            for k in ['img', 'label', 'img_d', 'label_d']:
                data[k] = np.flip(data[k], axis=0).copy()
        return data

class RandomRotate:
    def __call__(self, data):
        angle = random.choice([0, 90, 180, 270])
        for k in ['img', 'label', 'img_d', 'label_d']:
            data[k] = np.rot90(data[k], k=angle // 90).copy()
        return data

class RandomBrightness:
    def __call__(self, data):
        factor = 0.8 + np.random.uniform(0, 0.4)  # Between 0.8 and 1.2
        data['img'] = np.clip(data['img'] * factor, 0, 1)
        data['img_d'] = np.clip(data['img_d'] * factor, 0, 1)
        return data

class AddGaussianNoise:
    def __call__(self, data):
        noise = np.random.normal(0, 0.02, data['img'].shape)
        data['img'] = np.clip(data['img'] + noise, 0, 1)
        
        noise_d = np.random.normal(0, 0.02, data['img_d'].shape)
        data['img_d'] = np.clip(data['img_d'] + noise_d, 0, 1)
        return data


train_transform = transforms.Compose([
    RandomFlip(),
    RandomRotate(),
    RandomBrightness(),
    AddGaussianNoise(),
    ToTensor()
])

val_transform = transforms.Compose([
    ToTensor()
])

# === Loaders ===
train_set = CustomDataset(
    imgs_dir="/kaggle/input/kasvir-svg-train-val/train/images",
    labels_dir="/kaggle/input/kasvir-svg-train-val/train/masks",
    transform=train_transform
)

val_set = CustomDataset(
    imgs_dir="/kaggle/input/kasvir-svg-train-val/val/images",
    labels_dir="/kaggle/input/kasvir-svg-train-val/val/masks",
    transform=val_transform
)

train_loader = DataLoader(train_set, batch_size=cfg.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=cfg.BATCH_SIZE)


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import os
import numpy as np

# === Dice Loss & Score ===
def dice_loss(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum(dim=(1, 2, 3))
    dice = (2. * intersection + smooth) / (pred.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3)) + smooth)
    return 1 - dice.mean()

def dice_score(pred, target, smooth=1e-6):
    pred = (torch.sigmoid(pred) > 0.5).float()
    intersection = (pred * target).sum(dim=(1, 2, 3))
    score = (2. * intersection + smooth) / (pred.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3)) + smooth)
    return score.mean().item()

# === IoU Loss & Score ===
def iou_loss(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = (pred + target - pred * target).sum(dim=(1, 2, 3))
    iou = (intersection + smooth) / (union + smooth)
    return 1 - iou.mean()

def iou_score(pred, target, smooth=1e-6):
    pred = (torch.sigmoid(pred) > 0.5).float()
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = (pred + target - pred * target).sum(dim=(1, 2, 3))
    return ((intersection + smooth) / (union + smooth)).mean().item()

# === Model, Loss, Optimizer ===
net = UNet(3, 1, 64).to(device)
loss_fn = nn.BCEWithLogitsLoss().to(device)
optim = torch.optim.Adam(net.parameters(), lr=cfg.LEARNING_RATE)  # L2 regularization removed

train_writer = SummaryWriter(os.path.join(LOG_DIR, 'train'))
val_writer = SummaryWriter(os.path.join(LOG_DIR, 'val'))
start_epoch = 0

# === Training Loop ===
for epoch in range(start_epoch, cfg.NUM_EPOCHS):
    net.train()
    train_losses, bce_losses, dice_losses, iou_losses = [], [], [], []
    dice_scores, iou_scores = [], []

    for batch in train_loader:
        img = batch['img_d'].to(device)
        label = batch['label'].to(device)
        label_d = batch['label_d'].to(device)

        output1, output2, output3, output4, output5 = net(img)

        resized = [
            F.interpolate(label, size=(32, 32), mode='bilinear', align_corners=False),
            F.interpolate(label, size=(64, 64), mode='bilinear', align_corners=False),
            F.interpolate(label, size=(128, 128), mode='bilinear', align_corners=False)
        ]

        # === Losses ===
        bce_loss = (
            loss_fn(output1, label_d) +
            loss_fn(output2, resized[0]) +
            loss_fn(output3, resized[1]) +
            loss_fn(output4, resized[2]) +
            loss_fn(output5, label)
        )

        dice = dice_loss(output5, label)
        iou = iou_loss(output5, label)

        # Total Loss (without L1 and L2 regularization)
        loss = bce_loss + cfg.DICE_LAMBDA * dice + cfg.IOU_LAMBDA * iou

        # Backward
        optim.zero_grad()
        loss.backward()
        optim.step()

        # === Logging individual losses ===
        train_losses.append(loss.item())
        bce_losses.append(bce_loss.item())
        dice_losses.append(dice.item())
        iou_losses.append(iou.item())
        dice_scores.append(1 - dice.item())
        iou_scores.append(1 - iou.item())

        print(f'[Train] Epoch {epoch}/{cfg.NUM_EPOCHS} | Total: {loss.item():.4f} | BCE: {bce_loss.item():.4f} | Dice: {1 - dice.item():.4f} | IoU: {1 - iou.item():.4f}')

    # === TensorBoard Logging ===
    train_writer.add_scalar('train/total_loss', np.mean(train_losses), epoch)
    train_writer.add_scalar('train/bce_loss', np.mean(bce_losses), epoch)
    train_writer.add_scalar('train/dice_loss', np.mean(dice_losses), epoch)
    train_writer.add_scalar('train/iou_loss', np.mean(iou_losses), epoch)
    train_writer.add_scalar('train/dice_score', np.mean(dice_scores), epoch)
    train_writer.add_scalar('train/iou_score', np.mean(iou_scores), epoch)

    # === Validation ===
    net.eval()
    val_losses, val_dice_scores, val_iou_scores = [], [], []

    with torch.no_grad():
        for batch in val_loader:
            img = batch['img_d'].to(device)
            label = batch['label'].to(device)
            id_list = batch['id']

            outputs = net(img)
            output_final = outputs[-1]
            val_loss = loss_fn(output_final, label)

            val_losses.append(val_loss.item())
            val_dice_scores.append(dice_score(output_final, label))
            val_iou_scores.append(iou_score(output_final, label))

            if epoch == cfg.NUM_EPOCHS - 1:
                os.makedirs('./pred_out/', exist_ok=True)
                for i, id in enumerate(id_list):
                    np.save(f'./pred_out/{id}.npy', output_final[i].cpu().numpy())

            print(f'[Val] Epoch {epoch} | BCE Loss: {val_loss.item():.4f} | Dice: {val_dice_scores[-1]:.4f} | IoU: {val_iou_scores[-1]:.4f}')

    val_writer.add_scalar('val/bce_loss', np.mean(val_losses), epoch)
    val_writer.add_scalar('val/dice_score', np.mean(val_dice_scores), epoch)
    val_writer.add_scalar('val/iou_score', np.mean(val_iou_scores), epoch)

    # === Summary ===
    print(f'==> Epoch {epoch}')
    print(f'    Train Total Loss : {np.mean(train_losses):.4f} | BCE: {np.mean(bce_losses):.4f} | Dice score: {np.mean(dice_scores):.4f} | IoU score: {np.mean(iou_scores):.4f}')
    print(f'    Val   BCE Loss   : {np.mean(val_losses):.4f} | Dice score: {np.mean(val_dice_scores):.4f} | IoU score : {np.mean(val_iou_scores):.4f}')

    if epoch % 5 == 0 or epoch in {1, 4}:
        save_net(CKPT_DIR, net, optim, epoch)

train_writer.close()
val_writer.close()


[Train] Epoch 0/100 | Total: 4.9226 | BCE: 3.2209 | Dice: 0.1892 | IoU: 0.1091
[Train] Epoch 0/100 | Total: 4.7828 | BCE: 3.0944 | Dice: 0.1969 | IoU: 0.1147
[Train] Epoch 0/100 | Total: 4.6456 | BCE: 2.9849 | Dice: 0.2147 | IoU: 0.1246
[Train] Epoch 0/100 | Total: 4.9716 | BCE: 3.3010 | Dice: 0.2071 | IoU: 0.1222
[Train] Epoch 0/100 | Total: 4.5983 | BCE: 2.9249 | Dice: 0.2069 | IoU: 0.1197
[Train] Epoch 0/100 | Total: 4.6100 | BCE: 2.9317 | Dice: 0.2027 | IoU: 0.1189
[Train] Epoch 0/100 | Total: 4.6095 | BCE: 2.9190 | Dice: 0.1955 | IoU: 0.1140
[Train] Epoch 0/100 | Total: 4.5811 | BCE: 2.8940 | Dice: 0.1983 | IoU: 0.1146
[Train] Epoch 0/100 | Total: 4.5475 | BCE: 2.8501 | Dice: 0.1920 | IoU: 0.1106
[Train] Epoch 0/100 | Total: 4.5167 | BCE: 2.7955 | Dice: 0.1775 | IoU: 0.1012
[Train] Epoch 0/100 | Total: 4.5576 | BCE: 2.9002 | Dice: 0.2161 | IoU: 0.1265
[Train] Epoch 0/100 | Total: 4.5431 | BCE: 2.8751 | Dice: 0.2100 | IoU: 0.1221
[Train] Epoch 0/100 | Total: 4.4742 | BCE: 2.7678 | 