In [1]:
"""
    This algorithm process to the auto-encoding of the images
"""

'\n    This algorithm process to the auto-encoding of the images\n'

In [6]:
from torch.utils.data import DataLoader
from utils.dataset import imageDataset
import pandas as pd

In [3]:
imageDataset

utils.dataset.imageDataset

In [4]:
def get_data(data_folder="./data"):
    """
        Parameters
        ----------
        data_folder: str, folder containing the date

        Return
        ------
        - tuple containing :
            > an image dataset object
            > a pandas dataframe of metadata
        - dataframe of labels
    """

    # Creating an UID
    df_expert = pd.read_csv(f"{data_folder}/mimic-cxr-2.0.0-chexpert.csv.gz")
    df_metadata = pd.read_csv(f"{data_folder}/mimic-cxr-2.0.0-metadata.csv.gz")

    df_metadata = df_metadata.reset_index() \
        .rename(columns = {"index":"uid"}) \
        .drop(columns = ["Unnamed: 0"])

    df_expert = df_expert \
        .drop(columns = ["Unnamed: 0"]) \
        .join(df_metadata[["study_id","uid"]].set_index("study_id"), on = "study_id") \
        .set_index("uid") \
        .sort_index() \
        .reset_index(drop=False)
        
    df_expert["text_label"] = df_expert.iloc[:,3:].apply(
        lambda x: ",".join(x.dropna()[x.dropna() == 1].index.tolist())
    , axis=1)

    # Getting image loader

    ## Getting image path dict
    images_paths = df_metadata[["uid", "subject_id", "dicom_id", "study_id"]] \
                    .reset_index(drop=True)

    images_paths["subject_id_str"] = images_paths["subject_id"].astype("str")
    images_paths["study_id_str"] = images_paths["study_id"].astype("str")

    images_paths["path"] = f"{data_folder}/./files/p"+images_paths["subject_id_str"].str.slice(0,2)+"/p"+ \
        +images_paths["subject_id_str"]+"/s"+ \
        +images_paths["study_id_str"]+"/"+ \
        +images_paths["dicom_id"]+".jpg"

    images_paths = images_paths[["uid","path"]].set_index("uid")["path"] \
            .to_dict()

    ## Loading images
    images_dataset = imageDataset(images_paths)

    return (images_dataset, df_metadata), df_expert

In [7]:
X,y = get_data("./data")

In [8]:
# Splitting train and test
X0_train, X0_test = X[0].split(p=0.7, random_seed=42)

train_id = X0_train.image_list
test_id = X0_test.image_list

X1_train = X[1].set_index("uid").loc[train_id,:].reset_index()
X1_test = X[1].set_index("uid").loc[test_id,:].reset_index()

y_train = y.set_index("uid").loc[train_id,:].reset_index(drop=True)
y_test = y.set_index("uid").loc[test_id,:].reset_index(drop=True)

In [119]:
from torchvision.transforms import RandomResizedCrop, Resize
import torch
import numpy as np

In [121]:
def dataset_collater (x):
    
    size = tuple(np.array([x_[0].shape for x_ in x]).min(axis=0).tolist())
    resizer = Resize(512)
    randomcrop = RandomResizedCrop(size=(512,512))
    
    # Applying randomCrop
    images_tensor = []

    for x_ in x:
        image_tensor = torch.tensor(x_[0], dtype=torch.float32).unsqueeze(dim=0)
        image_tensor = resizer(image_tensor)
        image_tensor = randomcrop(image_tensor)

        images_tensor.append(image_tensor)

    output = torch.stack(images_tensor)

    return output

In [215]:
from torch import nn
from torch import optim
from torchvision.models import mobilenet_v3_small

In [219]:
class autoEncoder (nn.Module):

    def __init__ (self):
        super().__init__()

        self.mobilenet = mobilenet_v3_small(pretrained=True)

        self.encoder = nn.Sequential(*[
            nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(1,1), padding="same"),
            self.mobilenet.features,
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Flatten()
        ])

        self.decoder = nn.Sequential(*[
            nn.Unflatten(1, (1, 24,24)),
            nn.Conv2d(1, 16, kernel_size=(3,3), padding="same"),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Upsample((48,48)),
            nn.Conv2d(16, 32, kernel_size=(3,3), padding="same"),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Upsample((96,96)),
            nn.Conv2d(32, 64, kernel_size=(3,3), padding="same"),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Upsample((256,256)),
            nn.Conv2d(64, 128, kernel_size=(3,3), padding="same"),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample((512,512)),
            nn.Conv2d(128, 1, kernel_size=(3,3), padding="same"),
        ])

        self.loss = nn.MSELoss()
        self.optim = optim.Adam(self.parameters())

    def forward (self, x):
        y_encode = self.encoder(x)

        return y_encode

    def fullpass (self, x):

        y_encode = self(x)
        y_decode = self.decoder(y_encode)

        return y_decode

    def fit (self, x):

        self.train()
        self.optim.zero_grad()
        x_ = self.fullpass(x)
        loss = self.loss(x_, x)

        loss.backward()
        self.optim.step()

        return loss.detach().item()

In [220]:
dataloader = DataLoader(X0_train, collate_fn=dataset_collater, batch_size=12)

In [221]:
ae = autoEncoder()

In [222]:
n_epoch = 100

ae = ae.to("cuda:0")

for i in range(n_epoch):
    for x in dataloader:
        x = x.to("cuda:0")
        loss = ae.fit(x)

        print(f"Loss value : {loss}")