# Testing GAN generated data augmentation, using different proportions of data for training GAN 
## This was 

Dataset is COVID_QU_Ex dataset from Kaggle

# Libraries

In [2]:
import os 
import torch
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import random
from PIL import Image
import pickle
import subprocess

#This code needs a little bit rework, so testing will be easier

In [3]:
root_dir = './CovidData/Lung_Segmentation_Data/'

CLASSES = ['normal', 'viral', 'covid']


# Image Transforms

In [4]:
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    torchvision.transforms.Grayscale(num_output_channels=1),
    torchvision.transforms.Resize(size=(128, 128)),
    torchvision.transforms.ToTensor()
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    torchvision.transforms.Grayscale(num_output_channels=1),
    torchvision.transforms.Resize(size=(128, 128)),
    torchvision.transforms.ToTensor()
])


In [5]:
gan_directories = {
    'Test_orig_0.8' : '2022-12-14_12-57-44',
    'Test_orig_0.6' : '2022-12-14_15-04-49',
    'Test_orig_0.4' : '2022-12-14_16-48-34',
    'Test_orig_0.2' : '2022-12-14_18-54-13',
    
    'Test_0_0.8' : '2022-12-15_09-10-08',
    'Test_0_0.6' : '2022-12-15_10-41-13',
    'Test_0_0.4' : '2022-12-15_11-34-52',
    'Test_0_0.2' : '2022-12-15_12-31-45',

    'Test_1_0.8' : '2022-12-15_12-57-15',
    'Test_1_0.6' : '2022-12-15_14-07-49',
    'Test_1_0.4' : '2022-12-15_15-01-39',
    'Test_1_0.2' : '2022-12-15_15-41-37',

    'Test_2_0.8' : '2022-12-15_16-08-24',
    'Test_2_0.6' : '2022-12-15_17-22-46',
    'Test_2_0.4' : '2022-12-15_18-18-53',
    'Test_2_0.2' : '2022-12-15_19-04-53',

    'Test_3_0.8' : '2022-12-15_19-31-00',
    'Test_3_0.6' : '2022-12-15_20-43-34',
    'Test_3_0.4' : '2022-12-15_21-39-35',
    'Test_3_0.2' : '2022-12-15_22-21-06',
     }
#TODO Make this into a csv file

# Making datasets

In [6]:

class CustomDataset(torch.utils.data.Dataset): #should be ImageFolder, needs rewrite
    def __init__(self, images, classes, transform):
        #self.class_names = classes
        #self.image_dirs = image_dirs
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return (len(self.images))

    def __getitem__(self, index):
        path, idx = self.images[index]
        with open(path, 'r') as file:
            image = Image.open(path).convert('RGB')
        return self.transform(image), idx

