In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
import numpy as np
import yaml
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image, ImageMath
from pycocotools.coco import COCO
import os
import skimage.io as io
import matplotlib.pyplot as plt
import yaml
from glob import glob
import imageio
from copy import deepcopy
import cv2
import torchvision.models as models
from utils import tri_mirror
import splitfolders 
import shutil

In [66]:
len(os.listdir(train_path)), len(os.listdir(val_path))

(139, 35)

In [67]:
os.listdir(data_dir)

['val', 'train']

In [52]:
data_dir = "/home/yyelisieiev/luftr_data/CroppedNewRivets/"
train_path = os.path.join(data_dir, "train")
val_path = os.path.join(data_dir, "val")

In [58]:
file_list = glob(os.path.join(data_dir, "*.png"))

In [60]:
len(file_list)

174

In [61]:
def list_splitter(list_to_split, ratio):
    elements = len(list_to_split)
    middle = int(elements * ratio)
    return [list_to_split[:middle], list_to_split[middle:]]

In [62]:
train, val = list_splitter(file_list, 0.8)

In [63]:
for image in train:
    image_name = image.split("/")[-1]
    shutil.move(image, os.path.join(train_path, image_name))

In [64]:
for image in val:
    image_name = image.split("/")[-1]
    shutil.move(image, os.path.join(val_path, image_name))

In [22]:
map(lambda x: shutil.move(os.path.join(data_dir, x), os.path.join(train_path, x)), train)

<map at 0x7f035748a6a0>

In [23]:
map(lambda x: shutil.move(os.path.join(data_dir, x), os.path.join(val_path, x)), val)

<map at 0x7f03571de520>

In [9]:
splitfolders.ratio(data_dir, seed=1337, ratio=(.8, .2), group_prefix=None)

Copying files: 0 files [00:00, ? files/s]


In [14]:
images = imageio.mimread("../new_logs/ReLUMaxPoolingQuaterCrop/summaries/8.gif", )

In [15]:
print(images.shape)

AttributeError: 'list' object has no attribute 'shape'

In [4]:
class RivetDataset(Dataset):
    ANN_DIR = "annotations"
    CATEGORY = "rivet"
    
    def __init__(self, config, transform, validation=False):
        self.root_dir = config['data']['data_dir']
        self.image_size = config['data']['image_size']
        self.image_type = config['data']['image_type']
        self.transform = transform
        self.img_channels = 1

        self.file_list = glob(os.path.join(self.root_dir, "*"))
        self.feature_extractor = models.vgg16(pretrained=True).features.cuda()
        self.feature_extractor.eval()

    
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        X = dict()
        img_path = self.file_list[idx]
        rivet = Image.open(img_path).resize((self.image_size, self.image_size))
#         print(rivet.shape)
        center = self.image_size // 2
        cv = self.image_size // 4
        
        if self.image_type == "Corrupted":
            rivet = np.array(rivet).astype(np.float32)
            center_rivet = deepcopy(rivet[center - cv: center + cv, center - cv: center + cv])
            rivet[center - cv: center + cv, center - cv: center + cv] = 0.0

            rivet_rgb = np.zeros((rivet.shape[0], rivet.shape[0], 3))
            rivet_rgb[:,:, 0] = rivet
            rivet_rgb[:,:, 1] = rivet
            rivet_rgb[:,:, 2] = rivet
            rivet = Image.fromarray((rivet_rgb*255).astype(np.uint8))
            center_rivet = Image.fromarray(center_rivet)
                
        elif self.image_type == "TriMir":
            rivet = tri_mirror(rivet, center, cv)
            
        rivet = self.transform(rivet).unsqueeze(0).cpu()
#         rivet = self.feature_extractor(rivet).view(-1, 1).detach().cpu().flatten()
        center_rivet = self.transform(center_rivet)

        rivet = rivet.unsqueeze(0)
        center_rivet = center_rivet.unsqueeze(0)

        X['Rivet'] = rivet
        X['Center'] = center_rivet
        
        return X

