In [3]:
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms#, datasets

# Image processing module:
from PIL import Image, ImageCms

In [4]:
# Open image and discard alpha channel which makes wheel round rather than square
img_arr = Image.open("domddcz.png").convert('RGB') #without convert it's set to "P".

In [5]:
# Convert the picture into Lab colourspace
srgb_p = ImageCms.createProfile("sRGB")
lab_p  = ImageCms.createProfile("LAB")

rgb2lab = ImageCms.buildTransformFromOpenProfiles(srgb_p, lab_p, "RGB", "LAB")
Lab = ImageCms.applyTransform(img_arr, rgb2lab) #converted picture.

In [6]:
image_lightness = Lab.getchannel('L') #get the Lightness channel.

In [7]:
image = transforms.ToTensor()(image_lightness) 
#Lab_test2 = Lab_test.view(-1, 1, 224, 224)

In [8]:
def weights_init(model):
    if type(model) in [nn.Conv2d, nn.Linear]:
        nn.init.xavier_normal_(model.weight.data)
        nn.init.constant_(model.bias.data, 0.1)
class Color_network(nn.Module):

    def __init__(self):
        super(Color_network, self).__init__()
        self.features = nn.Sequential(
            # conv1
            nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 64),
            # conv2
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 128),
            # conv3
            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 256),
            # conv4
            nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            # conv5
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            # conv6
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            # conv7
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            # conv8
            nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size = 4, stride = 2, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            # conv8_313
            nn.Conv2d(in_channels = 256, out_channels = 224, kernel_size = 1, stride = 1,dilation = 1),
            #eltwise product
            # softmax
            #nn.Softmax(dim=1),
            # decoding
            #nn.Conv2d(in_channels = 313, out_channels = 2, kernel_size = 1, stride = 1)
        )
        self.apply(weights_init)
        
        # kernel
        #self.conv1 = nn.Conv2d(1, 8, 3) # 1 input image channel, 8 output channels, 3x3 square convolution
        
        #self.conv2 = nn.Conv2d(8, 8, 3, stride=2)
        #self.conv3 = nn.Conv2d(8, 16, 3)
        #self.conv4 = nn.Conv2d(16, 16, 3)
        #self.conv5 = nn.Conv2d(16, 16, 3, stride=2)
        #self.conv6 = nn.Conv2d(16, 16, 3)
        #self.conv7 = nn.Conv2d(16, 32, 3, stride=2)
        #self.conv8 = nn.Conv2d(32, 64, 3)
        #self.conv9 = nn.Conv2d(64, 128, 3)
        #self.conv10 = nn.Conv2d(128, 128, 3, stride=2)
        #self.conv11 = nn.Conv2d(128, 64, 3)
        #self.conv12 = nn.Conv2d(64, 32, 3)
        #self.conv13 = nn.Conv2d(32, 16, 3)
        #self.conv14 = nn.Conv2d(16, 2, 3) #Two output channels for the "ab" channels.

    # why would the value update on every move forward into the next layer?
    def forward(self, x):
        # Zhang doesn't use pooling. Up/down-sampling it is.
        # Max pooling over a (2, 2) window #x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        #x = x.view(-1, self.num_flat_features(x))# - interesting part
        print('x',x.size())
        features=self.features(x)
        x=x/0.38
        return x
        
        #x = F.relu(self.conv1(x))
        #x = F.relu(self.conv2(x))
        #x = F.relu(self.conv3(x))
        #x = F.relu(self.conv4(x))
        #x = F.relu(self.conv5(x))
        #x = F.relu(self.conv6(x))
        ##x = F.interpolate(self.conv6(x), size=2) #conv6 -32 size
        #x = F.relu(self.conv7(x))
        ##x = F.interpolate(self.conv7(x), size=2) #conv7 -32 size
        #x = F.relu(self.conv8(x))
        ##x = F.interpolate(self.conv8(x), size=2) #conv8 -32 size
        ##x = torch.tanh(self.conv7(x)) #tanh instead of relu for final layer.
        #return x

In [9]:
model = Color_network()

In [10]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss(reduction='none').cuda()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, lr = 0.001) #args.learning_rate

In [11]:
print(image.shape)

torch.Size([1, 224, 224])


In [21]:
#image = image.unsqueeze(1).float()
#Lab_test = Lab_test.squeeze(1).float()
#Lab_test = transforms.ToTensor()(image_lightness) 

