In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, random_split
from torchvision.utils import make_grid
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import matplotlib as mpl
import seaborn as sns

device = torch.device("cpu") 

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Writing a dataloader to load image data from ../data/ folder

class DASImageDataset(Dataset):
    def __init__(self, root_dir, transform=transforms.ToTensor()):
        self.root_dir = root_dir
        self.files = os.listdir(self.root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img_file =(os.path.join(self.root_dir, self.files[index]))
        img = Image.open(img_file)
        
        if self.transform:
            img = self.transform(img).type(torch.LongTensor)
        
        return img, self.files[index]

#### Define Lambda function to feed into transformer & apply correlation filtering

In [None]:
# define my own Lambda to set up code
def center_image(image):  
    # normalize to reduce effects of brightness and contrast
    image = image - image.mean()
    return image / np.linalg.norm(image.reshape(-1))
class MyLambda(transforms.Lambda):
    def __init__(self, lambd, patchr, patchl, thresr, thresl):
        super().__init__(lambd)
        self.patchr = patchr
        self.patchl = patchl
        self.thresr = thresr
        self.thresl = thresl

    def __call__(self, img):
        return self.lambd(img, self.patchr, self.patchl, self.thresr, self.thresl)

def corr_filter(image, patchr, patchl, thresr, thresl):
    # padd initial image with edge pixels
    image_r_init = np.pad(np.array(image), int((patchr.shape[0]-1)//2), 'edge')
    image_l_init = np.pad(np.array(image), int((patchl.shape[0]-1)//2), 'edge')
    
    h, w = image_r_init.shape
    h1, w1 = patchr.shape
    outputr = np.zeros((h-h1+1, w-w1+1))
    for i in range(h-h1+1):
        for j in range(w-w1+1):
            image1 = center_image(image_r_init[i:i+h1, j:j+w1]).reshape(-1)
            patch1 = center_image(patchr.reshape(-1))
            outputr[i, j] = np.dot(patch1, image1)              

    outputr = np.where(outputr >= thresr, 255, 0)

    h, w = image_l_init.shape
    h1, w1 = patchl.shape
    outputl = np.zeros((h-h1+1, w-w1+1))
    for i in range(h-h1+1):
        for j in range(w-w1+1):
            image1 = center_image(image_l_init[i:i+h1, j:j+w1]).reshape(-1)
            patch1 = center_image(patchl.reshape(-1))
            outputl[i, j] = np.dot(patch1, image1)
            
    outputl = np.where(outputl >= thresl, 255, 0)
    
    output = outputr + outputl
    return output

In [None]:
# load in patches (feel free to change to different values)
with open('../data/patchr.npy', 'wb') as f:
    patchr = np.load(f)
    

In [None]:
transform = transforms.Compose([
    #transforms.Resize((256, 256)), # no need to resize images bc they're resized during pre-processing
    MyLambda(corr_filter, patchr, np.fliplr(patchr), 0.3, 0.4),
    transforms.ToTensor()
])

images = DASImageDataset(root_dir='./data/processed/', transform=transform)

train_len = int(0.8 * len(images))
test_len = len(images) - train_len

train_dataset, test_dataset = random_split(images, [train_len, test_len])

train_dataloader = DataLoader(train_dataset, batch_size=32, sampler=RandomSampler(train_dataset))
test_dataloader = DataLoader(test_dataset, batch_size=32, sampler=RandomSampler(test_dataset))