In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import ImageFolder
import os
from tqdm import tqdm
import sys
from GPUtil import showUtilization as gpu_usage
from numba import cuda
import torch, gc
import pickle
from itertools import combinations
from torch.utils.data import DataLoader, Dataset
from PIL import Image

from models import Unet, get_default_device, to_device, DeviceDataLoader
os.environ['CUDA_VISIBLE_DEVICES']='0'
torch.manual_seed(0)

<torch._C.Generator at 0x7fe1662828f0>

In [2]:
rectype=['L2','L1','TV']
rectype_combinations=[] 
rectype_strings=[]
for i in range(len(rectype)):
    for p in combinations(rectype, i+1):  # 2 for pairs, 3 for triplets, etc

        rectype_combinations.append(p)
        rectype_strings.append('_'.join(p))
rectype_combinations
rectype_strings

['L2', 'L1', 'TV', 'L2_L1', 'L2_TV', 'L1_TV', 'L2_L1_TV']

In [3]:
lr=[1e-4,'exp', 'plateau']
radial_lines=[20,40,60,80,100]
max_epochs=500
batch_sizes=4

In [4]:
dataset_dir='./BIRN_dataset/'
images_dir=(dataset_dir+'birn_png/')
rec_dirs=[(f"{dataset_dir}birn_pngs_{rl}lines_{rt}/") for rt in rectype for rl in radial_lines]
rec_dirs

['./BIRN_dataset/birn_pngs_20lines_L2/',
 './BIRN_dataset/birn_pngs_40lines_L2/',
 './BIRN_dataset/birn_pngs_60lines_L2/',
 './BIRN_dataset/birn_pngs_80lines_L2/',
 './BIRN_dataset/birn_pngs_100lines_L2/',
 './BIRN_dataset/birn_pngs_20lines_L1/',
 './BIRN_dataset/birn_pngs_40lines_L1/',
 './BIRN_dataset/birn_pngs_60lines_L1/',
 './BIRN_dataset/birn_pngs_80lines_L1/',
 './BIRN_dataset/birn_pngs_100lines_L1/',
 './BIRN_dataset/birn_pngs_20lines_TV/',
 './BIRN_dataset/birn_pngs_40lines_TV/',
 './BIRN_dataset/birn_pngs_60lines_TV/',
 './BIRN_dataset/birn_pngs_80lines_TV/',
 './BIRN_dataset/birn_pngs_100lines_TV/']

In [5]:

class OriginalReconstructionDataset(Dataset):
    def __init__(self, radial_line, rec_type_str, datasets_dir, indexes = None, img_size=(256,256)):
        rec_type=rec_type_str.split('_')
        self.images_dir=(dataset_dir+'birn_png/')
        rec_dirs=[(f"{dataset_dir}birn_pngs_{rl}lines_{rt}/") for rt in rectype for rl in radial_lines]
        
        self.rec_images_dirs=[]
        for dir in rec_dirs:
            for rt in rec_type:
                if rt in dir:
                    if str(radial_line) in dir:
                        self.rec_images_dirs.append(dir)
                        break

        self.images = [f for f in os.listdir(self.images_dir) if f.endswith('.png')]
        if indexes is not None:
            self.images = [self.images[i] for i in indexes] 
        self.transform = transforms.Compose([
                        transforms.Grayscale(num_output_channels=1),         
                        transforms.Resize(img_size),
                        #transforms.Lambda(lambda x: x/255.0),
                        transforms.ToTensor()
                        ])
        self.rec_types=rec_type
        self.radial_line=radial_line
        print(self.images_dir)
        print(self.rec_images_dirs)
        print(self.rec_types)
        print(self.radial_line)
    def __len__(self):
    # return length of image samples    
        return len(self.images)

    def __getitem__(self, idx):
        img_name=self.images[idx]
        img = Image.open(self.images_dir+img_name)
        img=self.transform(img)
        rec_imgs=[]
        for rec,dir in zip(self.rec_types,self.rec_images_dirs):
            noisy_name=img_name[:-14]+rec+f'_{self.radial_line}lines.png'            
            tensor=self.transform(Image.open(dir+noisy_name))
            rec_imgs.append(tensor)
        noisy=torch.stack(rec_imgs)
        noisy=torch.squeeze(noisy, 1)
        return (img,noisy)

def first_element(test_dataset):
    for data in test_dataset:
        print(data[0].shape)
        print(data[1].shape)
        break