def DatasetMaker(split, mode=None, data_ratio=1, transform = None, geoaugment=False, seed = 0):
    """
        Returns a CustomDataset with given parameters
        split: str, which train-test to use; options: 'orig', '0', '1', '2', '3'
        mode: str, 'oversampling', 'gan'
            'oversampling' : oversample with real images, to balance classes
                    'gan' : balance datasets with gan generated images (uses data_ratio to figure out which gan to use)
                    None  : dataset won't be balanced
        data_ratio (optional): int, the ratio of the training covid data to be used, valid_ratios:
                        '1' : all data
                        '0.8' : 80% of training images
                        '0.6' : 60% of training images
                        '0.4' : 40% of training images
                        '0.2' : 20% of training images
        transform (optional): torch.Compose instance, sets the dataset's transforms
        geougment (optional): bool, uses basic data augmentation means
        seed (optional): int, seed to use for reproducibility
    """
    #when test is orig, and gan is gan_1, should this be used? what? 
    if is_test_valid(split) is False:
        return
    
    classes = ['normal', 'viral', 'covid']
    orig_dirs = {
    'normal' : f'{root_dir}/original/Normal',
    'viral' : f'{root_dir}/original/Non-Covid',
    'covid' : f'{root_dir}/original/COVID-19'
    }
    fake_dirs = {
        'gan_0.8' : f'{root_dir}/generated/Test_{split}/gan_0.8',
        'gan_0.6' : f'{root_dir}/generated/Test_{split}/gan_0.6',
        'gan_0.4' : f'{root_dir}/generated/Test_{split}/gan_0.4',
        'gan_0.2' : f'{root_dir}/generated/Test_{split}/gan_0.2'
    }
    indicies_files = {
        'gan_1' : f'{root_dir}/Indicies_files/Test_{split}/{split}_split_1_gan.pkl',
        'gan_0.8' : f'{root_dir}/Indicies_files/Test_{split}/{split}_split_0.8_gan.pkl',
        'gan_0.6' : f'{root_dir}/Indicies_files/Test_{split}/{split}_split_0.6_gan.pkl',
        'gan_0.4' : f'{root_dir}/Indicies_files/Test_{split}/{split}_split_0.4_gan.pkl',
        'gan_0.2' : f'{root_dir}/Indicies_files/Test_{split}/{split}_split_0.2_gan.pkl',
        'test'  : f'{root_dir}/Indicies_files/Test_{split}/{split}_split_test.pkl',
        'train'  : f'{root_dir}/Indicies_files/Test_{split}/{split}_split_train_and_val.pkl'
    }
    class_idx = {
        'covid': 0, 
        'viral': 1,
        'normal': 2,
        'gan_1' : 0,
        'gan_0.8' : 0,
        'gan_0.6' : 0,
        'gan_0.4' : 0,
        'gan_0.2' : 0
    }
    idx_to_class ={
        0: 'covid',
        1: 'viral',
        2: 'normal'
    }
    valid_ratios = [1, 0.8, 0.6, 0.4, 0.2]

    source_dir = {}
    train_images = []

    #We make the source dir for the classes
    for class_name in classes: 
        source_dir[class_name] = orig_dirs[class_name]
    
    file = indicies_files['train'] #Get all training images
    imgs = load_images_from_file(file)
    imgs = [*imgs[0],*imgs[1]] #Bad file generation
    num_of_covid_imgs = 0
    for x in imgs: #Set the training images path 
        class_of_x = idx_to_class[x[1]]
        item = os.path.join(source_dir[class_of_x],x[0]),class_idx[class_of_x] 
        if item[1] == class_idx['covid']: num_of_covid_imgs +=1 
        train_images.append(item)
    if data_ratio!=1: #If data_ratio is not 1, we need to change the covid images of the dataset
        num_of_covid_imgs = 0
        train_images = [x for x in train_images if x[1]!=class_idx['covid']]
        file = indicies_files[f'gan_{data_ratio}']
        imgs = load_images_from_file(file)
        for x in imgs:
            class_of_x = idx_to_class[x[1]]
            item = os.path.join(source_dir[class_of_x],x[0]),class_idx[class_of_x] 
            if item[1] == class_idx['covid']: num_of_covid_imgs +=1 
            train_images.append(item)

    average_class_size = round((len(train_images)-num_of_covid_imgs)/2)
    missing_images = max(0, average_class_size - num_of_covid_imgs)

    if mode=='gan' and (data_ratio in valid_ratios):
        #If gan is used for dataset balancing
        gan = f'gan_{data_ratio}'
        gan_dir = fake_dirs[gan]
        generate_images_to_dir(split, data_ratio, gan, gan_dir, missing_images) #?
        gan_ims = []
        for x in os.listdir(gan_dir): #optimize further
            if x.lower().endswith('jpg'):
                item = os.path.join(gan_dir, x), class_idx['covid']
                gan_ims.append(item)
        sample = random.sample(gan_ims, missing_images)
        train_images = [*train_images,*sample]
    elif mode=='oversampling':
        covid_images = [x for x in train_images if x[1]==class_idx['covid']]
        batch_size = len(covid_images)
        while batch_size <= missing_images:
            train_images = [*train_images, *covid_images]
            missing_images -= batch_size
        if missing_images>0:
            sample = random.sample(covid_images, missing_images)
            missing_images = [*missing_images, *sample]
    
    test_images = []
    test_imgs = load_images_from_file(indicies_files['test'])
    for x in test_imgs:
        class_of_x = idx_to_class[x[1]]
        item = os.path.join(source_dir[class_of_x],x[0]),idx_to_class[class_of_x] #This should be correct
        test_images.append(item)

    if transform is None:
        transforms = [ torchvision.transforms.Resize(size=(128, 128)),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                    torchvision.transforms.Grayscale(num_output_channels=1)]
    else:
        transforms = transform
    if geoaugment:
        augmentation_transforms = [#torchvision.transforms.RandomHorizontalFlip(), #(should be useful, causes confusion with gans)
                                torchvision.transforms.RandomAffine(4)]
        transforms = [augmentation_transforms + transforms]                      
    transforms = torchvision.transforms.Compose(transforms)
    train_dataset = CustomDataset(train_images, classes, transforms)
    test_dataset = CustomDataset(test_images, classes, transforms)
    return train_dataset, test_dataset