#for i in range(100):
#    outputs = model(image_lightness)#.log()
    
# Train the models
#total_step = len(data_loader)
#image = image.view(-1, 1, 224, 224) 
#image_lightness = image_lightness.unsqueeze(0).float().cuda()
for epoch in range(1000):
    outputs = model(image)#.log()
    #output=outputs[0].cpu().data.numpy()

    model.zero_grad()

    #loss.backward()
    optimizer.step()
    print(epoch)

x torch.Size([1, 1, 224, 224])
0
x torch.Size([1, 1, 224, 224])
1
x torch.Size([1, 1, 224, 224])
2
x torch.Size([1, 1, 224, 224])
3
x torch.Size([1, 1, 224, 224])
4
x torch.Size([1, 1, 224, 224])
5
x torch.Size([1, 1, 224, 224])
6
x torch.Size([1, 1, 224, 224])
7
x torch.Size([1, 1, 224, 224])
8
x torch.Size([1, 1, 224, 224])
9
x torch.Size([1, 1, 224, 224])
10
x torch.Size([1, 1, 224, 224])
11
x torch.Size([1, 1, 224, 224])
12
x torch.Size([1, 1, 224, 224])
13
x torch.Size([1, 1, 224, 224])
14
x torch.Size([1, 1, 224, 224])
15
x torch.Size([1, 1, 224, 224])
16
x torch.Size([1, 1, 224, 224])
17
x torch.Size([1, 1, 224, 224])
18
x torch.Size([1, 1, 224, 224])
19
x torch.Size([1, 1, 224, 224])
20
x torch.Size([1, 1, 224, 224])
21
x torch.Size([1, 1, 224, 224])
22
x torch.Size([1, 1, 224, 224])
23
x torch.Size([1, 1, 224, 224])
24
x torch.Size([1, 1, 224, 224])
25
x torch.Size([1, 1, 224, 224])
26
x torch.Size([1, 1, 224, 224])
27
x torch.Size([1, 1, 224, 224])
28
x torch.Size([1, 1, 224,

237
x torch.Size([1, 1, 224, 224])
238
x torch.Size([1, 1, 224, 224])
239
x torch.Size([1, 1, 224, 224])
240
x torch.Size([1, 1, 224, 224])
241
x torch.Size([1, 1, 224, 224])
242
x torch.Size([1, 1, 224, 224])
243
x torch.Size([1, 1, 224, 224])
244
x torch.Size([1, 1, 224, 224])
245
x torch.Size([1, 1, 224, 224])
246
x torch.Size([1, 1, 224, 224])
247
x torch.Size([1, 1, 224, 224])
248
x torch.Size([1, 1, 224, 224])
249
x torch.Size([1, 1, 224, 224])
250
x torch.Size([1, 1, 224, 224])
251
x torch.Size([1, 1, 224, 224])
252
x torch.Size([1, 1, 224, 224])
253
x torch.Size([1, 1, 224, 224])
254
x torch.Size([1, 1, 224, 224])
255
x torch.Size([1, 1, 224, 224])
256
x torch.Size([1, 1, 224, 224])
257
x torch.Size([1, 1, 224, 224])
258
x torch.Size([1, 1, 224, 224])
259
x torch.Size([1, 1, 224, 224])
260
x torch.Size([1, 1, 224, 224])
261
x torch.Size([1, 1, 224, 224])
262
x torch.Size([1, 1, 224, 224])
263
x torch.Size([1, 1, 224, 224])
264
x torch.Size([1, 1, 224, 224])
265
x torch.Size([1,

472
x torch.Size([1, 1, 224, 224])
473
x torch.Size([1, 1, 224, 224])
474
x torch.Size([1, 1, 224, 224])
475
x torch.Size([1, 1, 224, 224])
476
x torch.Size([1, 1, 224, 224])
477
x torch.Size([1, 1, 224, 224])
478
x torch.Size([1, 1, 224, 224])
479
x torch.Size([1, 1, 224, 224])
480
x torch.Size([1, 1, 224, 224])
481
x torch.Size([1, 1, 224, 224])
482
x torch.Size([1, 1, 224, 224])
483
x torch.Size([1, 1, 224, 224])
484
x torch.Size([1, 1, 224, 224])
485
x torch.Size([1, 1, 224, 224])
486
x torch.Size([1, 1, 224, 224])
487
x torch.Size([1, 1, 224, 224])
488
x torch.Size([1, 1, 224, 224])
489
x torch.Size([1, 1, 224, 224])
490
x torch.Size([1, 1, 224, 224])
491
x torch.Size([1, 1, 224, 224])
492
x torch.Size([1, 1, 224, 224])
493
x torch.Size([1, 1, 224, 224])
494
x torch.Size([1, 1, 224, 224])
495
x torch.Size([1, 1, 224, 224])
496
x torch.Size([1, 1, 224, 224])
497
x torch.Size([1, 1, 224, 224])
498
x torch.Size([1, 1, 224, 224])
499
x torch.Size([1, 1, 224, 224])
500
x torch.Size([1,

707
x torch.Size([1, 1, 224, 224])
708
x torch.Size([1, 1, 224, 224])
709
x torch.Size([1, 1, 224, 224])
710
x torch.Size([1, 1, 224, 224])
711
x torch.Size([1, 1, 224, 224])
712
x torch.Size([1, 1, 224, 224])
713
x torch.Size([1, 1, 224, 224])
714
x torch.Size([1, 1, 224, 224])
715
x torch.Size([1, 1, 224, 224])
716
x torch.Size([1, 1, 224, 224])
717
x torch.Size([1, 1, 224, 224])
718
x torch.Size([1, 1, 224, 224])
719
x torch.Size([1, 1, 224, 224])
720
x torch.Size([1, 1, 224, 224])
721
x torch.Size([1, 1, 224, 224])
722
x torch.Size([1, 1, 224, 224])
723
x torch.Size([1, 1, 224, 224])
724
x torch.Size([1, 1, 224, 224])
725
x torch.Size([1, 1, 224, 224])
726
x torch.Size([1, 1, 224, 224])
727
x torch.Size([1, 1, 224, 224])
728
x torch.Size([1, 1, 224, 224])
729
x torch.Size([1, 1, 224, 224])
730
x torch.Size([1, 1, 224, 224])
731
x torch.Size([1, 1, 224, 224])
732
x torch.Size([1, 1, 224, 224])
733
x torch.Size([1, 1, 224, 224])
734
x torch.Size([1, 1, 224, 224])
735
x torch.Size([1,

942
x torch.Size([1, 1, 224, 224])
943
x torch.Size([1, 1, 224, 224])
944
x torch.Size([1, 1, 224, 224])
945
x torch.Size([1, 1, 224, 224])
946
x torch.Size([1, 1, 224, 224])
947
x torch.Size([1, 1, 224, 224])
948
x torch.Size([1, 1, 224, 224])
949
x torch.Size([1, 1, 224, 224])
950
x torch.Size([1, 1, 224, 224])
951
x torch.Size([1, 1, 224, 224])
952
x torch.Size([1, 1, 224, 224])
953
x torch.Size([1, 1, 224, 224])
954
x torch.Size([1, 1, 224, 224])
955
x torch.Size([1, 1, 224, 224])
956
x torch.Size([1, 1, 224, 224])
957
x torch.Size([1, 1, 224, 224])
958
x torch.Size([1, 1, 224, 224])
959
x torch.Size([1, 1, 224, 224])
960
x torch.Size([1, 1, 224, 224])
961
x torch.Size([1, 1, 224, 224])
962
x torch.Size([1, 1, 224, 224])
963
x torch.Size([1, 1, 224, 224])
964
x torch.Size([1, 1, 224, 224])
965
x torch.Size([1, 1, 224, 224])
966
x torch.Size([1, 1, 224, 224])
967
x torch.Size([1, 1, 224, 224])
968
x torch.Size([1, 1, 224, 224])
969
x torch.Size([1, 1, 224, 224])
970
x torch.Size([1,

In [27]:
print(outputs.shape)
#print(outputs)
#testing = testing.squeeze(0)
#print(testing.shape)
#testing = testing * 128
#print(testing.shape)
#print(testing)
#im = transforms.ToPILImage()(testing)
#print(im.mode)
#im.save('sequential-test.jpg')

torch.Size([1, 1, 224, 224])
torch.Size([224, 224])


In [28]:
im = transforms.ToPILImage()(testing)
im.save('newNN-whodisP2.jpg')