In [6]:
idx_file='indexes.pkl'
if not os.path.exists(idx_file):
    np.random.seed(seed=42)
    all_indexes=np.random.permutation(len([f for f in os.listdir(images_dir) if f.endswith('.png')]))
    m = len(all_indexes)
    m_train=int(m*0.8)
    m_val = int(m*0.1)
    train_indexes=all_indexes[:m_train]
    val_indexes=all_indexes[m_train:m_train+m_val]
    test_indexes=all_indexes[m_train+m_val:]
    
    with open(idx_file, 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump([train_indexes, val_indexes, test_indexes], f)
else:
    with open(idx_file,'rb') as f:  # Python 3: open(..., 'rb')
        train_indexes, val_indexes, test_indexes = pickle.load(f)




In [7]:
batch_size=4
device=get_default_device()

train_dataset={}
train_loaders={}
for rt in rectype_strings:
    for rl in radial_lines:
        train_ds=OriginalReconstructionDataset(rl, rt, dataset_dir, train_indexes)
        train_dataset[rl,rt]=train_ds
        train_loaders[rl,rt]=DeviceDataLoader(torch.utils.data.DataLoader(train_ds, batch_size=batch_size), device)
        first_element(train_ds)


val_dataset={}
val_loaders={}
for rt in rectype_strings:
    for rl in radial_lines:
        val_ds=OriginalReconstructionDataset(rl, rt, dataset_dir, val_indexes)
        val_dataset[rl,rt]=val_ds
        val_loaders[rl,rt]=DeviceDataLoader(torch.utils.data.DataLoader(val_ds, batch_size=batch_size), device)
        first_element(val_ds)

test_dataset={}
test_loaders={}
for rt in rectype_strings:
    for rl in radial_lines:
        test_ds=OriginalReconstructionDataset(rl, rt, dataset_dir, test_indexes)
        test_dataset[rl,rt]=test_ds
        test_loaders[rl,rt]=DeviceDataLoader(torch.utils.data.DataLoader(test_ds, batch_size=batch_size,shuffle=True), device)
        first_element(test_ds)





./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_20lines_L2/']
['L2']
20
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_40lines_L2/']
['L2']
40
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_60lines_L2/']
['L2']
60
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_80lines_L2/']
['L2']
80
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_100lines_L2/']
['L2']
100
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_20lines_L1/']
['L1']
20
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_40lines_L1/']
['L1']
40
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_60lines_L1/']
['L1']
60
torch.Size([1, 256, 256])


In [8]:
def define_oprimizer(lr, params_to_optimize):
    if lr=='exp':
        optimizer = torch.optim.Adam(params_to_optimize, lr=0.001, weight_decay=1e-05)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.000001**(1/max_epochs), last_epoch=- 1, verbose=False)
    elif lr=='plateau':
        optimizer = torch.optim.Adam(params_to_optimize, lr=0.001, weight_decay=1e-05)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1**(1/2), patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
    else:
        optimizer = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-05)
    return optimizer

In [9]:


## Training function
def train_epoch_den(model, device, dataloader, loss_fn, optimizer):
    # Set train mode
    model.train()    
    train_loss = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, image_noisy in tqdm(dataloader): # with "_" we just ignore the labels (the second element of the dataloader tuple)
        image_noisy.to(device)
        #print(device)
        #print((image_noisy.device))
        
        result = model(image_noisy)
        # Evaluate loss
        loss = loss_fn(result, image_batch)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #if lr=='exp' or lr=='plateau':
        #    scheduler.step()        
        # Print batch loss
        #print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())

    return np.mean(train_loss)

