# Installing libraries

In [1]:
# Install required libraries and packages.

! pip install pytorch-lightning==2.1.0
! pip install hydra-core --upgrade
! pip install torcheval



# Download data and model CT

In [2]:
!gdown --id 1udpraFyj0DMWsxFuA5qlXEuKTncmVrfn
!unzip -q data.zip
!rm data.zip

Downloading...
From (original): https://drive.google.com/uc?id=1QE-X4z1xT2_xDszWH47F_5bl4jCaqxVm
From (redirected): https://drive.google.com/uc?id=1QE-X4z1xT2_xDszWH47F_5bl4jCaqxVm&confirm=t&uuid=5e456d70-095c-4226-897f-14dd49ef4952
To: /content/data.zip
100% 95.8M/95.8M [00:01<00:00, 59.4MB/s]


# Model and dataloader

In [3]:
from torch.utils.data import Dataset
from PIL import Image
import os


class SegmentationImageFolder(Dataset):
    def __init__(self, dataset_path, transform):
        self.image_dir = dataset_path + "/data"
        self.mask_dir = dataset_path + "/mask"
        self.transform = transform
        self.images = os.listdir(self.image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        image = Image.open(image_path).convert("L")
        mask = Image.open(mask_path).convert("L")
        return self.transform(image), self.transform(mask)



In [4]:
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl


class CTDataset(pl.LightningDataModule):
    def __init__(
        self,
        batch_size,
        data_dir="/content/data/CT/png",
        train_dir="/train_dir",
        test_dir="/test_dir",
        num_classes=1,
        padding=True,
        image_small=True,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.train_dataset_path = data_dir + train_dir
        self.test_dataset_path = data_dir + test_dir
        self.num_classes = num_classes
        if image_small:
            self.image_size = (256, 256)
        else:
            self.image_size = (512, 512)
        if padding:
            self.transform = transforms.Compose(
                [transforms.Pad((61, 61, 62, 62)), Resize(self.image_size), ToTensor()]
            )  # padding standardowych zdjęć 389x389 do 512x512
        else:
            self.transform_resize = transforms.Compose(
                [Resize(self.image_size), ToTensor()]
            )

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            dataset = SegmentationImageFolder(
                self.train_dataset_path, transform=self.transform
            )
            train_dataset_size = int(len(dataset) * 0.8)
            self.train_dataset, self.val_dataset = random_split(
                dataset, [train_dataset_size, len(dataset) - train_dataset_size]
            )
        if stage == "test" or stage is None:
            self.test_dataset = SegmentationImageFolder(
                self.test_dataset_path, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


In [5]:
# Unet modules

import torch
import torch.nn as nn


# Klasa na podstawie kodu z https://github.com/uygarkurt/UNet-PyTorch
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv_op(x)


# Klasa na podstawie kodu z https://github.com/uygarkurt/UNet-PyTorch
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)

        return down, p


# Klasa na podstawie kodu z https://github.com/uygarkurt/UNet-PyTorch
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(
            in_channels, in_channels // 2, kernel_size=2, stride=2
        )
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], 1)
        return self.conv(x)


In [6]:
import torch
import torch.nn as nn


# Kalasa bazująca na teorii przedstawionej w: https://www.youtube.com/watch?v=KOF38xAvo8I&t=574s
class SpatialAttention(nn.Module):
    def __init__(self, in_channels_g, in_channels_x, intermediate_channels):
        super(SpatialAttention, self).__init__()
        self.W_g = nn.Conv2d(
            in_channels_g,
            intermediate_channels,
            kernel_size=1,
            stride=(1, 1),
            padding=0,
            bias=True,
        )
        self.W_x = nn.Conv2d(
            in_channels_x,
            intermediate_channels,
            kernel_size=1,
            stride=(2, 2),
            padding=0,
            bias=True,
        )
        self.relu = nn.ReLU(inplace=True)
        self.psi = nn.Conv2d(
            intermediate_channels, 1, kernel_size=1, stride=(1, 1), padding=0, bias=True
        )
        self.sigmoid = nn.Sigmoid()
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        combined = g1 + x1
        combined = self.relu(combined)
        psi = self.sigmoid(self.psi(combined))
        psi = self.upsample(psi)
        out = x * psi
        return out


In [7]:
# Unet structure witch attention

import torch
import torch.nn as nn


