In [1]:
!pip install wget
!pip install torchmetrics
!pip install tqdm



In [2]:
import torch
from torch import nn
from torch.utils import data
from PIL import Image
import os
import wget
import zipfile
import torchvision
from tqdm.notebook import tqdm as tq

In [3]:
config = {
    "stage1": {
        "dest_path_data": "test_data/phase1.zip",
        "url_data": "http://kliv.iitkgp.ac.in/projects/miriad/sample_data/bmi34/phase1/phase1.zip",
        "url_model": "http://kliv.iitkgp.ac.in/projects/miriad/model_weights/bmi34/cbis_a1_b1.zip",
        "dest_path_model": "model_weights/cbis_a1_b1.zip"
    },
    "test_data": {
        "url": "http://kliv.iitkgp.ac.in/projects/miriad/sample_data/rbis_ddsm_sample.zip",
        "dest_path": "test_data/rbis_ddsm_sample.zip"
    }
}

In [4]:

def download_and_extract(path, url, expath):
    wget.download(url, path)
    with zipfile.ZipFile(path, 'r') as zip_ref:
        zip_ref.extractall(expath)

def download_data(config):
    if not os.path.exists('test_data'):
        os.makedirs('test_data')
        data_url = config['stage1']['url_data']
        data_path = config['stage1']['dest_path_data']
        download_and_extract(path=data_path, url=data_url, expath='test_data/')

# Here test_data coressponds to sample/dummy data to test the code.

def download_model(config):
    if not os.path.exists('model_weights'):
        os.makedirs('model_weights')
        data_url = config['stage1']['url_model']
        data_path = config['stage1']['dest_path_model']
        download_and_extract(path=data_path, url=data_url, expath='model_weights/')
download_data(config)
download_model(config)

In [5]:
model_config = {
    "train": {
        "train_data": "test_data/phase1/train/",
        "test_data": "test_data/phase1/test/",
        "batch_size": 10,
        "lr": 1e-8,
    },
}

In [6]:
class CustomDatasetPhase1(data.Dataset):

    def __init__(self, path_to_dataset, files256=None,files128=None,split=None,
                transform_images = None, transform_masks = None,
                images_path_rel = '.', masks_path_rel = '.',
                preserve_names = False):
        self.path_to_dataset = os.path.abspath(path_to_dataset)
        self.files256 = files256
        self.files128 = files128
        self.images_path_rel = images_path_rel # relative path to images
        self.masks_path_rel = masks_path_rel # relative path to masks (same as images)
        self.transform_images = transform_images # transforms
        self.preserve_names = preserve_names # not important, debugging stuff
        self.split = split
        self.test_files = os.listdir(self.path_to_dataset)

        # This is the list of all samples
        self.cropimages = os.listdir(os.path.join(self.path_to_dataset, self.images_path_rel))


    def __len__(self):
        if self.split == 'train':
            return min(len(self.files256),len(self.files128))
        else:
            return len(os.listdir(self.path_to_dataset))

    def __getitem__(self, i):
        # indexing function

        if self.split == 'train':
            fname256 = self.files256[i]
            fname128 = self.files128[i]
            image256 = Image.open(os.path.join(self.path_to_dataset, fname256))
            image128 = Image.open(os.path.join(self.path_to_dataset, fname128))
            # mask = Image.open(os.path.join(self.path_to_dataset, self.masks_path_rel, self.cropsubset[i]))

            # usual transformation apply
            if self.transform_images is not None:
                image256 = self.transform_images(image256)
                image128 = self.transform_images(image128)

            return [image256, image128]

        else:
            image = Image.open(os.path.join(self.path_to_dataset, self.test_files[i]))
            if self.transform_images is not None:
                image = self.transform_images(image)

            return [image, 0, self.test_files[i]]

In [14]:
batch_size = 64
images_transforms = torchvision.transforms.Compose([torchvision.transforms.Grayscale(),
                                                    torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Resize((1024,1024))])
# Create data loaders.

# CBISDDSM dataset & dataloader for inference
test_dataset = CustomDatasetPhase1(model_config['train']['test_data'],
                                    transform_images=images_transforms)
test_dataloader = data.DataLoader(test_dataset,
                                  batch_size=1, num_workers=4,
                                  pin_memory=True, shuffle=False)
print(len(test_dataloader))
for X in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X[0].shape}")
    print(f"Shape of X [N, C, H, W]: {X[1].shape}")
    break

3




Shape of X [N, C, H, W]: torch.Size([1, 1, 1024, 1024])
Shape of X [N, C, H, W]: torch.Size([1])




