In [None]:
# !pip install google-colab
!pip install opencv-python
!pip install -U albumentations
from google.colab import drive
drive.mount('/content/drive')


Collecting opencv-python
  Downloading opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (19 kB)
Downloading opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (67.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: opencv-python
Successfully installed opencv-python-4.12.0.88
Collecting albumentations
  Downloading albumentations-2.0.8-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.1/43.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting albucore==0.0.24 (from albumentations)
  Downloading albucore-0.0.24-py3-none-any.whl.metadata (5.3 kB)
Collecting opencv-python-headless>=4.9.0.80 (from albumentations)
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (19 kB)
Collecting stringzilla>=3.

In [None]:
import zipfile
import os

zip_path = "/content/drive/MyDrive/mri_data_png.zip"
extract_path = "/content/mri_data_png"

os.makedirs(extract_path, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Extracted to:", extract_path)


Extracted to: /content/mri_data_png


In [None]:
import torch
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False, norm="instance"):
        super(Block, self).__init__()

        conv = (
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down else
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
        )

        # Choose normalization
        if norm == "instance":
            norm_layer = nn.InstanceNorm2d(out_channels, affine=True)
        elif norm == "batch":
            norm_layer = nn.BatchNorm2d(out_channels)
        elif norm is None:
            norm_layer = nn.Identity()
        else:
            raise ValueError(f"Unknown norm='{norm}'. Use 'instance', 'batch', or None.")

        act_layer = nn.ReLU(inplace=True) if act == "relu" else nn.LeakyReLU(0.2, inplace=True)

        self.conv = nn.Sequential(conv, norm_layer, act_layer)
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout2d(0.5) 

        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, in_ch=1, out_ch=6, features=64):
        super().__init__()
        # Encoder: 256->128->64->32->16->8->4->2  (initial_down + 6 downs)
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_ch, features, 4, 2, 1, padding_mode="reflect"),  # 256 -> 128
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.down1 = Block(features,       features * 2, down=True, act="leaky")   # 128 -> 64
        self.down2 = Block(features * 2,   features * 4, down=True, act="leaky")   # 64  -> 32
        self.down3 = Block(features * 4,   features * 8, down=True, act="leaky")   # 32  -> 16
        self.down4 = Block(features * 8,   features * 8, down=True, act="leaky")   # 16  -> 8
        self.down5 = Block(features * 8,   features * 8, down=True, act="leaky")   # 8   -> 4
        self.down6 = Block(features * 8,   features * 8, down=True, act="leaky")   # 4   -> 2

        # Bottleneck: 2x2 -> 1x1
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1),  # 2 -> 1
            nn.ReLU(inplace=True),
        )

        self.up1 = Block(features * 8,       features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(features * 8 * 2,   features * 8, down=False, act="relu", use_dropout=True)       
        self.up3 = Block(features * 8 * 2,   features * 8, down=False, act="relu", use_dropout=True)       
        self.up4 = Block(features * 8 * 2,   features * 8, down=False, act="relu")                         
        self.up5 = Block(features * 8 * 2,   features * 4, down=False, act="relu")                      
        self.up6 = Block(features * 4 * 2,   features * 2, down=False, act="relu")                          
        self.up7 = Block(features * 2 * 2,   features,     down=False, act="relu")                       

        # Final: 128 -> 256, logits for 6 classes
        self.final_up = nn.ConvTranspose2d(features * 2, out_ch, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        d1 = self.initial_down(x)   
        d2 = self.down1(d1)         
        d3 = self.down2(d2)    
        d4 = self.down3(d3) 
        d5 = self.down4(d4)         
        d6 = self.down5(d5)       
        d7 = self.down6(d6)        

        b  = self.bottleneck(d7)    # 1

        u1 = self.up1(b)                        
        u2 = self.up2(torch.cat([u1, d7], 1))     
        u3 = self.up3(torch.cat([u2, d6], 1))      
        u4 = self.up4(torch.cat([u3, d5], 1))     
        u5 = self.up5(torch.cat([u4, d4], 1))   
        u6 = self.up6(torch.cat([u5, d3], 1))     
        u7 = self.up7(torch.cat([u6, d2], 1))     

        logits = self.final_up(torch.cat([u7, d1], 1)) 
        return logits


In [None]:
G = Generator(in_ch=1, out_ch=6, features=64)

# Dummy input
x = torch.randn(2, 1, 256, 256)

# Forward pass
y = G(x)

print("Input shape :", x.shape)
print("Output shape:", y.shape)

Input shape : torch.Size([2, 1, 256, 256])
Output shape: torch.Size([2, 6, 256, 256])


In [None]:
import torch
import torch.nn as nn

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


import torch
import torch.nn as nn
import torch.nn.functional as F

class Discriminator(nn.Module):
    def __init__(self, in_ch_x=1, in_ch_y=6, features=(64, 128, 256, 512)):
        super().__init__()
        self.in_ch_x = in_ch_x
        self.in_ch_y = in_ch_y

        in_pair = in_ch_x + in_ch_y

        self.initial = nn.Sequential(
            nn.Conv2d(in_pair, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
        )

        blocks = []
        in_c = features[0]
        for f in features[1:]:
            stride = 1 if f == features[-1] else 2
            blocks.append(nn.Sequential(
                nn.Conv2d(in_c, f, kernel_size=4, stride=stride, padding=1, bias=False, padding_mode="reflect"),
                nn.InstanceNorm2d(f, affine=True),
                nn.LeakyReLU(0.2, inplace=True),
            ))
            in_c = f

        # Final 1-channel conv -> patch score map
        blocks.append(nn.Conv2d(in_c, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))

        self.model = nn.Sequential(*blocks)

    def _ensure_mask_channels(self, y):
        """
        Convert (N,H,W) index mask to (N,C,H,W) one-hot if needed.
        If y already has shape (N,C,H,W), return as-is.
        """
        if y.dim() == 3:
            # y is class indices
            if not (y.dtype == torch.long or y.dtype == torch.int64):
                y = y.long()
            y = F.one_hot(y, num_classes=self.in_ch_y).permute(0, 3, 1, 2).float()
        elif y.dim() == 4:
            if y.size(1) != self.in_ch_y:
                raise ValueError(f"Expected mask with {self.in_ch_y} channels, got {y.size(1)}.")
            if not y.is_floating_point():
                y = y.float()
        else:
            raise ValueError(f"Mask y must be (N,H,W) or (N,C,H,W), got shape {tuple(y.shape)}.")
        return y

    def forward(self, x, y):
        """
        x: (N, in_ch_x, H, W)
        y: (N, in_ch_y, H, W)  or  (N, H, W) indices
        """
        y = self._ensure_mask_channels(y)
        pair = torch.cat([x, y], dim=1)        # (N, in_ch_x+in_ch_y, H, W)
        h = self.initial(pair)
        out = self.model(h)
        return out
# quick sanity test for your shapes
if __name__ == "__main__":
    N, H, W = 2, 256, 256
    D = Discriminator(in_ch_x=1, in_ch_y=6)

    x = torch.randn(N, 1, H, W)         
    y_idx = torch.randint(0, 6, (N, H, W))    # indices
    out_real = D(x, y_idx)
    print("D(x, y_idx) ->", out_real.shape)

    y_probs = torch.softmax(torch.randn(N, 6, H, W), dim=1)
    out_fake = D(x, y_probs)
    print("D(x, y_probs) ->", out_fake.shape)

D(x, y_idx) -> torch.Size([2, 1, 30, 30])
D(x, y_probs) -> torch.Size([2, 1, 30, 30])


In [None]:
import os, glob, random, cv2, numpy as np, torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

def collect_ids(root="mri_data", img_dir="images", mask_dir="masks"):
    imgs  = {os.path.splitext(os.path.basename(p))[0]: p
             for p in glob.glob(os.path.join(root, img_dir, "*.png"))}
    masks = {os.path.splitext(os.path.basename(p))[0]: p
             for p in glob.glob(os.path.join(root, mask_dir, "*.png"))}
    ids = sorted(list(set(imgs.keys()) & set(masks.keys())))
    if not ids:
        raise RuntimeError("No matching image/mask basenames found.")
    return ids, imgs, masks

def split_ids(ids, val_ratio=0.1, seed=42):
    random.Random(seed).shuffle(ids)
    n_val = max(1, int(len(ids) * val_ratio))
    return ids[n_val:], ids[:n_val]  # train_ids, val_ids

# ---------- dataset ----------
class SpinePNG(Dataset):
    def __init__(self, ids, img_map, mask_map):
        self.ids = ids
        self.img_map = img_map
        self.mask_map = mask_map
        self.tf = A.Compose([
            A.Resize(256, 256),
            A.HorizontalFlip(p=0.5),
            A.Normalize(mean=(0.5,), std=(0.5,)),
            ToTensorV2()
        ])

    def __len__(self): return len(self.ids)

    def __getitem__(self, i):
        bid = self.ids[i]
        img = cv2.imread(self.img_map[bid], cv2.IMREAD_UNCHANGED)
        msk = cv2.imread(self.mask_map[bid], cv2.IMREAD_UNCHANGED)

        if img.ndim == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        if img.dtype == np.uint16:
            img = img.astype(np.float32) / 65535.0
        else:
            img = img.astype(np.float32) / 255.0

        # albumentations expects HxW
        aug = self.tf(image=img, mask=msk.astype(np.int64))
        x = aug["image"]      
        if x.ndim == 2:          
            x = x.unsqueeze(0)
        y = aug["mask"].long()    
        return x, y

def make_loaders(root="mri_data", batch_size=4, val_ratio=0.1, num_workers=4, seed=42):
    ids, img_map, mask_map = collect_ids(root)
    train_ids, val_ids = split_ids(ids, val_ratio=val_ratio, seed=seed)
    train_ds = SpinePNG(train_ids, img_map, mask_map)
    val_ds   = SpinePNG(val_ids,   img_map, mask_map)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=max(1, num_workers//2), pin_memory=True)
    return train_loader, val_loader

train_loader, val_loader = make_loaders(root="/content/mri_data_png/data", batch_size=4, val_ratio=0.1)

for xb, yb in train_loader:
    print(xb.shape, yb.shape)  # expect (N,1,256,256) and (N,256,256)
    break
print(len(train_loader.dataset))  # should be ~231
print(len(val_loader.dataset))    # should be ~26



torch.Size([4, 1, 256, 256]) torch.Size([4, 256, 256])
3182
353


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

NUM_CLASSES = 20
bce = nn.BCEWithLogitsLoss()
ce  = nn.CrossEntropyLoss()

def dice_loss_from_logits(logits, target_idx, eps=1e-6):
    p = F.softmax(logits, dim=1)                           
    t = F.one_hot(target_idx, NUM_CLASSES).permute(0,3,1,2).float()
    num = 2 * (p * t).sum((0,2,3))
    den = (p*p).sum((0,2,3)) + (t*t).sum((0,2,3)) + eps
    return 1 - (num/den).mean()

def train(gen, disc, train_loader, val_loader, device="cuda", epochs=20, lr=2e-4, lambda_seg=20.0,
          save_each_epoch=False):
    gen.to(device); disc.to(device)

    opt_g = torch.optim.Adam(gen.parameters(),  lr=lr, betas=(0.5, 0.999))
    opt_d = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

    amp_enabled = device.startswith("cuda") and torch.cuda.is_available()
    scaler_g = torch.cuda.amp.GradScaler(enabled=amp_enabled)
    scaler_d = torch.cuda.amp.GradScaler(enabled=amp_enabled)

    for epoch in range(1, epochs+1):
        gen.train(); disc.train()
        pbar = tqdm(train_loader, desc=f"epoch {epoch}/{epochs}", ncols=100, leave=False)

        for x, y_idx in pbar:
            x = x.to(device, non_blocking=True)
            y_idx = y_idx.to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=amp_enabled):
                with torch.no_grad():
                    logits_fake = gen(x)
                    probs_fake  = F.softmax(logits_fake, dim=1)  
                y_real = F.one_hot(y_idx, NUM_CLASSES).permute(0,3,1,2).float()
                d_real = disc(x, y_real)
                d_fake = disc(x, probs_fake.detach())
                loss_d = bce(d_real, torch.ones_like(d_real)) + bce(d_fake, torch.zeros_like(d_fake))

            opt_d.zero_grad(set_to_none=True)
            scaler_d.scale(loss_d).backward()
            scaler_d.step(opt_d); scaler_d.update()

            # ---- G step ----
            with torch.cuda.amp.autocast(enabled=amp_enabled):
                logits = gen(x)
                probs  = F.softmax(logits, dim=1)
                d_fake_for_g = disc(x, probs)
                gan_loss = bce(d_fake_for_g, torch.ones_like(d_fake_for_g))
                seg_loss = ce(logits, y_idx) + dice_loss_from_logits(logits, y_idx)
                loss_g = gan_loss + lambda_seg * seg_loss

            opt_g.zero_grad(set_to_none=True)
            scaler_g.scale(loss_g).backward()
            scaler_g.step(opt_g); scaler_g.update()

            pbar.set_postfix(D=f"{loss_d.item():.3f}",
                              G=f"{loss_g.item():.3f}",
                              CE=f"{ce(logits, y_idx).item():.3f}")

        gen.eval()
        ce_sum, dice_sum, n = 0.0, 0.0, 0
        with torch.inference_mode():
            for x, y_idx in val_loader:
                x = x.to(device); y_idx = y_idx.to(device)
                logits = gen(x)
                ce_sum   += ce(logits, y_idx).item()
                dice_sum += (1.0 - dice_loss_from_logits(logits, y_idx).item())
                n += 1

        avg_ce   = ce_sum / max(n, 1)
        avg_dice = dice_sum / max(n, 1)
        print(f"[epoch {epoch}] val CE: {avg_ce:.3f} | val Dice: {avg_dice:.3f}")

        if save_each_epoch:
            torch.save(gen.state_dict(), f"gen_epoch{epoch}.pth")

    return gen


In [None]:
gen  = Generator(in_ch=1, out_ch=NUM_CLASSES, features=64)
disc = Discriminator(in_ch_x=1, in_ch_y=NUM_CLASSES)

model = train(gen, disc, train_loader, val_loader, device="cpu", epochs=20, lr=2e-4, lambda_seg=20.0)


  scaler_g = torch.cuda.amp.GradScaler(enabled=amp_enabled)
  scaler_d = torch.cuda.amp.GradScaler(enabled=amp_enabled)
  with torch.cuda.amp.autocast(enabled=amp_enabled):
  with torch.cuda.amp.autocast(enabled=amp_enabled):


[epoch 1] val CE: 0.213 | val Dice: 0.468




[epoch 2] val CE: 0.181 | val Dice: 0.588




[epoch 3] val CE: 0.134 | val Dice: 0.668




[epoch 4] val CE: 0.199 | val Dice: 0.666


epoch 5/20:  64%|████████████▊       | 509/796 [10:43<05:58,  1.25s/it, CE=0.056, D=0.044, G=12.732]