# Klasa częściowo na podstawie kodu z https://github.com/uygarkurt/UNet-PyTorch
class UNetWithAttention(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down_convolution_1 = DownSample(in_channels, 64)
        self.down_convolution_2 = DownSample(64, 128)
        self.down_convolution_3 = DownSample(128, 256)
        self.down_convolution_4 = DownSample(256, 512)

        self.bottle_neck = DoubleConv(512, 1024)

        self.up_convolution_1 = UpSample(1024, 512)
        self.up_convolution_2 = UpSample(512, 256)
        self.up_convolution_3 = UpSample(256, 128)
        self.up_convolution_4 = UpSample(128, 64)

        self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

        self.attention_1 = SpatialAttention(1024, 512, 1024)
        self.attention_2 = SpatialAttention(512, 256, 512)
        self.attention_3 = SpatialAttention(256, 128, 256)
        self.attention_4 = SpatialAttention(128, 64, 128)

    def forward(self, x):
        down_1, p1 = self.down_convolution_1(x)
        down_2, p2 = self.down_convolution_2(p1)
        down_3, p3 = self.down_convolution_3(p2)
        down_4, p4 = self.down_convolution_4(p3)

        b = self.bottle_neck(p4)

        up_1 = self.up_convolution_1(b, self.attention_1(b, down_4))
        up_2 = self.up_convolution_2(up_1, self.attention_2(up_1, down_3))
        up_3 = self.up_convolution_3(up_2, self.attention_3(up_2, down_2))
        up_4 = self.up_convolution_4(up_3, self.attention_4(up_3, down_1))

        out = self.out(up_4)
        return out


In [8]:
# Unet structure

import torch
import torch.nn as nn

from unet_parts import DoubleConv, DownSample, UpSample


# Klasa na podstawie kodu z https://github.com/uygarkurt/UNet-PyTorch
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down_convolution_1 = DownSample(in_channels, 64)
        self.down_convolution_2 = DownSample(64, 128)
        self.down_convolution_3 = DownSample(128, 256)
        self.down_convolution_4 = DownSample(256, 512)

        self.bottle_neck = DoubleConv(512, 1024)

        self.up_convolution_1 = UpSample(1024, 512)
        self.up_convolution_2 = UpSample(512, 256)
        self.up_convolution_3 = UpSample(256, 128)
        self.up_convolution_4 = UpSample(128, 64)

        self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        down_1, p1 = self.down_convolution_1(x)
        down_2, p2 = self.down_convolution_2(p1)
        down_3, p3 = self.down_convolution_3(p2)
        down_4, p4 = self.down_convolution_4(p3)

        b = self.bottle_neck(p4)

        up_1 = self.up_convolution_1(b, down_4)
        up_2 = self.up_convolution_2(up_1, down_3)
        up_3 = self.up_convolution_3(up_2, down_2)
        up_4 = self.up_convolution_4(up_3, down_1)

        out = self.out(up_4)
        return out


In [10]:
import torch
from torch import nn, optim
from torcheval.metrics.functional import binary_f1_score
from rich.console import Console

console = Console()


class UNetModel(pl.LightningModule):
    def __init__(
        self,
        in_channels,
        out_channels,
        learning_rate=1e-3,
        scheduler_step_size=8,
        scheduler_gamma=0.5,
        pos_weight=2,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.out_channels = out_channels
        self.scheduler_step_size = scheduler_step_size
        self.scheduler_gamma = scheduler_gamma
        self.pos_weight = pos_weight

        self.model = UNetWithAttention(in_channels, out_channels)

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, x, y):
        pos_weight = torch.tensor([self.pos_weight]).cuda()
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        return criterion(x, y)

    def common_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.compute_loss(outputs, y)
        return loss, outputs, y

    def common_test_valid_step(self, batch, batch_idx):
        loss, outputs, y = self.common_step(batch, batch_idx)
        acc = binary_f1_score(outputs.view(-1), y.view(-1), threshold=0.5)
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self.common_test_valid_step(batch, batch_idx)
        self.log(
            "train_loss", loss, prog_bar=True, on_step=True, on_epoch=True, logger=True
        )
        self.log(
            "train_acc", acc, prog_bar=True, on_step=True, on_epoch=True, logger=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.common_test_valid_step(batch, batch_idx)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        console.print(
            f"Validation [bold cyan]Loss (batch {batch_idx}): {loss:.4f}[bold cyan]"
        )
        console.print(
            f"Validation [bold green]Accuracy (batch {batch_idx}): {acc:.4f}[/bold green]"
        )
        return loss

    def test_step(self, batch, batch_idx):
        loss, acc = self.common_test_valid_step(batch, batch_idx)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)
        console.print(
            f"Test [bold cyan]Loss (batch {batch_idx}): {loss:.4f}[bold cyan]"
        )
        console.print(
            f"Test [bold green]Accuracy (batch {batch_idx}): {acc:.4f}[/bold green]"
        )
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=self.scheduler_step_size, gamma=self.scheduler_gamma
        )
        return [optimizer], [lr_scheduler]
    

#Training

In [11]:
import torch
import gc

# czyszczenie pamięci gpu

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [12]:
if __name__ == "__main__":
    dm = CT_dataset(batch_size=16)
    dm.setup()

    model = UNetModel(1, dm.num_classes, learning_rate=1e-4)

In [13]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [14]:
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb.init(project="ct_model")
wandb_logger = WandbLogger()