def load_images_from_file(file):
    with open(file, 'rb') as file:
        data = pickle.load(file)
    return data   
    
def generate_images_to_dir(split, data_ratio, gan, dir, size):
    """
        Generates pictures with a given gan, to a given directory
    """
    curr_dir = os.getcwd()
    output_dir = os.path.join(curr_dir, dir) #?
    gan_dir = gan_directories[f'Test_{split}_{data_ratio}'] 

    lippi_dir = '/home/bbernard/lipizzaner-covidgan-master/src' #Change this on server

    code =f'conda activate lipizzaner && python main.py generate --mixture-source ./output/lipizzaner_gan/master/{gan_dir} -o {output_dir} --sample-size {size} -f configuration/covid-qu-conv/Test_{split}/covidqu_{data_ratio}.yml'
    os.system.chdir(lippi_dir)
    subprocess.run(code)
    os.system.chdir(curr_dir)


def is_test_valid(test):
    if test in ['orig', '0', '1', '2', '3']: return True
    else: return False

## Models


In [7]:
def model(name):
    if name=="resnet":
        resnet18 = torchvision.models.resnet18(pretrained=True)
        resnet18.conv1= torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        resnet18.fc = torch.nn.Linear(in_features=512, out_features=3)
        resnet18.get_name = 'resnet18'
        return resnet18
    elif name=="vgg":
        vgg16 = torchvision.models.vgg16(pretrained=True)
        vgg16.features[0] = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        vgg16.classifier[6] = torch.nn.Linear(in_features=4096, out_features=3, bias=True)
        vgg16.get_name = 'vgg16'
        return vgg16
    elif name=="efficient":
        efficient = torchvision.models.efficientnet_b0(pretrained=True)
        efficient.features[0][0] = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        efficient.classifier[1] = torch.nn.Linear(in_features=1280, out_features=3, bias=True)
        efficient.get_name = 'efficientnet_b0'
        return efficient
    else:
        print("Not implemented")

## Making dataset and dataloader

In [19]:
#params = {'split': 'orig',      #'orig' and '0' through '3'
#            'mode'      : None, #'oversampling', 'gan'
#            'data_ratio': 1,    # 1, 0.8, 0.6, 0.4, 0.2
#            'transform' : None, 
#            'geoaugment': False,
#            'seed'      : 0
#            }

#train_dataset, test_dataset = DatasetMaker(params['split'], params['mode'], params['data_ratio']) #, params['transform'], params['geoaugment'], params['seed']) 

#BATCH=64
#train_dl = DataLoader(train_dataset, batch_size= BATCH, shuffle = True)
#test_dl = DataLoader(test_dataset, batch_size= BATCH, shuffle = True )

#print(f'Number of train images: {len(train_dataset)}, number of test images: {len(test_dataset)}')
#print(f'Number of train batches: {len(train_dl)}, number of test batches: {len(test_dl)}')

27132


## Making modell, loss function and optimizer


In [20]:
#loss_fn = torch.nn.CrossEntropyLoss()
#name = 'resnet'
#modell = model(name)
#optimizer = torch.optim.Adam(modell.parameters(), lr=3e-5)



