In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import copy
import random

In [2]:
seed=111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic=True

In [3]:
if torch.cuda.is_available():
    device="cuda"
else:
    device="cpu"

In [4]:
training=torchvision.datasets.CIFAR10(
    root="data", 
    train=True, 
    download=True
)

testing=torchvision.datasets.CIFAR10(
    root="data", 
    train=False, 
    download=True
)

In [5]:
def images_to_np(dataset):
    images=[]
    for example in dataset:
        image=np.array(example[0], dtype="float32")/255.0
        images.append(image)
    return images

In [6]:
def modify_images(images):
    modified_images=[]
    mask=np.zeros((4, 4, 3), dtype="float32")
    for image in images:
        modified_image=copy.deepcopy(image)
        modified_image[14:18, 14:18, :]=mask
        modified_images.append(modified_image)
    return modified_images

In [7]:
def images_to_targets(images):
    targets=[]
    for image in images:
        targets.append(image[14:18, 14:18, :])
    return targets

In [8]:
def images_to_tensors(images):
    tensors=[]
    for image in images:
        shape=image.shape[0]
        tensor=copy.deepcopy(image)
        tensor=torch.tensor(tensor, dtype=torch.float32)
        tensor=tensor.permute(2, 0, 1)
        tensor=tensor.reshape(1, 3, shape, shape)
        tensor=tensor.to(device)
        tensors.append(tensor)
    return tensors

In [9]:
def tensors_to_batches(tensors, batch_size=100):
    mini_batches=[]
    length=len(tensors)
    batch_number=length//batch_size+int(bool(length%batch_size))
    for batch in range(batch_number):
        current_size=min(batch_size, length-batch*batch_size)
        shape=tensors[0].shape[-1]
        mini_batch=torch.zeros((current_size, 3, shape, shape), dtype=torch.float32)
        for tensor in range(current_size): 
            mini_batch[tensor]=tensors[batch*batch_size+tensor]
        mini_batch=mini_batch.to(device)
        mini_batches.append(mini_batch)
    return mini_batches

In [10]:
def image_processing(dataset, batch_size=100):
    images=images_to_np(dataset)
    modified_images=modify_images(images)
    images=images_to_targets(images)
    tensors=images_to_tensors(images)
    modified_tensors=images_to_tensors(modified_images)
    mini_batches=tensors_to_batches(tensors, batch_size)
    modified_mini_batches=tensors_to_batches(modified_tensors, batch_size)
    return images, modified_images, tensors, modified_tensors, mini_batches, modified_mini_batches

In [11]:
train_images, train_modified_images, train_tensors, train_modified_tensors, train_mini_batches, train_modified_mini_batches=image_processing(training)
test_images, test_modified_images, test_tensors, test_modified_tensors, test_mini_batches, test_modified_mini_batches=image_processing(testing)

In [12]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.network=nn.Sequential(
            self.block(3, 64), 
            self.block(64, 128),
            self.block(128, 256), 
            self.block(256, 512), 
            nn.MaxPool2d(2), 
            self.block(512, 256), 
            self.block(256, 128), 
            self.block(128, 64), 
            self.block(64, 3), 
            nn.Sigmoid()
        )
    def block(self, x, y):
        self.network_part=nn.Sequential(
            nn.Conv2d(x, y, 3), 
            nn.ReLU(), 
            nn.BatchNorm2d(y)
        )
        return self.network_part
    def forward(self, x):
        return self.network(x)

In [13]:
model=NeuralNetwork().to(device)

In [14]:
loss_fn=nn.L1Loss()
optimizer=torch.optim.SGD(model.parameters(), lr=0.1)

In [15]:
def image_example(modified_image_batches, modified_images, images, place, batch_size=100):
    with torch.no_grad():
        prediction=model(modified_image_batches[place//batch_size])[place%batch_size].permute(1, 2, 0).detach().cpu()
        image=copy.deepcopy(modified_images[place])
        plt.figure(figsize=(16, 30))
        plt.subplot(1, 3, 1)
        plt.title("original")
        plt.imshow(image)
        image[14:18, 14:18, :]=prediction
        plt.subplot(1, 3, 2)
        plt.title("predicted")
        plt.imshow(image)
        plt.subplot(1, 3, 3)
        plt.title("real")
        image[14:18, 14:18, :]=images[place]
        plt.imshow(image)
        plt.show()

In [16]:
# Approximate average distance between real and predicted pixels

def validation_distance():
    with torch.no_grad():
        n=len(test_modified_mini_batches)
        average_distance=0
        for batch in range(n):
            res=model(test_modified_mini_batches[batch])
            loss=loss_fn(res, test_mini_batches[batch])
            average_distance+=float(loss)
        average_distance=round(average_distance/n*255, 2)
        return average_distance

In [17]:
epochs=10
# Image for tracking training progression
place=32903
data=[]
batch_number=len(train_mini_batches)
for epoch in range(epochs):
    image_example(train_modified_mini_batches, train_modified_images, train_images, place)
    val=validation_distance()
    print(f"Validation distance: {val:.2f}\n")
    for batch in range(batch_number):
        res=model(train_modified_mini_batches[batch])
        loss=loss_fn(res, train_mini_batches[batch])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch%100==0:
            # average distance between true and predicted pixels
            distance=round(float(loss)*255, 2)
            print(f"epoch: {epoch:3}  |  batch number: {batch:3}  |  distance: {distance:.2f}")
            data.append(distance)

In [121]:
plt.title("training distance")
plt.plot(data[:])
plt.show()

In [123]:
# Examples of predictions on test set

example=4116
image_example(test_modified_mini_batches, test_modified_images, test_images, example)