[34m[1mwandb[0m: Currently logged in as: [33mlukasstan[0m ([33mlukasstan-warsaw-university-of-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [15]:
from pytorch_lightning.loggers import TensorBoardLogger

tensorboard_logger = TensorBoardLogger("lightning_logs", name="ct_model")

In [16]:
trainer = pl.Trainer(
    max_epochs=20,
    check_val_every_n_epoch=2,
    log_every_n_steps=20,
    logger=[wandb_logger, tensorboard_logger],
    accelerator="gpu",
    precision="16"
)

trainer.fit(model=model, datamodule=dm)

/usr/local/lib/python3.11/dist-packages/lightning_fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/plugins/precision/amp.py:54: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
Downloading...
From: https://drive.google.com/1FWGQi6W8d7hWWoMkykhbqY_GUFKlltWp
To: /content
1.65kB [00:00, 1.41MB/s]
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loggers/wan

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
model_save_path = "/content/unet.pth"
torch.save(model, model_save_path)

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/ --host=127.0.0.1 --port=6006 --load_fast=false

In [None]:
from google.colab import output
output.serve_kernel_port_as_window(6006, path="")

# Pruning

In [None]:
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt


def apply_pruning(model, amount=0.2):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name="weight", amount=amount)
            # Opcjonalnie usuń maskę po przycięciu
            prune.remove(module, "weight")

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
import matplotlib.pyplot as plt


def plot_weight_distribution(model, title):

    weights = torch.tensor([], device="cuda")

    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            module_weights = module.weight.to("cuda").flatten()
            weights = torch.cat((weights, module_weights))

    weights_cpu = weights.detach().cpu().numpy()

    plt.hist(weights_cpu, bins=50)
    plt.title(title)
    plt.xlabel("Wartość wagi")
    plt.ylabel("Liczba wag")
    plt.xlim(-0.1, 0.1)
    plt.show()

In [None]:
def count_zero_weights(model):
    zero_weights = 0
    total_weights = 0

    for module in model.modules():
        if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor):
            zero_weights += torch.sum(module.weight == 0).item()
            total_weights += module.weight.numel()

    zero_percentage = (zero_weights / total_weights) * 100 if total_weights > 0 else 0
    return zero_weights, total_weights, zero_percentage

In [None]:
dm = CT_dataset(batch_size=8)
dm.setup()
model = UNetModel(1, dm.num_classes, learning_rate=1e-3)

model = torch.load("unet.pth")

In [None]:
# Przed pruningiem
print("Liczba parametrów przed pruningiem:", count_parameters(model))
plot_weight_distribution(model, "Rozkład wag przed pruningiem")
zero_weights, total_weights, zero_percentage = count_zero_weights(model)
print(f"Liczba wag równych zero: {zero_weights}")
print(f"Całkowita liczba wag: {total_weights}")
print(f"Procent wag równych zero: {zero_percentage:.2f}%")

# Po pruningu
apply_pruning(model, amount=0.2)  # Zastosowanie pruning
print("Liczba parametrów po pruningu:", count_parameters(model))
plot_weight_distribution(model, "Rozkład wag po pruningu")
zero_weights, total_weights, zero_percentage = count_zero_weights(model)
print(f"Liczba wag równych zero: {zero_weights}")
print(f"Całkowita liczba wag: {total_weights}")
print(f"Procent wag równych zero: {zero_percentage:.2f}%")

In [None]:
model_save_path = "/content/unet.pth"
torch.save(model, model_save_path)

# Testing

In [None]:
import torch
import gc

# czyszczenie pamięci gpu

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
dm = CT_dataset(batch_size=8)
dm.setup()
model = UNetModel(1, dm.num_classes, learning_rate=1e-3)

model = torch.load("unet.pth")


logger = TensorBoardLogger("lightning_logs", name="ct_model")
trainer = pl.Trainer(logger=logger)

trainer.test(model=model, datamodule=dm)

In [None]:
# Results presentation (inference)

import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import numpy as np

# from carvana_dataset import CarvanaDataset
# from unet import UNet


def pred_show_image_grid(data_path, model_pth, device):
    model = torch.load(model_pth)
    model = model.to(device)
    image_dataset = CT_dataset(8)
    image_dataset.setup()

    images = [[], [], []]

    for i, (image, mask) in enumerate(image_dataset.test_dataloader()):
        image = image.to(device)
        mask = mask.to(device)
        images[0].append(image)
        images[1].append(mask)
        loss, outputs, y = model.common_step((image, mask), i)
        images[2].append(outputs)
        break

    flattened_images = []

    for image_type in images:
        for image in image_type[0]:
            flattened_images.append(image)

    fig, axes = plt.subplots(3, 8, figsize=(16, 6))
    for i, ax in enumerate(axes.flat):
        if i < len(flattened_images):
            tensor_image = flattened_images[i]

            if isinstance(tensor_image, torch.Tensor):
                tensor_image = tensor_image.cpu().detach().numpy()

            if i > 15:
                tensor_image = np.where(tensor_image < 0.5, 0.0, 1.0)

            ax.imshow(tensor_image[0], cmap="gray")
            ax.axis("off")
        else:
            ax.axis("off")
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    SINGLE_IMG_PATH = "/content/data_ct/manual_test/226.png"
    DATA_PATH = "/content/data_ct"
    MODEL_PATH = "/content/unet.pth"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    pred_show_image_grid(DATA_PATH, MODEL_PATH, device)