### Testing function
def test_epoch_den(model, device, dataloader, loss_fn):
    # Set evaluation mode
    model.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch, image_noisy in dataloader:
            result = model(image_noisy)
            # Append the network output and the original image to the lists
            conc_out.append(result.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        val_loss = loss_fn(conc_out, conc_label)
    return val_loss.data

def plot_ae_outputs_den(unet,n=10):
    plt.figure(figsize=(21,6))
    for i in range(n):

      ax = plt.subplot(3,n,i+1)
      img = test_dataset[4*i][0].unsqueeze(0)
      image_noisy = test_dataset[4*i][1].unsqueeze(0)
      
      unet.eval()

      with torch.no_grad():
         rec_img  = unet(image_noisy)

      plt.imshow(img.cpu().squeeze().numpy()[0,:,:], cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(3, n, i + 1 + n)
      plt.imshow(image_noisy.cpu().squeeze().numpy()[0,:,:], cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Corrupted images')

      ax = plt.subplot(3, n, i + 1 + n + n)
      plt.imshow(rec_img.cpu().squeeze().numpy()[0,:,:], cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.subplots_adjust(left=0.1,
                    bottom=0.1, 
                    right=0.7, 
                    top=0.9, 
                    wspace=0.3, 
                    hspace=0.3)   
    
    plt.savefig('images_256_CS_TV.png')
    plt.show()



In [10]:


def train(rec_type, radial_lines, model, loss_fn, optimizer, dl_train, dl_val,max_epochs):#train_loaders[rl,rt]=DeviceDataLoader(torch.utils.data.DataLoader(train_ds, batch_size=batch_size), device)
    history={'learning_rate':[],'train_loss':[],'val_loss':[]}
    last_saved=''
    best_model=''
    for epoch in range(max_epochs):
        print('EPOCH %d/%d' % (epoch + 1, max_epochs))
        ### Training (use the training function)
        train_loss=train_epoch_den(
            model=model, 
            device=device, 
            dataloader=dl_train, 
            loss_fn=loss_fn, 
            optimizer=optimizer)
        ### Validation  (use the testing function)
        val_loss = test_epoch_den(
            model=model, 
            device=device, 
            dataloader=dl_val,
            loss_fn=loss_fn)
        # Print Validationloss


        train_loss=train_loss if not torch.is_tensor(train_loss) else train_loss.cpu().detach().numpy()
        val_loss=val_loss if not torch.is_tensor(val_loss) else val_loss.cpu().detach().numpy()
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['learning_rate'].append(optimizer.param_groups[0]["lr"])

        print('\n EPOCH {}/{} \t train loss {:.8f} \t val loss {:.8f}'.format(epoch + 1, max_epochs,train_loss,val_loss))
        if val_loss<=min(history['val_loss']):
            print('Lowest val loss:',val_loss, '... Saving parameters.')
            try:
                os.remove(last_saved)
            except:
                print('couldnt remove {last_saved}')
            torch.save(model.state_dict(), f'saved_models/unet_{radial_lines}{rectype}_epoch{epoch + 1}_lr{lr}.pth')
            last_saved= f'unet_{radial_lines}{rectype}_epoch{epoch + 1}_lr{lr}.pth'
        break
        
    import pandas as pd
    df = pd.DataFrame(history)
    df.to_csv(f'saved_models/loss_history_{radial_lines}{rectype}_epoch{epoch + 1}_lr{lr}.csv')

    return history




In [14]:
%%time
import mlflow
from sklearn.model_selection import ParameterGrid

def execute(model, num_radial_lines, data_train, data_val, data_test, device, exp_params):
    experiment_name=f"MRIREC_{num_radial_lines}"
    run_params = {"description":f"Reconstruction using {num_radial_lines} radial lines",
              "tags":{'release.version':'1.0.0'}}
    experiment = mlflow.get_experiment_by_name(experiment_name)
    if not experiment:
        experiment_id=mlflow.create_experiment(experiment_name)
    experiment = mlflow.set_experiment(experiment_name)
    
    run_params.update({"experiment_id": experiment.experiment_id})
    
    print("Experiment_id: {}".format(experiment.experiment_id))
    print("Localização dos artefatos: {}".format(experiment.artifact_location))
    print("Tags: {}".format(experiment.tags))
    print("Lifecycle_stage: {}".format(experiment.lifecycle_stage))

    grid_exp = ParameterGrid(exp_params)

    for p_model in grid_exp:
        print(p_model)
        model = p_model['model']
        rectype = p_model['rectype']
        num_channels = len(rectype.split('_'))
        epochs = p_model['max_epochs']
        learnig_rate=p_model['learnig_rate']
        batch_size = p_model['batch_size']
        train_ds= OriginalReconstructionDataset(num_radial_lines, rectype, dataset_dir, train_indexes)
        train_dl= DeviceDataLoader(torch.utils.data.DataLoader(train_ds, batch_size=batch_size), device)
        val_ds=OriginalReconstructionDataset(num_radial_lines, rectype, dataset_dir, test_indexes)
        val_dl= DeviceDataLoader(torch.utils.data.DataLoader(train_ds, batch_size=batch_size), device)

        #loading model
        
        if model=='Unet':
            model=Unet(num_inputs=len(rectype.split('_'))) #1 a 3 canais
        elif False:
            pass
            #IF TO ADD NEW MODEL IMPLEMENT H#RE
        else:
            error('Model not found...')
        model.to(device)
        
        params_to_optimize = [{'params': model.parameters()}]        
        optimizer=define_optimizer(p_model.lr, params_to_optimize)

        #train loop:
        break
        train(rectype, num_radial_lines, model, torch.nn.MSELoss(), optimizer,train_dl, val_dl, device)




CPU times: user 8 µs, sys: 2 µs, total: 10 µs
Wall time: 12.6 µs


In [16]:

models=

exp_params={"model": ['Unet'],
            "rectype": rectype_strings,
            "learnig_rate":[1e-4, 'exp', 'plateau'],
            "max_epochs":[500],
            "batch_size": [4],
}


radial_lines=[20,40,60,80,100]

for rl in radial_lines:
    execute(models, rl, train_loaders, val_loaders, test_loaders,device,exp_params)


Experiment_id: 585268919028592994
Localização dos artefatos: file:///home/jonathan/MRI_unet_reconstruction/mlruns/585268919028592994
Tags: {}
Lifecycle_stage: active
{'batch_size': 4, 'learnig_rate': 0.0001, 'model': 'Unet', 'rectype': 'L2'}


KeyError: 'lr'