In [16]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import ImageFolder
import os
from tqdm import tqdm
import sys
from GPUtil import showUtilization as gpu_usage
from numba import cuda
import torch, gc
import pickle
from itertools import combinations
from torch.utils.data import DataLoader, Dataset
from PIL import Image

from models import Unet, get_default_device, to_device, DeviceDataLoader
os.environ['CUDA_VISIBLE_DEVICES']='0'


In [17]:
rectype=['L2','L1','TV']
rectype_combinations=[] 
rectype_strings=[]
for i in range(len(rectype)):
    for p in combinations(rectype, i+1):  # 2 for pairs, 3 for triplets, etc

        rectype_combinations.append(p)
        rectype_strings.append('_'.join(p))
rectype_combinations
rectype_strings

['L2', 'L1', 'TV', 'L2_L1', 'L2_TV', 'L1_TV', 'L2_L1_TV']

In [18]:
lr=[1e-4,'exp', 'plateau']
radial_lines=[20,40,60,80,100]
max_epochs=[500]
batch_sizes=[4]

In [19]:
dataset_dir='./BIRN_dataset/'
images_dir=(dataset_dir+'birn_png/')
rec_dirs=[(f"{dataset_dir}birn_pngs_{rl}lines_{rt}/") for rt in rectype for rl in radial_lines]
rec_dirs

['./BIRN_dataset/birn_pngs_20lines_L2/',
 './BIRN_dataset/birn_pngs_40lines_L2/',
 './BIRN_dataset/birn_pngs_60lines_L2/',
 './BIRN_dataset/birn_pngs_80lines_L2/',
 './BIRN_dataset/birn_pngs_100lines_L2/',
 './BIRN_dataset/birn_pngs_20lines_L1/',
 './BIRN_dataset/birn_pngs_40lines_L1/',
 './BIRN_dataset/birn_pngs_60lines_L1/',
 './BIRN_dataset/birn_pngs_80lines_L1/',
 './BIRN_dataset/birn_pngs_100lines_L1/',
 './BIRN_dataset/birn_pngs_20lines_TV/',
 './BIRN_dataset/birn_pngs_40lines_TV/',
 './BIRN_dataset/birn_pngs_60lines_TV/',
 './BIRN_dataset/birn_pngs_80lines_TV/',
 './BIRN_dataset/birn_pngs_100lines_TV/']

In [49]:

class OriginalReconstructionDataset(Dataset):
    def __init__(self, radial_line, rec_type_str, datasets_dir, indexes = None, img_size=(256,256)):
        rec_type=rec_type_str.split('_')
        self.images_dir=(dataset_dir+'birn_png/')
        rec_dirs=[(f"{dataset_dir}birn_pngs_{rl}lines_{rt}/") for rt in rectype for rl in radial_lines]
        
        self.rec_images_dirs=[]
        for dir in rec_dirs:
            for rt in rec_type:
                if rt in dir:
                    if str(radial_line) in dir:
                        self.rec_images_dirs.append(dir)
                        break

        self.images = [f for f in os.listdir(self.images_dir) if f.endswith('.png')]
        if indexes is not None:
            self.images = [self.images[i] for i in indexes] 
        self.transform = transforms.Compose([
                        transforms.Grayscale(num_output_channels=1),         
                        transforms.Resize(img_size),
                        #transforms.Lambda(lambda x: x/255.0),
                        transforms.ToTensor()
                        ])
        self.rec_types=rec_type
        self.radial_line=radial_line
        print(self.images_dir)
        print(self.rec_images_dirs)
        print(self.rec_types)
        print(self.radial_line)
    def __len__(self):
    # return length of image samples    
        return len(self.images)

    def __getitem__(self, idx):
        img_name=self.images[idx]
        img = Image.open(self.images_dir+img_name)
        img=self.transform(img)
        rec_imgs=[]
        for rec,dir in zip(self.rec_types,self.rec_images_dirs):
            noisy_name=img_name[:-14]+rec+f'_{self.radial_line}lines.png'            
            tensor=self.transform(Image.open(dir+noisy_name))
            rec_imgs.append(tensor)
        noisy=torch.stack(rec_imgs)
        noisy=torch.squeeze(noisy, 1)
        return (img,noisy)

def first_element(test_dataset):
    for data in test_dataset:
        print(data[0].shape)
        print(data[1].shape)
        break


In [51]:
idx_file='indexes.pkl'
if not os.path.exists(idx_file):
    np.random.seed(seed=42)
    all_indexes=np.random.permutation(len([f for f in os.listdir(images_dir) if f.endswith('.png')]))
    m = len(all_indexes)
    m_train=int(m*0.8)
    m_val = int(m*0.1)
    train_indexes=all_indexes[:m_train]
    val_indexes=all_indexes[m_train:m_train+m_val]
    test_indexes=all_indexes[m_train+m_val:]
    
    with open(idx_file, 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump([train_indexes, val_indexes, test_indexes], f)
else:
    with open(idx_file,'rb') as f:  # Python 3: open(..., 'rb')
        train_indexes, val_indexes, test_indexes = pickle.load(f)




In [52]:
batch_size=4

train_dataset={}
train_loaders={}
for rt in rectype_strings:
    for rl in radial_lines:
        train_ds=OriginalReconstructionDataset(rl, rt, dataset_dir, train_indexes)
        train_dataset[rl,rt]=train_ds
        train_loaders[rl,rt]=torch.utils.data.DataLoader(train_ds, batch_size=batch_size)
        first_element(train_ds)


val_dataset={}
val_loaders={}
for rt in rectype_strings:
    for rl in radial_lines:
        val_ds=OriginalReconstructionDataset(rl, rt, dataset_dir, val_indexes)
        val_dataset[rl,rt]=val_ds
        val_loaders[rl,rt]=torch.utils.data.DataLoader(val_ds, batch_size=batch_size)
        first_element(val_ds)

test_dataset={}
test_loaders={}
for rt in rectype_strings:
    for rl in radial_lines:
        test_ds=OriginalReconstructionDataset(rl, rt, dataset_dir, test_indexes)
        test_dataset[rl,rt]=test_ds
        test_loaders[rl,rt]=torch.utils.data.DataLoader(test_ds, batch_size=batch_size,shuffle=True)
        first_element(test_ds)


./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_20lines_L2/']
['L2']
20
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_40lines_L2/']
['L2']
40
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_60lines_L2/']
['L2']
60
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_80lines_L2/']
['L2']
80
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_100lines_L2/']
['L2']
100
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_20lines_L1/']
['L1']
20
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_40lines_L1/']
['L1']
40
torch.Size([1, 256, 256])
torch.Size([1, 256, 256])
./BIRN_dataset/birn_png/
['./BIRN_dataset/birn_pngs_60lines_L1/']
['L1']
60
torch.Size([1, 256, 256])
