In [1]:
import torch as th
import torch.nn as nn
from PIL import Image
import numpy as np
from os import listdir
from os.path import isfile, join
import re
import random

In [2]:
class MyModel(nn.Module):
    def __init__(self, img_w, img_h):
        super(MyModel, self).__init__()
        in_channels = 3 # RGB
        out_channels = 1
        kernel = (3,3) # square kernel
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel, padding=1)
        self.rel = nn.ReLU()
        
        self.lin = nn.Linear(1, 1)
        self.act = nn.Sigmoid()
        
    def forward(self, data):
        out = self.conv(data)
        out = self.rel(out)
        out = out.flatten().unsqueeze(1)
        out = self.lin(out)
        return self.act(out)

In [3]:
def load_one_image(file):
    return Image.open(file)

In [4]:
def load_images_in_path(path):
    return {f: load_one_image(join(path, f)) for f in listdir(path) if isfile(join(path, f))}

In [5]:
train_img = load_images_in_path("/home/samuel/Documents/Cours/M2_AIC/Projet/train/images")
print("Nombre d'image de train : %d" % (len(train_img)))

Nombre d'image de train : 180


In [6]:
def get_town_name_list(img_dict):
    names = [f for f,img in img_dict.items()]
    names = list(map(lambda n: re.search("[A-Za-z]+", n).group(0), names))
    return set(names)

In [7]:
town_name_list = get_town_name_list(train_img)
print(town_name_list)

{'vienna', 'kitsap', 'tyrol', 'austin', 'chicago'}


In [8]:
def filter_img_dict(img_dict, limit):
    return {name: img for name, img in img_dict.items() if int(re.search("[0-9]+", name).group(0)) <= limit}

In [9]:
filtered_train_img = filter_img_dict(train_img, 9)
print(len(filtered_train_img))

45


In [10]:
def dict_to_sorted_list(img_dict):
    return sorted(img_dict.items(), key=lambda t: t[0])

In [11]:
sorted_train_img = dict_to_sorted_list(filtered_train_img)

gt_img = load_images_in_path("/home/samuel/Documents/Cours/M2_AIC/Projet/train/gt")
filtered_gt_img = filter_img_dict(gt_img, 9)
sorted_gt_img = dict_to_sorted_list(filtered_gt_img)

print(sorted_train_img[0])
print(sorted_gt_img[0])
print(sorted_train_img[-1])
print(sorted_gt_img[-1])

('austin1.jpg', <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2500x2500 at 0x7FACC0520B70>)
('austin1.jpg', <PIL.JpegImagePlugin.JpegImageFile image mode=L size=2500x2500 at 0x7FACC050AAC8>)
('vienna9.jpg', <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2500x2500 at 0x7FACC0549400>)
('vienna9.jpg', <PIL.JpegImagePlugin.JpegImageFile image mode=L size=2500x2500 at 0x7FACC04AD358>)


In [12]:
def img_sorted_list_to_numpy(sorted_img_list):
    return [np.expand_dims(np.moveaxis(np.asarray(img), -1, 0), axis=0) for _,img in sorted_img_list]

In [13]:
train_np = img_sorted_list_to_numpy(sorted_train_img)

In [14]:
print(train_np[0].shape)

(1, 3, 2500, 2500)


In [15]:
def gt_sorted_list_to_numpy(sorted_gt_img):
    return [np.asarray(gt).flatten().reshape(-1 ,1) for _,gt in sorted_gt_img]

In [16]:
gt_np = gt_sorted_list_to_numpy(sorted_gt_img)

In [17]:
print(gt_np[0].shape)

(6250000, 1)


```python
img = load_one_image("austin1.jpg")
label = load_one_image("austin1_label.jpg")
```

```python
type(img)
```

```python
arr = np.asarray(img)
label = np.asarray(label)
```

```python
print(arr.shape)
print(label.shape)
```

```python
w = arr.shape[0]
h = arr.shape[1]
c = arr.shape[2]
arr = arr.reshape(1, c, w, h) 
label = label.flatten().reshape(-1 ,1)

print(label.shape)
```

```python
print(arr.shape)
```

In [20]:
tmp = list(zip(train_np, gt_np))
random.shuffle(tmp)
train_np, gt_np = zip(*tmp)

In [None]:
nbEpoch = 30
learning_rate = 1e-2

model = MyModel(2500, 2500)
loss_fn = nn.BCELoss()
optim = th.optim.Adagrad(model.parameters(), lr=learning_rate)

for i in range(nbEpoch):
    model.train()
    sum_loss = 0
    for img, gt in zip(train_np, gt_np):
        optim.zero_grad()
        out = model(th.FloatTensor(img))
        y = th.FloatTensor(gt)
        loss = loss_fn(out, y)
        loss.backward()
        optim.step()
        sum_loss += loss.item()
    sum_loss /= len(train_np)
    print("Epoch %d, loss = %f" % (i, sum_loss))

Epoch 0, loss = -849.980221
