In [4]:
# import sys
# sys.path.insert(1, '/home/john/Desktop/EMSNET/modules/model/')
# import emsnet_preprocessing

In [117]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
# from skimage import io, transform
import numpy as np
import os
import matplotlib.pyplot as plt
from skimage.transform import resize
from glob import glob
from tqdm import tqdm


import lightning.pytorch as pl
from lightning.pytorch.tuner import Tuner

import albumentations as A
from albumentations.augmentations.dropout.grid_dropout import GridDropout
import pickle

import torchvision

import cv2

from imblearn.over_sampling import RandomOverSampler

In [172]:
def remove_borders(tensor):
    """
    Removes black edges from a PyTorch tensor.

    Args:
        tensor (torch.Tensor): A PyTorch tensor of shape (channel, height, width).

    Returns:
        torch.Tensor: A tensor with black edges removed.
    """
    # Compute the sum of the pixel values in each row and 
    print(tensor.shape)
    row_sums = np.sum(tensor, axis=2)
    col_sums = np.sum(tensor, axis=1)

    # Find the first and last non-zero row and column indices
    first_row = np.min(np.nonzero(row_sums))
    last_row = np.max(np.nonzero(row_sums))
    first_col = np.min(np.nonzero(col_sums).min().item()
    last_col = np.nonzero(col_sums).max().item()

    # Create a new tensor with the new size
    new_tensor = tensor[:, first_row:last_row+1, first_col:last_col+1]

    return new_tensor

In [173]:
def RetinaTransform(image):
    # resizing image because original does not fit in memory.
    image = image.transpose([2,0,1])
    image = remove_borders(image)
    image = resize(image, (64,64),anti_aliasing=False)
    #Fixing dtype to avoid runtime error and save memory
    image = torch.tensor(image ,dtype=torch.float32)
    
    return image

In [174]:
# label_frame.sum().sort_values(ascending=False)[2:]

In [175]:
# Load dataset into memory
data_folder = "../Evaluation_Set/Validation"
label_path = "../Evaluation_Set/RFMiD_Validation_Labels.csv"
label_frame = pd.read_csv(label_path)
local_transform = RetinaTransform

for image_name in glob(data_folder+"/*"):
    # image = io.imread(image_name)

    label_frame_index = int(image_name.split("/")[-1].split(".")[0])
    label_frame.at[label_frame_index-1, "image_path"] = image_name
    
X_train, y_train = label_frame["image_path"].values.reshape(-1,1),label_frame.pop("Disease_Risk")
ros = RandomOverSampler(random_state=0)
X_resampled, y_resampled = ros.fit_resample(X_train, y_train)

In [176]:
image_array = []
for image,label in tqdm(zip(X_resampled.flatten(),y_resampled)):
    img = cv2.imread(image)
    img = local_transform(img)
    plt.imshow(img.permute(1,2,0).cpu())
    image_array.append((img,label))
    
with open("image_array.pkl", "wb") as f:
        pickle.dump(image_array,f)


0it [00:00, ?it/s][A

(3, 1424, 2144)
0 [[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]





AttributeError: 'tuple' object has no attribute 'min'

In [177]:
# with open("image_array.pkl", "rb") as f:
#     image_array = pickle.load(f)

In [103]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [104]:
def SuperConv2d(x, supr_conv_arch):
    return torch.cat([conf(x).flatten(start_dim=1) for conf in supr_conv_arch],dim=1)

In [110]:
class EMSNETDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.simple_conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1,padding=1)
        # self.super_conv_conf = [[3,1,3,1,1]]
        # self.super_conv_arch = [nn.Conv2d(*conf).to(device) for conf in self.super_conv_conf]
        # print(self.super_conv_arch )
        
        
        self.lrelu = nn.LeakyReLU(0.2)
        self.linear_decoder = nn.Sequential(
            nn.Linear(64*64*3,64*64),
            self.lrelu,
            nn.Linear(64*64,2048),
            self.lrelu,
            nn.Linear(2048,1024),
            self.lrelu,
            nn.Linear(1024,256),
            self.lrelu,
            nn.Linear(256,32),
            self.lrelu,
            nn.Linear(32,1),
        )
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.dropout= nn.Dropout(p=0.8)
        
    def forward(self, x):
        x = self.simple_conv(x)
        x = x.flatten(start_dim=1)
        x = self.lrelu(x)
        x = self.linear_decoder(x)
        x = self.sigmoid(x)
        return x

In [111]:
class EMSNETGAN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.discriminator = EMSNETDiscriminator()
        
        # scripted_transforms = torch.jit.script(transforms)
        
    def forward(self, img):
        return self.discriminator(img)

    def bce_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch):
        imgs, labels = batch
        imgs = imgs / 255

        # plt.imshow(imgs[0].cpu().permute(1,2,0)*255)

        output = self.forward(imgs)
        loss = bce_loss(output, labels)
        return loss
        
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=3e-4)
        return optimizer


In [112]:
class RetinaDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, data_indices, transform=None):
        
        self.data = [image_array[idx] for idx in data_indices]
        self.augmenter = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # A.RandomBrightnessContrast(p=0.3,brightness_limit = (-0.2,0.2),contrast_limit = (-0.2,0.2)),
            # A.MedianBlur(p=0.3,always_apply=False,blur_limit=5),
            # A.IAAAdditiveGaussianNoise(p=0.5,scale=(0,0.15*255)),
            # A.HueSaturationValue(hue_shift_limit=10,sat_shift_limit=10,val_shift_limit=10,p=0.3),
            # A.Cutout(p=0.5,max_h_size=20,max_w_size=20,num_holes=5)
        ])
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
        # return (torch.tensor(self.augmenter(image=self.data[idx][0].numpy())["image"]),self.data[idx][1])

In [113]:
train_dataset = RetinaDataset(data_indices=np.arange(0,len(image_array)))
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=8)

In [114]:
emsnet = EMSNETGAN()

In [115]:
trainer = pl.Trainer(max_epochs=10, accelerator='gpu')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [116]:
trainer.fit(model=emsnet, train_dataloaders=train_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                | Params
------------------------------------------------------
0 | discriminator | EMSNETDiscriminator | 61.1 M
------------------------------------------------------
61.1 M    Trainable params
0         Non-trainable params
61.1 M    Total params
244.381   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x8832 and 12288x4096)

In [None]:
3*64*64