In [1]:
import torch
import torch.nn as nn
import math
import numpy as np
from PIL import Image
import cv2
import os
import sys
import torchvision
import torch.optim
import torch.backends.cudnn as cudnn
import time
from tqdm import tqdm

In [2]:
#network architecture
class network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0, bias=True)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv5 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2, bias=True)
        self.conv6 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3, bias=True)
        self.conv7 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=9, stride=1, padding=4, bias=True)
        self.conv8 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv9 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv10 = nn.Conv2d(in_channels=27, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True)
            
    def forward(self, x):
        x0 = x
        x1 = self.relu(self.conv1(x))
        x2 = self.relu(self.conv2(x1))

        concat1 = torch.cat((x1,x2), 1)
        x3 = self.relu(self.conv3(concat1))

        concat2 = torch.cat((x2,x3), 1)
        x4 = self.relu(self.conv4(concat2))

        concat3 = torch.cat((x3,x4),1)
        x5 = self.relu(self.conv5(concat3))
        
        concat4 = torch.cat((x4,x5),1)
        x6 = self.relu(self.conv6(concat4))
        
        concat5 = torch.cat((x5,x6),1)
        x7 = self.relu(self.conv7(concat5))
        
        concat6 = torch.cat((x6,x7),1)
        x8 = self.relu(self.conv8(concat6))
        
        concat7 = torch.cat((x7,x8),1)
        x9 = self.relu(self.conv9(concat7))
        
        concat8 = torch.cat((x1,x2,x3,x4,x5,x6,x7,x8,x9),1)
        x10 = self.relu(self.conv10(concat8))

        clean_image = self.relu((x10 * x) - x10 + 1) 

        return clean_image

In [3]:
#check the network architecture
print(network())

network(
  (relu): ReLU(inplace=True)
  (conv1): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(6, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv6): Conv2d(6, 3, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (conv7): Conv2d(6, 3, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv8): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv9): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv10): Conv2d(27, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [4]:
#initialize the network
def initialization(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)   

#check if the GPU accleration is available (training without GPU accleration can be extremely slow)
print(torch.cuda.is_available())
#from CPU to GPU
net=network()
#net = network().cuda()
net.apply(initialization)
#check the weight
#net.conv1.weight           

True


network(
  (relu): ReLU(inplace=True)
  (conv1): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(6, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv6): Conv2d(6, 3, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (conv7): Conv2d(6, 3, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv8): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv9): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv10): Conv2d(27, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [5]:
#Please change the name of the root according to where your training dataset is.
#You need one train.txt file under the root (for example:C:\\Users\\Gengqian Yang\\dataset\\VOC2012\\ImageSets\\Segmentation\\ ) 
#The train.txt file should include the name for all training examples
#The training dataset can be found at Github, you only need to change the name of the root if you already downloaded it.
def read_voc_images(root="C:\\Users\\Gengqian Yang\\dataset\\VOC2012", 
                    is_train=True, max_num=None):
    txt_fname = '%s/ImageSets/Segmentation/%s' % (
        root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
        images_fullnames = f.read().split() #images is the list of all images
        images_v1=[]
        images=[]
        labels_names=[]
        for i in range(len(images_fullnames)):
            if i%2==1:
                images_v1.append(images_fullnames[i])
        for image_fullname in images_v1:
            images.append(image_fullname.split('\\')[-1])
        for image in images:
            labels_names.append(image.split("_")[0])
    if max_num is not None:
        images = images[:min(max_num, len(images))]
        labels_names = labels_names[:min(max_num, len(labels_names))]
    features, labels = [None] * len(images), [None] * len(labels_names) # features and labels are empty lists
    for i, fname in tqdm(enumerate(images)):
        features[i] = Image.open('%s\\JPEGImages\\%s' % (root, fname)).convert("RGB") #feature is the list of input images
    for i, fname in tqdm(enumerate(labels_names)):
        labels[i] = Image.open('%s\\SegmentationClass\\%s.png' % (root, fname)).convert("RGB")#label is the list of ground truth images
    return features, labels # PIL images list

class DataSet(torch.utils.data.Dataset):
    def __init__(self, is_train, voc_dir, max_num=None):
        #data standardisation
        self.rgb_mean = np.array([0.485, 0.456, 0.406]) #mean
        self.rgb_std = np.array([0.229, 0.224, 0.225]) #standard deviation
        self.transform=torchvision.transforms.ToTensor()
        self.tsf = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=self.rgb_mean, 
                                             std=self.rgb_std)
        ]) #transform the input list into a tensor then apply the standardisation
        
        features, labels = read_voc_images(root=voc_dir, 
                                           is_train=is_train, 
                                           max_num=max_num) # load features and labels
        self.features = features   
        self.labels = labels
        print('read ' + str(len(self.features)) + ' valid examples') 
    
    def __getitem__(self, idx):
        
        feature, label = self.features[idx], self.labels[idx]
                                       #randomly choose corresponding feature and label of the index
        
        return self.transform(feature),self.transform(label)
                #standardisation

    def __len__(self):
        return len(self.features) #return the length of dataset
    

In [6]:
#Please change the name of the root according to where your training dataset is.
voc_dir = "C:\\Users\\Gengqian Yang\\dataset\\VOC2012"
#choose the number of training examples (RAM consuming!!! 5000 images will occupy 16G of RAM)
max_num = 5000
#max_num =None
num_workers = 0
#activate the GPU acceleration (otherwise training will be 10 times slower!)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#training function
def train_model(model, criterion, optimizer, dataload, num_epochs):
    min_loss = 100
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm(net.parameters(),0.1)
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
        if (epoch_loss/step<=min_loss):
            min_loss = epoch_loss/step
            torch.save(model.state_dict(), "new modified network weights attempt.pth")
            print("saved")
    return model
    

def train(batch_size,num_epochs):
    #activate this function to initialise a network at the first time
    #model = network().apply(initialization).to(device)
    model =network().to(device)
    #carrying on the training by loading the weights saved by previous training
    model.load_state_dict(torch.load("new modified network weights.pth",map_location=torch.device('cpu')))
    criterion = nn.MSELoss()
    #optimizer = torch.optim.Adam(model.parameters())
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
    voc_train = DataSet(True, voc_dir, max_num)
    train_iter = torch.utils.data.DataLoader(voc_train, batch_size=batch_size, shuffle=True,
                              drop_last=True, num_workers=0)
    print(device)
    train_model(model, criterion, optimizer, train_iter, num_epochs)

In [None]:
#run this function to train the network
train(batch_size=8, num_epochs=100)

In [5]:
#### load a testing example
img=Image.open('C:\\Users\\Gengqian Yang\\Desktop\\runtime test\\1.png')
img.size

(620, 460)

In [6]:
#transfer it into a tensor
transform=torchvision.transforms.ToTensor()

In [7]:
#input=tsf(img).unsqueeze(0)
#add one dimension as the input of the network is (batch number, 3, h, w)
input=transform(img).unsqueeze(0)

In [8]:
#activate the evaluation mode to freeze the weight
net.eval()
#load the weights
net.load_state_dict(torch.load("new modified network weights.pth",map_location=torch.device('cpu')))

<All keys matched successfully>

In [11]:
output=net(input)
unloader = torchvision.transforms.ToPILImage()
image = output  # clone the tensor
image = image.squeeze(0)  # remove the fake batch dimension
image = unloader(image)
#check the image by PIL image library
image.show()

In [10]:
#save the result
image.save('C:\\Users\\Gengqian Yang\\CNN results\\home 3 network dehazed.jpg')