## Train function

In [8]:
def train(epochs, model, loss_fn, optimizer, train_dataset, test_dataset, batch_size, shuffle ):
    """
        A simple train function 
        Params: 
            epoch: number of epochs to train for
    """
    train_dl = DataLoader(train_dataset, batch_size = batch_size, shuffle= shuffle)
    test_dl = DataLoader(test_dataset, batch_size = batch_size, shuffle= shuffle)

    history = {'train_loss': [],
               'train_accuracy': [],
               'val_loss': [],
               'val_accuracy': []}
    
    if torch.cuda.is_available():
        device = torch.cuda.current_device()
    else: 
        device = torch.device('cpu')

    print('Starting training..')
    for e in range(epochs):
        print('='*20)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*20)

        train_loss = 0.
        train_accuracy = 0.

        model.train() # set model to training phase

        for train_step, (images, labels) in enumerate(train_dl):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            train_accuracy += sum((preds == labels).numpy())
            if train_step%20==0:
                print(f"Training round {train_step}")

        train_loss /= (train_step + 1)
        train_accuracy = train_accuracy/len(train_dataset)
        print(f'Training Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}')

        val_loss = 0.
        val_accuracy = 0.0

        model.eval()
        
        for val_step, (images, labels) in enumerate(test_dl):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            val_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            val_accuracy += sum((preds == labels).numpy())

        val_loss /= (val_step + 1)
        val_accuracy = val_accuracy/len(test_dataset)
        print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}')

        model.train()
        history['train_loss'].append(train_loss)
        history['train_accuracy'].append(train_accuracy)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_accuracy)

    print('Training complete..')
    return model, history

In [9]:
def save_history(model, history, split, data_ratio, mode):
    FILEBASE = f"{model.get_name}_model_{split}_split_{data_ratio}_ratio_{mode}_mode"
    torch.save(model.state_dict(), FILEBASE + '.pt')
    with open(FILEBASE + '-history.pkl', 'wb') as file:
        pickle.dump(file, history)

## Tests

In [12]:
#'orig' and '0' through '3'
#'oversampling', 'gan'
# # 1, 0.8, 0.6, 0.4, 0.2
#'transform' : None, 
#'geoaugment': False,
#'seed'      : 0

#dataset rules
SPLIT = 'orig'
MODE = None
DATA_RATIO=0.2
GEOAUGMENT = False

#Network
NETWORK = 'resnet'
EPOCHS = 1
SHUFFLE=True
BATCH=64
LOSS_FN = torch.nn.CrossEntropyLoss()

train_dataset, test_dataset = DatasetMaker(split = SPLIT, mode = MODE, data_ratio = DATA_RATIO, geoaugment = GEOAUGMENT) #, params['transform'], params['geoaugment'], params['seed']) 

train_dl = DataLoader(train_dataset, batch_size= BATCH, shuffle = True)
test_dl = DataLoader(test_dataset, batch_size= BATCH, shuffle = True )

model = model(NETWORK)
OPTIMIZER = torch.optim.Adam(model.parameters(), lr=3e-5) #0.00003

model, history = train(EPOCHS, model, LOSS_FN, OPTIMIZER, train_dataset, test_dataset, BATCH,  SHUFFLE )

save_history(model, history, SPLIT, DATA_RATIO, MODE)

27132




Starting training..
Starting epoch 1/1
Training round 0


KeyboardInterrupt: 

In [None]:
#dataset rules
SPLIT = 'orig'
MODE = None
DATA_RATIO=1
GEOAUGMENT = False

#Network
NETWORK = 'resnet'
EPOCHS = 1
SHUFFLE=True

train_dataset, test_dataset = DatasetMaker(split = SPLIT, mode = MODE, data_ratio = DATA_RATIO, geoaugment = GEOAUGMENT) #, params['transform'], params['geoaugment'], params['seed']) 

train_dl = DataLoader(train_dataset, batch_size= BATCH, shuffle = True)
test_dl = DataLoader(test_dataset, batch_size= BATCH, shuffle = True )

model = model(NETWORK)