In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2


class LumbarDataset(Dataset):
    def __init__(self, df, image_dir, mask_dir, transforms=None):
        self.df = df.reset_index()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        image_path = os.path.join(self.image_dir, row.image)
        mask_path = os.path.join(self.mask_dir, row.image)

        # Open image
        image = Image.open(image_path).convert('L')
        image = np.asarray(image, dtype=np.float32) / 255.0

        # Open mask
        mask = Image.open(mask_path)
        mask = np.asarray(mask, dtype=np.uint8)
        mask = np.where(mask <= 5, mask, 0)  # Considering only L5, L4, L3, L2, L1 vertebrae

        if self.transforms:
            transformed = self.transforms(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        
        mask = torch.nn.functional.one_hot(torch.tensor(mask, dtype=torch.long), num_classes=6).permute(2, 0, 1).float()
        image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  # Adding channel dimension

        return image, mask

# Define transformations
train_dataset = LumbarDataset(root_dir=TRAIN_DIR)
val_dataset = LumbarDataset(root_dir=VAL_DIR)
transforms_train = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(),
    A.Normalize(mean=[0.5], std=[0.5]),
    ToTensorV2()
])

transforms_valid = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=[0.5], std=[0.5]),
    ToTensorV2()
])


# Define the Generator and Discriminator (Pix2Pix Model)
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(features * 2, features * 4, down=True, act="leaky", use_dropout=False)
        self.down3 = Block(features * 4, features * 8, down=True, act="leaky", use_dropout=False)
        self.down4 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
        self.down5 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
        self.down6 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True)
        self.up3 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True)
        self.up4 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False)
        self.up5 = Block(features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False)
        self.up6 = Block(features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False)
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))

# Define the Training Function
def train_pix2pix(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler, device):
    loop = tqdm(loader, leave=True)
    for idx, (x, y) in enumerate(loop):
        x = x.to(device)
        y = y.to(device)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * 100  # Adjust L1 regularization strength
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(D_real=torch.sigmoid(D_real).mean().item(), D_fake=torch.sigmoid(D_fake).mean().item())


# Define the Discriminator class
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels * 2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

# Main function to train Pix2Pix
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    disc = Discriminator(in_channels=3).to(device)
    gen = Generator(in_channels=3, features=64).to(device)
    opt_disc = optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(50):
        train_pix2pix(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler, device)

if __name__ == "__main__":
    main()


NameError: name 'TRAIN_DIR' is not defined

In [7]:
!python3 -m pip install tqdm torch torchvision albumentations

Collecting albumentations
  Downloading albumentations-2.0.5-py3-none-any.whl.metadata (41 kB)
Collecting PyYAML (from albumentations)
  Downloading PyYAML-6.0.2-cp311-cp311-win_amd64.whl.metadata (2.1 kB)
Collecting pydantic>=2.9.2 (from albumentations)
  Downloading pydantic-2.11.3-py3-none-any.whl.metadata (65 kB)
Collecting albucore==0.0.23 (from albumentations)
  Downloading albucore-0.0.23-py3-none-any.whl.metadata (5.3 kB)
Collecting opencv-python-headless>=4.9.0.80 (from albumentations)
  Downloading opencv_python_headless-4.11.0.86-cp37-abi3-win_amd64.whl.metadata (20 kB)
Collecting stringzilla>=3.10.4 (from albucore==0.0.23->albumentations)
  Downloading stringzilla-3.12.3-cp311-cp311-win_amd64.whl.metadata (81 kB)
Collecting simsimd>=5.9.2 (from albucore==0.0.23->albumentations)
  Downloading simsimd-6.2.1-cp311-cp311-win_amd64.whl.metadata (67 kB)
Collecting annotated-types>=0.6.0 (from pydantic>=2.9.2->albumentations)
  Downloading annotated_types-0.7.0-py3-none-any.whl.me


[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python3.exe -m pip install --upgrade pip
