<a href="https://colab.research.google.com/github/0xtaha/image-super-resloution-for-remote-sensing/blob/main/GAN_Models%5CEEGAN_Remote_Sensing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. train the VGG-19 model


### A . Download The Imagenet dataset required to train VGG-19

In [None]:
import os
import shutil
from tqdm import tqdm
import requests
import tarfile
import re
import sqlite3

def download_urls():
    lists = [
        ['http://image-net.org/imagenet_data/urls/imagenet_fall11_urls.tgz',
         'data/urls/imagenet_fall11_urls.tgz',
         'data/urls/fall11.txt'],
        ['http://image-net.org/imagenet_data/urls/imagenet_winter11_urls.tgz',
         'data/urls/imagenet_winter11_urls.tgz',
         'winter11.txt'],
        ['http://image-net.org/imagenet_data/urls/imagenet_spring10_urls.tgz',
         'data/urls/imagenet_spring10_urls.tgz',
         'data/urls/spring10.txt'],
        ['http://image-net.org/imagenet_data/urls/imagenet_fall09_urls.tgz',
         'data/urls/imagenet_fall09_urls.tgz',
         'data/urls/fall09.txt']]

    for list_ in lists:
        url = list_[0]
        tar_path = list_[1]
        txt_path = list_[2]

        file_size = int(requests.head(url).headers["content-length"])
        r = requests.get(url, stream=True)
        pbar = tqdm(total=file_size, unit="b", unit_scale=True)
        with open(tar_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                f.write(chunk)
                pbar.update(len(chunk))
        pbar.close()
        
        tar = tarfile.open(tar_path, 'r')
        for item in tar:
            tar.extract(item, '.')
            shutil.move(item.name, txt_path)

In [None]:
os.mkdir('data')
os.mkdir('data/urls')
download_urls()

### B . Create Datebase required for imagenet

In [None]:
import os
import re
import sqlite3
from tqdm import tqdm
import sys
sys.path.append('./labels')
import objects

def create_db():
    lists = [
        ['urls/fall11.txt', 'fall11'],
        ['urls/winter11.txt', 'winter11'],
        ['urls/spring10.txt', 'spring10'],
        ['urls/fall09.txt', 'fall09'],
    ]
    target = list(objects.objects.keys())
        
    con = sqlite3.connect('data/imagenet.db')
    sql = '''CREATE TABLE urls (id integer primary key autoincrement,
        parent varchar(255), object varchar(255), seq varchar(255),
        url varchar(65535) unique, download boolean, error boolean);'''
    con.execute(sql)
    
    pattern = r"(.+)_(.+)\t(.+)\n"
    for l in lists:
        txt_path, parent = l
        print(parent)
        with open(txt_path, 'rb') as f:
            x = f.readlines()
        for i, x_ in tqdm(enumerate(x)):
            try: 
                x_ = str(x_, 'utf-8')
            except:
                continue
            matchOB = re.match(pattern, x_)
            object_ = matchOB.group(1)
            seq = matchOB.group(2)
            url = matchOB.group(3)
            if object_ not in target:
                continue
            sql = '''INSERT INTO urls (parent, object, seq, url, 
                     download, error) values (?, ?, ?, ?, ?, ?);'''
            user = (parent, object_, seq, url, 0, 0)
            try:
                con.execute(sql, user)
            except:
                continue
        con.commit()
    con.close()

In [None]:
create_db()

### C . Download the images

In [None]:
import requests
import numpy as np
import os
import shutil
import re
import threading
import time
import sqlite3

def download(*h):
    def add_error(id_):
        sql = '''UPDATE urls SET error = 1 WHERE id = ?;'''
        update_status(sql, id_)

    def add_download(id_):
        sql = '''UPDATE urls SET download = 1 WHERE id = ?;'''
        update_status(sql, id_)

    def update_status(sql, id_):
        con = sqlite3.connect('data/imagenet.db')
        con.execute(sql, (str(id_),))
        con.commit()
        con.close()
        
    for h_ in h:
        id_, parent, object_, seq, url, _, _ = h_
        time.sleep(10)
        try:
            r = requests.get(url, stream=True, timeout=10)
        except:
            print('ERROR')
            add_error(id_)
            continue
        if r.status_code == 200:
            dir_ = os.path.join('data', 'raw', object_)
            if not os.path.exists(dir_):
                os.mkdir(dir_)
            path = os.path.join(dir_, '{}_{}.jpg'.format(parent, seq))
            with open(path, 'wb') as f:
                try:
                    for chunk in r.iter_content(chunk_size=1024):
                        f.write(chunk)
                except:
                    print('ERROR')
                    add_error(id_)
                    continue
            add_download(id_)
            print(url)
        else:
            print('ERROR')
            add_error(id_)


def get_lists():
    con = sqlite3.connect('data/imagenet.db')
    cur = con.cursor()
    sql = '''SELECT * from urls WHERE download = 0 and error = 0;'''
    cur.execute(sql) 
    lists = cur.fetchall()
    cur.close()
    con.close()
    return lists
    

def missing_teddy():
    ''' "n04399382: teddy, teddy bear" cannot be downloaded. '''
    ''' There is no n04399382 image. '''
    if not os.path.exists('data/raw/n04399382'):
    	os.mkdir('data/raw/n04399382') 

In [None]:
if not os.path.exists('data/raw'):
        os.mkdir('data/raw')
    
n_threads = 3
lists = get_lists()
x = int(np.ceil(len(lists) / n_threads))

for i in range(n_threads):
  h = lists[i*x:(i+1)*x]
  th = threading.Thread(target=download, args=h).start()

missing_teddy()

### D . Preprocess Imagenet images

In [None]:
import os
import glob
import cv2
from PIL import Image
import io
import numpy as np

def preprocess():
    pp = glob.glob('data/raw/*')
    pp.sort()
    for i, p in enumerate(pp):
        print(i, p)
        paths = glob.glob(os.path.join(p, '*'))
        x = []
        for path in paths:
            with open(path, 'rb') as img_bin:
                buff = io.BytesIO()
                buff.write(img_bin.read())
                buff.seek(0)
                try:
                    temp = np.array(Image.open(buff), dtype=np.uint8)
                except:
                    continue
                if temp.ndim != 3:
                    continue
                try:
                    img = cv2.cvtColor(temp, cv2.COLOR_RGB2BGR)
                except:
                    continue
            if img is None:
                continue
            img = cv2.resize(img, (96, 96))
            x.append(img)
        x = np.array(x, dtype=np.uint8)
        np.random.shuffle(x)
        r = int(len(x) * 0.95)
        x_train = x[:r]
        x_test = x[r:]
        print(x_train.shape, x_test.shape)
        id_ = "{0:04d}".format(i)
        np.save('data/npy/train/{}.npy'.format(id_), x_train)
        np.save('data/npy/test/{}.npy'.format(id_), x_test)

In [None]:
os.mkdir('data/npy')
os.mkdir('data/npy/train')
os.mkdir('data/npy/test')
preprocess()

### E . Train the VGG-19

In [None]:
import numpy as np
import cv2
import os
import glob
from tqdm import tqdm

def _load(src):
    paths = glob.glob(src)
    paths.sort()
    paths = paths[:100] # 100 classes
    x = None
    t = []
    for path in tqdm(paths):
        id_ = int(os.path.basename(path).split('.')[0])
        c = np.load(path)
        if c.size == 0:
            continue
        l = [id_ for _ in range(c.shape[0])]
        if x is None:
            x = c
        else:
            x = np.concatenate((x, c), 0)
        t += l
    t = np.array(t)
    return [x, t]

def load():
    x_train, t_train = _load('./imagenet/data/npy/train/*')
    x_test, t_test = _load('./imagenet/data/npy/test/*')
    return x_train, t_train, x_test, t_test


In [None]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import argparse
import sys
sys.path.append('../utils')
from vgg19 import VGG19
import load
import augment

learning_rate = 1e-3
batch_size = 128

def train():
    x = tf.placeholder(tf.float32, [None, 96, 96, 3])
    t = tf.placeholder(tf.int32, [None])
    is_training = tf.placeholder(tf.bool, [])

    model = VGG19(x, t, is_training)
    sess = tf.Session()
    with tf.variable_scope('vgg19'):
        global_step = tf.Variable(0, name='global_step', trainable=False)
    opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = opt.minimize(model.loss, global_step=global_step)
    init = tf.global_variables_initializer()
    sess.run(init)

    # Restore the latest model
    if tf.train.get_checkpoint_state('backup/'):
        saver = tf.train.Saver()
        saver.restore(sess, 'backup/latest')

    # Load the dataset
    x_train, t_train, x_test, t_test = load()

    # Train
    while True:
        epoch = int(sess.run(global_step) / np.ceil(len(x_train)/batch_size)) + 1
        print('epoch:', epoch)
        perm = np.random.permutation(len(x_train))
        x_train = x_train[perm]
        t_train = t_train[perm]
        sum_loss_value = 0
        for i in tqdm(range(0, len(x_train), batch_size)):
            x_batch = augment.augment(x_train[i:i+batch_size])
            t_batch = t_train[i:i+batch_size]
            _, loss_value = sess.run(
                [train_op, model.loss],
                feed_dict={x: x_batch, t: t_batch, is_training: True})
            sum_loss_value += loss_value
        print('loss:', sum_loss_value)

        saver = tf.train.Saver()
        saver.save(sess, 'backup/latest', write_meta_graph=False)

        prediction = np.array([])
        answer = np.array([])
        for i in range(0, len(x_test), batch_size):
            x_batch = augment.augment(x_test[i:i+batch_size])
            t_batch = t_test[i:i+batch_size]
            output = model.out.eval(
                feed_dict={x: x_batch, is_training: False}, session=sess)
            prediction = np.concatenate([prediction, np.argmax(output, 1)])
            answer = np.concatenate([answer, t_batch])
            correct_prediction = np.equal(prediction, answer)
        accuracy = np.mean(correct_prediction)
        print('accuracy:', accuracy)

In [None]:
train()

## 2 . Download the Pretrained VGG-19 Model

In [None]:
# find the share link of the file/folder on Google Drive
file_share_link = "https://drive.google.com/open?id=0B-s6ok7B0V9vcXNfSzdjZ0lCc0k"

# extract the ID of the file
file_id = file_share_link[file_share_link.find("=") + 1:]

# append the id to this REST command
file_download_link = "https://docs.google.com/uc?export=download&id=" + file_id 