In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from torch import nn
import os
# from torch.utils.data import DataLoader
# from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from sklearn.model_selection import train_test_split


In [9]:
class Net(nn.Module):
    def __init__(self):
        super().__init__() # 7 x 22 x 22
        self.conv1 = nn.Conv2d(7, 16, 4,stride=2) # 16 x 10 x 10
        self.pool = nn.MaxPool2d(2,stride=1) # 16 x 9 x 9
        self.conv2 = nn.Conv2d(16, 32, 3,stride=2) # 32 x 4 x 4
        self.drop = nn.Dropout(p=0.15)
        self.fc1 = nn.Linear( 512, 120)   
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x,1) 
        x = self.drop(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=1)
        return x
    
def retrieve_images(fil_dir):
    sample_list = os.listdir(fil_dir)
    img_list = np.zeros((len(sample_list),7,40,500))
    ind = 0
    for file_name in sample_list:
        img_list[ind,:,:,:] = np.load(os.path.join(fil_dir,file_name))
        ind+=1
    return img_list[:ind,:,:,:]

def retrieve_masks(fil_dir):
    sample_list = os.listdir(fil_dir)
    img_list = np.zeros((len(sample_list),40,500))
    ind = 0
    for file_name in sample_list:
        if file_name[5:8] != "000":
            img_list[ind,:,:] = np.load(os.path.join(fil_dir,file_name))
            ind+=1
    return img_list.reshape(-1,1,40,500)[:ind,:,:,:]

def get_windowed_data(imgs):
    windows_per_im = 10*240
    fin_data = np.zeros((imgs.shape[0]*windows_per_im,7,22,22))
    for ind in range(imgs.shape[0]):
        fin_data[ind*windows_per_im:(ind+1)*windows_per_im,:,:,:] = sliding_window(imgs[ind,:,:,:])
        
    return fin_data

def sliding_window(image, sq_size = 22, pix_steps=2):
    img_list = []
    for i in np.arange(0,image.shape[1]-sq_size+pix_steps,pix_steps):
        for j in np.arange(0,image.shape[2]-sq_size+pix_steps,pix_steps):
            img_list.append(image[:,i:i+sq_size,j:j+sq_size])
    return np.array(img_list)


In [8]:
net = Net()
net.load_state_dict(torch.load("C:/Users/laure/Documents/proj_loopr/models_2_21_6_dropout_43tp_93tn/trained_net_70_epo.pth"))

<All keys matched successfully>

In [9]:
img_list = retrieve_images("C:/Users/laure/Documents/proj_loopr/processed_images")

In [10]:
img_list.shape

(104, 7, 40, 500)

In [11]:
im_test = img_list[82,:,:,:].reshape(1,7,40,500)
im_test_win = get_windowed_data(im_test, np.zeros((1,1,40,500)))

In [12]:
im_test_torch = torch.from_numpy(im_test_win).type(torch.FloatTensor)

In [13]:
outs_np = net(im_test_torch).detach().numpy()

In [14]:
outs_np.shape

(2400, 2)

In [26]:
sample_list = os.listdir("C:/Users/laure/Documents/proj_loopr/processed_images")
rl_list = []
for file_name in sample_list:
    if file_name[5:8] != "000":
        rl_list.append(file_name)


In [27]:
rl_list[82]

'0086_030_02.npy'

In [22]:
np.hstack((np.arange(10).reshape((-1,1)),np.arange(10).reshape((-1,1)))).transpose().reshape((2,2,5))[:,0,:]

array([[0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4]])

In [10]:
img_list = retrieve_images("C:/Users/laure/Documents/proj_loopr/processed_images")
name_list = os.listdir("C:/Users/laure/Documents/proj_loopr/processed_images")

net = Net()
net.load_state_dict(torch.load("C:/Users/laure/Documents/proj_loopr/models_2_21_6_dropout_43tp_93tn/trained_net_70_epo.pth"))
for el in range(len(name_list)):
    im_win = get_windowed_data(img_list[el,:,:,:].reshape(1,7,40,500))
    im_win = torch.from_numpy(im_win).type(torch.FloatTensor)
    outs_np = net(im_win).detach().numpy()
    outs_np = outs_np.transpose().reshape((2,240,10))
    np.save("C:/Users/laure/Documents/proj_loopr/windowed_output/"+name_list[el],outs_np)