In [15]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class Encoder(nn.Module):
    def __init__(self, n_downconv=3, n_encowidth=64):
        super().__init__()
        # a tunable number of DownConv blocks in the architecture
        self.n_downconv = n_downconv
        self.n_encowidth = n_encowidth
        # The two mandatory initial layers
        layer_list = [
            nn.Conv2d(in_channels=1, out_channels=self.n_encowidth,
                      kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(in_channels=self.n_encowidth, out_channels=self.n_encowidth,
                      kernel_size=3, stride=2, padding=1), nn.ReLU()
        ]
        for _ in range(self.n_downconv):
            layer_list.extend([
                nn.Conv2d(in_channels=self.n_encowidth, out_channels=self.n_encowidth,
                          kernel_size=3, stride=1, padding=1), nn.ReLU(),
                nn.Conv2d(in_channels=self.n_encowidth, out_channels=self.n_encowidth,
                          kernel_size=3, stride=2, padding=1), nn.ReLU(),
            ])
        # The one mandatory end layer
        layer_list.append(
            nn.Conv2d(in_channels=self.n_encowidth, out_channels=16,
                      kernel_size=3, stride=1, padding=1)
        )
        # register the Sequential module
        self.encoder = nn.Sequential(*layer_list)

    def forward(self, x):
        # forward pass; a final clamping is applied
        return torch.clamp(self.encoder(x), 0, 1)

class Decoder(nn.Module):
    def __init__(self, n_upconv=3, n_decowidth=96):
        super().__init__()

        # a tunable number of DownConv blocks in the architecture
        self.n_upconv = n_upconv
        self.n_decowidth = n_decowidth

        # The one mandatory initial layers
        layer_list = [
            nn.Conv2d(in_channels=16, out_channels=n_decowidth,
                      kernel_size=3, stride=1, padding=1), nn.ReLU(),
        ]
        # 'n_upconv' number of UpConv layers (In the CVPR paper, it was 3)
        for _ in range(self.n_upconv):
            layer_list.extend([
                nn.Conv2d(in_channels=n_decowidth, out_channels=n_decowidth *
                          4, kernel_size=3, stride=1, padding=1), nn.ReLU(),
                nn.PixelShuffle(2)
            ])
        # The mandatory final layer
        layer_list.extend([
            nn.Conv2d(in_channels=n_decowidth, out_channels=1 *
                      4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2)
        ])
        # register the Sequential module
        self.decoder = nn.Sequential(*layer_list)

    def forward(self, x):
        # forward pass; a final clamping is applied
        return torch.clamp(self.decoder(x), 0, 1)

class AutoEncoder(nn.Module):
    def __init__(self, n_updownconv=3, width=64):
        super().__init__()
        self.n_updownconv = n_updownconv
        self.width = width

        # there must be same number of 'n_downconv' and 'n_upconv'
        self.encoder = Encoder(
            n_downconv=self.n_updownconv, n_encowidth=self.width)
        self.decoder = Decoder(
            n_upconv=self.n_updownconv, n_decowidth=self.width)

    def forward(self, x):
        self.shape_input = list(x.shape)
        x = self.encoder(x)
        self.shape_latent = list(x.shape)
        x = self.decoder(x)
        return x

model = AutoEncoder(n_updownconv=3, width=64)
if torch.cuda.is_available():
  model = model.cuda()


Using cuda device


In [16]:
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio

def compare_psnr_batch(original, compressed, **kwargs):

    assert original.shape == compressed.shape, 'shapes should be same'
    assert len(original.shape) == 4  # Batch x Channel x Height x Width

    psnr = PeakSignalNoiseRatio().cuda()
    avg_psnr = psnr(compressed, original)

    return avg_psnr

def compare_ssim_batch(original, compressed, **kwargs):
    assert original.shape == compressed.shape # 'shapes should be same'
    assert len(original.shape) == 4  # Batch x Channel x Height x Width

    ssim = StructuralSimilarityIndexMeasure().cuda()
    avg_ssim = ssim(original, compressed)

    return avg_ssim

In [17]:
def validate_model_phase1(config, test_dataloader, model):
    n, avg_loss, avg_ssim, avg_psnr = 1, 0, 0, 0
    for _, data_list in enumerate(test_dataloader):
        images = data_list[0]
        if torch.cuda.is_available():
            images = images.cuda()

        output = model(images)
        # calculate the metrics (SSIM and pSNR)
        ssim = compare_ssim_batch(images, output)
        psnr = compare_psnr_batch(images, output)

        avg_ssim = ((n * avg_ssim) + ssim) / (n + 1)  # running mean
        avg_psnr = ((n * avg_psnr) + psnr) / (n + 1)  # running mean
        n += 1
    return avg_ssim, avg_psnr

In [18]:
model = AutoEncoder(n_updownconv=3, width=64).cuda()
saved_dict = torch.load("model_weights/cbis_a1_b1.model")
model.load_state_dict(saved_dict['model_state'])

<All keys matched successfully>

In [19]:
avg_psnr,avg_ssim = validate_model_phase1(config, test_dataloader, model)