In [82]:
class CropConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, overlap, bias=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.overlap = overlap
        
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / np.sqrt(fan_in)
            torch.nn.init.torch.nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, x):
        batch_size, in_channels, in_h, in_w = x.shape

        crop_start = (in_h // 4) - self.overlap
        start_idx = in_h * crop_start + crop_start
        crop_size = (in_h // 2) + self.overlap * 2

        out_h = ((in_h - self.kernel_size + 2 * self.padding) //self.stride + 1)
        out_w = ((in_w - self.kernel_size + 2 * self.padding) //self.stride + 1)

        unfold = torch.nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size), dilation=self.dilation, padding=self.padding, stride=self.stride)
        inp_unf = unfold(x)
        crop_lst = []
        for i in range(crop_size):
            if i == 0:
                crop_lst.append(torch.ones([batch_size, self.out_channels, inp_unf[:, :, :start_idx].shape[2]], dtype=torch.bool))
            if i == crop_size - 1:
                crop_lst.append(torch.zeros([batch_size, self.out_channels, inp_unf[:, :, start_idx:start_idx + crop_size].shape[2]], dtype=torch.bool))
                crop_lst.append(torch.ones([batch_size, self.out_channels, inp_unf[:, :, start_idx+crop_size:(in_h**2)].shape[2]], dtype=torch.bool))
                break
            crop_lst.append(torch.zeros([batch_size, self.out_channels, inp_unf[:, :, start_idx:start_idx + crop_size].shape[2]], dtype=torch.bool))
            crop_lst.append(torch.ones([batch_size, self.out_channels, inp_unf[:, :, start_idx + crop_size: start_idx + in_h].shape[2]], dtype=torch.bool))
            start_idx += in_h

        crop_indexes = torch.cat(crop_lst, axis=2)
        out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t())

        if self.bias is None:
            out_unf = out_unf.transpose(1, 2)
        else:
            print(out_unf.shape)
            print(self.bias.shape)
            out_unf = (out_unf + self.bias).transpose(1, 2)
        out_unf = torch.where(crop_indexes, out_unf, torch.zeros(out_unf.shape, dtype=torch.float32))
        print(out_unf.shape)
        print(batch_size, self.out_channels, out_h, out_w)
        out = out_unf.view(batch_size, self.out_channels, out_h, out_w)
        return out

In [122]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            CropConv(in_channels=3, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1, overlap=0),
            nn.ReLU(),
            CropConv(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1, dilation=1, overlap=0),
            nn.ReLU(),
            CropConv(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, dilation=1, overlap=0),
            nn.ReLU(),
			CropConv(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, dilation=1, overlap=0),
            nn.ReLU(),
            CropConv(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, dilation=1, overlap=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2))
    def forward(self, x):
        encoded = self.encoder(x)
        return encoded

In [123]:
x = torch.randn(1, 1, 16, 16)

In [124]:
my_conv = CropConv(in_channels=1, out_channels=4, kernel_size=3, stride=2, padding=1, dilation=1, overlap=0, bias=True)

In [125]:
my_conv(x).shape

Crop size:  8
torch.Size([1, 64, 4])
torch.Size([4])
torch.Size([1, 4, 64])
1 4 8 8


torch.Size([1, 4, 8, 8])

In [126]:
(16 - 3 + 2)/2 + 1
# ((in_s - kernel_size + 2 * padding)/stride) + 1

8.5

In [127]:
encoder = Encoder()

In [128]:
with open("config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    
config['data']["data_dir"] = "/home/yyelisieiev/luftr_data/cropped_rivets"

transform = transforms.Compose([transforms.ToTensor()])

dataset = RivetDataset(config, transform, validation=False)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

In [129]:
for index, inputs in enumerate(dataloader):
    model_input = inputs['Rivet'][0][0]
    print("Model input: ", model_input.shape)
    out = encoder(model_input)
    print(out.shape)
    break
#         im.save(os.path.join(rivet_dir, f"{index}_{image_idx}_rivet.tif"))

Model input:  torch.Size([1, 3, 64, 64])
Crop size:  32
torch.Size([1, 128, 1024])
1 128 32 32
Crop size:  16
torch.Size([1, 256, 256])
1 256 16 16
Crop size:  8
torch.Size([1, 256, 64])
1 256 8 8
Crop size:  4
torch.Size([1, 256, 16])
1 256 4 4
Crop size:  2
torch.Size([1, 256, 4])
1 256 2 2
torch.Size([1, 256, 1, 1])


In [131]:
print(len(np.unique(out.detach().numpy())))

204
