In [None]:
import sys
sys.path.append('/tf/data')

import os
import torch
import torch.nn as nn
import numpy as np
import shutil
import matplotlib.pyplot as plt

from MedImageGanFiles.dcgan import weights_init, Generator
from general_func import load_dataset
from Classification.class_functions import split_ds, concat_data, plot_loss_acc
from Classification.Custom_GridSearch import grid_search

In [None]:
params = {
    'images_root_path': '/tf/data/augmented_64_3ch/',
    'run': 'Search_1',
    'model_save_path': '/tf/data/Classification/ConvNeXt/Grid_Search_synth_1/',
    'model_save_freq_epochs': 1000,
    'seed': 42,

    'num_epochs': 150,
    'learning_rate': 4e-3,
    'weight_decay': 0.001,
    'warmup_epochs': 10,
    'early_stop': 15,

    'batch_size': 8,

    'loader_workers': 2,

    'drop_rate': 0.7,
    'apply_class_weights': False,

    #RandAUG
    'num_ops': 8,
    'magnitude': 20,

    #GridMask
    'offset': False, #False: square = 0, True: Square = noise
    'ratio': 0.5, #how much image to keep
    'mode': 1, #0 = keep squares, 1 = cut squares
    'prob': 0.5 #Probability to apply transformation 
    }

print('Device:', torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
def generate_synth(
        generate_class: int,
        n: int,
        weights_0 = '/tf/data/MedImageGanModels_MyModels/MIG3_neg2_2_lr0_0001_b8/Epoch_001_Iter_009900.zip',
        weights_1 = '/tf/data/MedImageGanModels_MyModels/MIG3_pos2_2_lr0_0001_b8/Epoch_000_Iter_001200.zip',
        show_images = False,
        ):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #Create generator
    netG = Generator(ngpu=1, nz=256, ngf=64, nc=3).to(device)
    netG.main[12] = nn.ConvTranspose2d(64, 64, kernel_size=(3, 4), stride=(1, 4), padding=(1, 2), bias=False).apply(weights_init).to(device)
    netG.main[13] = nn.BatchNorm2d(64).apply(weights_init).to(device)
    netG.main.add_module('14', nn.ReLU(inplace=True))

    netG.main.add_module('15', nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=(5, 4), stride=(1, 4), padding=(2,2), bias=False).apply(weights_init).to(device))
    netG.main.add_module('16', nn.BatchNorm2d(64).to(device))
    netG.main.add_module('17', nn.ReLU(inplace=True))
    
    netG.main.add_module('18', nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=(4, 3), stride=(2, 3), padding=(1,2), bias=False).apply(weights_init).to(device))
    netG.main.add_module('19', nn.Tanh())

    if generate_class == 0:
        netG.load_state_dict(torch.load(weights_0))
    else:
        netG.load_state_dict(torch.load(weights_1))

    #Generate images
    input_tensor = torch.randn(n, 256, 1, 1).to(device)
    images = netG(input_tensor).cpu().detach()

    if show_images == True:
        for i in range(3):
            image = np.transpose(images[i],(1,2,0))
            image = (image-np.min(image)) / (np.max(image)-np.min(image))
            plt.figure(figsize=(20,5))
            plt.axis('off')
            plt.imshow(image)
            plt.show()
    return list(images)

In [None]:
#Load data and create dataloaders
ds_pos = load_dataset(positive=True)
train_scans_pos, val_scans_pos, test_scans_pos = split_ds(ds_pos, train_split = 0.7, val_split = 0.3, seed = params['seed'])
ds_neg = load_dataset(positive=False)
train_scans_neg, val_scans_neg, test_scans_neg = split_ds(ds_neg, train_split = 0.7, val_split = 0.3, seed = params['seed'])

n_pos = len(train_scans_pos)
n_neg = len(train_scans_neg)

search_n = 0
for n_synth in [0, 1, 3]:
    params['n_synth'] = n_synth
    gen_n_pos = n_neg-n_pos + n_synth*n_neg
    gen_n_neg = n_synth*n_neg
    augmented_train_scans_pos = train_scans_pos+generate_synth(generate_class=1, n=gen_n_pos)
    print('Total positive training scans:', len(augmented_train_scans_pos))
    augmented_train_scans_neg = train_scans_neg+generate_synth(generate_class=0, n=gen_n_neg)
    print('Total negative training scans:', len(augmented_train_scans_neg))
    
    for batch in [8, 16, 32, 64]:
        params['batch_size'] = batch
        train_loader = concat_data(augmented_train_scans_pos, augmented_train_scans_neg, batch_size=params['batch_size'], workers=params['loader_workers'])
        val_loader = concat_data(val_scans_pos, val_scans_neg, batch_size=params['batch_size'], workers=params['loader_workers'])
        test_loader = concat_data(test_scans_pos, test_scans_neg, batch_size=params['batch_size'], workers=params['loader_workers'])
    
        for lr in [4e-2, 4e-3, 4e-4, 4e-5]:
            for weight_decay in [0.01, 0.005, 0.001]:
                for drop_rate in [0.5, 0.7]:
                    search_n += 1

                    if search_n < 238:
                        continue
                    print('Model:', search_n)
                    params['learning_rate'] = lr
                    params['weight_decay'] = weight_decay
                    params['drop_rate'] = drop_rate
                    _, train_losses, val_losses, val_accuracy = grid_search(params, search_n, train_loader, val_loader, test_loader, ds_pos, ds_neg, log_path = params['model_save_path']+params['run'])
                    
                    #Delete irellevant searches:
                    folder = params['model_save_path']+params['run']+'/'+str(search_n)
                    if min(val_losses) > 0.38:
                        if os.path.exists(folder):
                            shutil.rmtree(folder)
                            print("Folder and its contents deleted successfully.")
                    else:
                        plot_loss_acc(train_losses, val_losses, val_accuracy, save_dir = folder)