In [128]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import numpy as np
import shutil
import os
from PIL import Image
import pprint
import gc

### Load Data:

In [129]:
code_dir = os.getcwd()

data_dir = os.path.join(code_dir, "Data")
class_file_path = os.path.join(code_dir, "train_data.csv")

PNEUMOTHORAX = 1
NOT_PNEUMOTHORAX = 0
classes = {PNEUMOTHORAX, NOT_PNEUMOTHORAX}
num_classes = len(classes)

classification_file = pd.read_csv(class_file_path, header=0)
if 'Unnamed: 0' in classification_file.columns:
    classification_file.drop(columns=['Unnamed: 0', 'Unnamed: 0.1'], inplace=True)

print(classification_file.columns.values.tolist())

images_paths = np.asarray(classification_file["file_name"])
images_targets = np.asarray(classification_file["target"])

['file_name', 'target']


### Calculate the percentage of each class in the data:

In [142]:
def show_class_percentage(targets, mult):

      pneum_counter = sum([t for t in targets if t == PNEUMOTHORAX])
      not_pneum_counter = sum([1 for t in targets if t == NOT_PNEUMOTHORAX])*mult
      total_data_count = not_pneum_counter + pneum_counter

      print(f"Pneumothorax precentage: {( pneum_counter / total_data_count)*100:.3}%\n"
            f"Not pneumothorax precentage: {( not_pneum_counter / total_data_count)*100:.3}%\n" )

show_class_percentage(images_targets, 1)

Pneumothorax precentage: 78.8%
Not pneumothorax precentage: 21.2%



As we can see, the data is very biased. <br/>
In order to optimaize our model, we aspire a situation were the class percentage is as close as possible to 50%.<br/>
to do so we need to apply data augmantayion on the NOT_PNEUMOTHORAX class.<br/>
Lets check how many times we need to duplicate the NOT_PNEUMOTHORAX data.<br/>

In [143]:
for i in range(2,6):
      print(f"duplicate {i} times:\n")
      show_class_percentage(images_targets, i)

duplicate 2 times:

Pneumothorax precentage: 65.0%
Not pneumothorax precentage: 35.0%

duplicate 3 times:

Pneumothorax precentage: 55.3%
Not pneumothorax precentage: 44.7%

duplicate 4 times:

Pneumothorax precentage: 48.1%
Not pneumothorax precentage: 51.9%

duplicate 5 times:

Pneumothorax precentage: 42.6%
Not pneumothorax precentage: 57.4%



We can see that we will get the best result if we duplicate 4 times

In [None]:

new_images_paths = images_paths.tolist()
new_images_targets = images_targets.tolist()

def duplicate(i, paths):
    for p in paths:
        no_png = p.replace(".png", "")
        new_file_name = str(no_png) + "_copy_"+ str(i) + ".png"

        new_images_paths.append(new_file_name)
        new_images_targets.append(0)

        source = os.path.join(data_dir, p)
        dest = os.path.join(data_dir, new_file_name)

        if not os.path.isfile(dest):
            shutil.copy(source, dest)
    



not_pneum_paths = [p for r,p in enumerate(images_paths) if images_targets[r] == 0]

for i in range(1,4):
    duplicate(i, not_pneum_paths) 
    
new_df = pd.DataFrame(zip(new_images_paths, new_images_targets), columns=['file_name', 'target'])
       
new_df.to_csv(class_file_path, sep = ',')

In [145]:
copys_paths = [p for p in new_images_paths if "copy" in p]

show_class_percentage(new_images_targets, 1)

Pneumothorax precentage: 48.1%
Not pneumothorax precentage: 51.9%



In [148]:
# transform
transform = transforms.Compose([transforms.RandomPerspective(distortion_scale=0.2, p =1.0),
                                transforms.RandomVerticalFlip(p=0.5),
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomAdjustSharpness(sharpness_factor=2),
                                ])

for p in copys_paths:
    im_full_path = os.path.join(data_dir, p)
    image = Image.open(im_full_path)
    trans_image = transform(image)
    trans_image.save(im_full_path)
