In [30]:
import os
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import matplotlib.pyplot as plt
%matplotlib inline

if not os.path.exists('./AE_denoise'):
    os.mkdir('./AE_denoise')


def to_img(x):
    x = x.view(x.size(0), 1, 28, 28)
    return x

num_epochs = 50
batch_size = 128
learning_rate = 1e-3


def plot_sample_img(img, name):
    img = img.view(1, 28, 28)
    save_image(img, './sample_{}.png'.format(name))


def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


def tensor_round(tensor):
    return torch.round(tensor)

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
    transforms.Lambda(lambda tensor:tensor_round(tensor))
])

dataset = MNIST('./data', transform=img_transform, download=True)
#test_dataset=MNIST('./data', transform=img_transform, train = False, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 196),
            nn.ReLU(True),
            nn.Linear(196, 16),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(16, 196),
            nn.ReLU(True),
            nn.Linear(196, 28 * 28),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


model = autoencoder().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
#         print(img)
#         print(img.size())
#         print('--------------------------')
        img = img.view(img.size(0), -1)
        noise = (torch.randint(-10,2,(img.size(0),img.size(1)))>0).float()
        noised_img=((img+noise)>0).float()
#         print(img)
#         print(img.size())
#         print('--------------------------')
        noised_img = Variable(noised_img).cuda()
        
        # ===================forward=====================
        output = model(noised_img)
#         MIDIMG=model.encoder(img).view(img.size(0),56,56).cpu().detach().numpy()
#         MIDIMG=MIDIMG[0]
#         MIDIMG=np.rint(MIDIMG/np.sum(MIDIMG)*255)
#         print(MIDIMG)
#         plt.imshow(MIDIMG)
#         plt.show()
        
        loss = criterion(output, noised_img)
        MSE_loss = nn.MSELoss()(output, noised_img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    if (epoch+1) % 10 == 0 or epoch == 0:
        x_hat = to_img(output.cpu().data)
        #MIDIMG= model.encoder(img).view(img.size(0),1,56,56).cpu().data
        save_image(x_hat, './AE_denoise/c_output_{}.png'.format(epoch+1))
        #save_image(MIDIMG, './mlp_pansharpened/x_midlayer{}.png'.format(epoch))
        save_image(img.view(img.size(0),1,28,28).cpu().data, './AE_denoise/a_raw{}.png'.format(epoch+1))
        save_image(noised_img.view(img.size(0),1,28,28).cpu().data, './AE_denoise/b_noised{}.png'.format(epoch+1))
        
        
        
for epoch_2 in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img).cuda()

        output = model(img)

        loss = criterion(output, img)
        MSE_loss = nn.MSELoss()(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch_2 [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
          .format(epoch_2 + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    if (epoch_2+1) % 10 == 0 or epoch_2 == 0:
        x_hat = to_img(output.cpu().data)
        save_image(x_hat, './AE_denoise/d_output_{}.png'.format(epoch_2+1))









epoch [1/50], loss:0.3834, MSE_loss:0.1144
epoch [2/50], loss:0.3609, MSE_loss:0.1059


KeyboardInterrupt: 

In [39]:
from skimage import io
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os


if not os.path.exists('./clouds_remove'):
    os.mkdir('./clouds_remove')

def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


img_transform = transforms.Compose([
    transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
])
        

class LandsatDataset():

    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.transform = transform
#         self.batch_size=batch_size
    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
#         samples=[]
#         for i in range(self.batch_size):
        img_name = os.listdir(self.root_dir)[idx]#[idx*self.batch_size+i]
        image = io.imread(self.root_dir+img_name)[4000:4700,3000:3500]
        image_tensor=torch.from_numpy(image)
        image_tensor=image_tensor.cuda()
        #samples.append(image_tensor)
        #tensor=torch.cat((*samples),(0))
#         if self.transform:
#             sample = self.transform(sample)
        return image_tensor

num_epochs=100
learning_rate = 1e-3
batch_size=55

dataset = LandsatDataset('./landsat8_142_49_truecolor/',transform=img_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(55, 33),
            nn.ReLU(True),
            nn.Linear(33, 11),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(11, 33),
            nn.ReLU(True),
            nn.Linear(33, 55),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x



model = autoencoder().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

for data in dataloader:
    img_B = data[0:,0:,0:,0]
    img_G = data[0:,0:,0:,0]
    img_R = data[0:,0:,0:,0]
    ini = data.permute(0,3,1,2)
    save_image(ini, './clouds_remove/x_inputlayer.png')
    img_B=img_B.view(img_B.size(0),-1).float()
    img_B = torch.transpose(img_B, 0, 1)
    img_G=img_G.view(img_G.size(0),-1).float()
    img_G = torch.transpose(img_G, 0, 1)
    img_R=img_R.view(img_R.size(0),-1).float()
    img_R = torch.transpose(img_R, 0, 1)
    for epoch in range(num_epochs):
        img_B = Variable(img_B).cuda()
        output_B = model(img_B)
        loss = criterion(output_B.long(), img_B.long())
        MSE_loss = nn.MSELoss()(output_B, img_B)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%100==0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    
    for epoch in range(num_epochs):
        img_G = Variable(img_G).cuda()
        output_G = model(img_G)
        loss = criterion(output_G.long(), img_G.long())
        MSE_loss = nn.MSELoss()(output_G, img_G)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%100==0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    
    for epoch in range(num_epochs):
        img_R = Variable(img_R).cuda()
        output_R = model(img_R)
        loss = criterion(output_R.long(), img_R.long())
        MSE_loss = nn.MSELoss()(output_R, img_R)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%100==0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    
    
#     img_fb= torch.transpose(img_B,0,1)
#     img_fg= torch.transpose(img_G,0,1)
#     img_fr= torch.transpose(img_R,0,1)
    output_fb=torch.transpose(output_B,0,1).view(55,1,700,500)
    output_fg=torch.transpose(output_G,0,1).view(55,1,700,500)
    output_fr=torch.transpose(output_R,0,1).view(55,1,700,500)
    output=torch.cat((output_fb,output_fg,output_fr),1).cpu().data
    # MIDIMG= model.encoder(img).view(img.size(0),1,56,56).cpu().data
    # save_image(x, './mlp_pansharpened/x_{}.png'.format(epoch))
    save_image(output, './clouds_remove/x_output.png')
    # save_image(MIDIMG, './mlp_pansharpened/x_midlayer{}.png'.format(epoch))

    


RuntimeError: "host_softmax" not implemented for 'torch.cuda.LongTensor'

In [11]:
from PIL import Image
import numpy
pic = Image.open("test.jpg")
pix = numpy.array(pic)
print(pix.shape)

(7751, 7591, 3)


In [18]:
pix[2000][3000]

array([56, 72, 62], dtype=uint8)

In [19]:
np.swapaxes(pix,2,0).shape
pix[:][2000][3000]

array([56, 72, 62], dtype=uint8)

In [27]:
x = torch.randn(55, 1,200,300)
y = torch.randn(55, 1,200,300)
z = torch.randn(55, 1,200,300)
dd=torch.cat((x,y,z),1)
print(dd.size())

torch.Size([55, 3, 200, 300])


In [9]:
import numpy as np
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os
import numpy
from PIL import Image
import torch

if not os.path.exists('./clouds_remove'):
    os.mkdir('./clouds_remove')

def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


img_transform = transforms.Compose([
    transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
])
        

class LandsatDataset():

    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.transform = transform
#         self.batch_size=batch_size
    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
#         samples=[]
#         for i in range(self.batch_size):
        img_name = os.listdir(self.root_dir)[idx]#[idx*self.batch_size+i]
        img = Image.open(self.root_dir+img_name)
        image = numpy.array(img)[4000:4700,3000:3500]
        image_tensor=torch.from_numpy(image)
        image_tensor=image_tensor.cuda()
        #samples.append(image_tensor)
        #tensor=torch.cat((*samples),(0))
#         if self.transform:
#             sample = self.transform(sample)
        return image_tensor

num_epochs=1500
learning_rate = 1e-3
batch_size=55

dataset = LandsatDataset('./landsat8_142_49_truecolor/',transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(55, 33),
            nn.ReLU(True),
            nn.Linear(33, 11),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(11, 33),
            nn.ReLU(True),
            nn.Linear(33, 55),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x



model = autoencoder().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

for data in dataloader:
    ini = data.float().permute(0,3,1,2)
    save_image(ini, './clouds_remove/x_inputlayer.png')
    img_B = min_max_normalization(data[0:,0:,0:,0].float(),0.0,1.0)
    img_G = min_max_normalization(data[0:,0:,0:,1].float(),0.0,1.0)
    img_R = min_max_normalization(data[0:,0:,0:,2].float(),0.0,1.0)
    img_B=img_B.view(img_B.size(0),-1).float()
    img_B = torch.transpose(img_B, 0, 1)
    img_G=img_G.view(img_G.size(0),-1).float()
    img_G = torch.transpose(img_G, 0, 1)
    img_R=img_R.view(img_R.size(0),-1).float()
    img_R = torch.transpose(img_R, 0, 1)
    for epoch in range(num_epochs):
        img_B = Variable(img_B).cuda()
        output_B = model(img_B)
        loss = criterion(output_B, img_B)
        MSE_loss = nn.MSELoss()(output_B, img_B)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%500==0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    
    for epoch in range(num_epochs):
        img_G = Variable(img_G).cuda()
        output_G = model(img_G)
        loss = criterion(output_G, img_G)
        MSE_loss = nn.MSELoss()(output_G, img_G)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%500==0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    
    for epoch in range(num_epochs):
        img_R = Variable(img_R).cuda()
        output_R = model(img_R)
#         loss = criterion(output_R.long(), img_R.long())
        loss = nn.MSELoss()(output_R, img_R)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%500==0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    
    
#     img_fb= torch.transpose(img_B,0,1)
#     img_fg= torch.transpose(img_G,0,1)
#     img_fr= torch.transpose(img_R,0,1)
    output_fb=torch.transpose(output_B,0,1).view(55,1,700,500)
    output_fg=torch.transpose(output_G,0,1).view(55,1,700,500)
    output_fr=torch.transpose(output_R,0,1).view(55,1,700,500)
    output=torch.cat((output_fb,output_fg,output_fr),1).cpu().data
    # MIDIMG= model.encoder(img).view(img.size(0),1,56,56).cpu().data
    # save_image(x, './mlp_pansharpened/x_{}.png'.format(epoch))
    save_image(output, './clouds_remove/x_output.png')
    # save_image(MIDIMG, './mlp_pansharpened/x_midlayer{}.png'.format(epoch))

    




epoch [500/5000], loss:0.6851, MSE_loss:0.0525
epoch [1000/5000], loss:0.6794, MSE_loss:0.0496
epoch [1500/5000], loss:0.6707, MSE_loss:0.0454
epoch [2000/5000], loss:0.6589, MSE_loss:0.0395
epoch [2500/5000], loss:0.6476, MSE_loss:0.0339
epoch [3000/5000], loss:0.6396, MSE_loss:0.0300
epoch [3500/5000], loss:0.6349, MSE_loss:0.0277
epoch [4000/5000], loss:0.6319, MSE_loss:0.0263
epoch [4500/5000], loss:0.6298, MSE_loss:0.0253
epoch [5000/5000], loss:0.6282, MSE_loss:0.0246




epoch [100/5000], loss:0.6196, MSE_loss:0.0265
epoch [200/5000], loss:0.6188, MSE_loss:0.0262
epoch [300/5000], loss:0.6183, MSE_loss:0.0259
epoch [400/5000], loss:0.6178, MSE_loss:0.0257
epoch [500/5000], loss:0.6174, MSE_loss:0.0255
epoch [600/5000], loss:0.6170, MSE_loss:0.0253
epoch [700/5000], loss:0.6166, MSE_loss:0.0251
epoch [800/5000], loss:0.6163, MSE_loss:0.0250
epoch [900/5000], loss:0.6159, MSE_loss:0.0248
epoch [1000/5000], loss:0.6155, MSE_loss:0.0246
epoch [1100/5000], loss:0.6152, MSE_loss:0.0245
epoch [1200/5000], loss:0.6148, MSE_loss:0.0243
epoch [1300/5000], loss:0.6144, MSE_loss:0.0241
epoch [1400/5000], loss:0.6140, MSE_loss:0.0240
epoch [1500/5000], loss:0.6137, MSE_loss:0.0238
epoch [1600/5000], loss:0.6133, MSE_loss:0.0236
epoch [1700/5000], loss:0.6129, MSE_loss:0.0235
epoch [1800/5000], loss:0.6125, MSE_loss:0.0233
epoch [1900/5000], loss:0.6121, MSE_loss:0.0231
epoch [2000/5000], loss:0.6116, MSE_loss:0.0229
epoch [2100/5000], loss:0.6112, MSE_loss:0.0227
e



epoch [100/5000], loss:0.0069, MSE_loss:0.0143
epoch [200/5000], loss:0.0068, MSE_loss:0.0143
epoch [300/5000], loss:0.0068, MSE_loss:0.0143
epoch [400/5000], loss:0.0067, MSE_loss:0.0143
epoch [500/5000], loss:0.0067, MSE_loss:0.0143
epoch [600/5000], loss:0.0067, MSE_loss:0.0143
epoch [700/5000], loss:0.0066, MSE_loss:0.0143
epoch [800/5000], loss:0.0066, MSE_loss:0.0143
epoch [900/5000], loss:0.0066, MSE_loss:0.0143
epoch [1000/5000], loss:0.0066, MSE_loss:0.0143
epoch [1100/5000], loss:0.0066, MSE_loss:0.0143
epoch [1200/5000], loss:0.0065, MSE_loss:0.0143
epoch [1300/5000], loss:0.0065, MSE_loss:0.0143
epoch [1400/5000], loss:0.0065, MSE_loss:0.0143
epoch [1500/5000], loss:0.0065, MSE_loss:0.0143
epoch [1600/5000], loss:0.0065, MSE_loss:0.0143
epoch [1700/5000], loss:0.0065, MSE_loss:0.0143
epoch [1800/5000], loss:0.0065, MSE_loss:0.0143
epoch [1900/5000], loss:0.0065, MSE_loss:0.0143
epoch [2000/5000], loss:0.0065, MSE_loss:0.0143
epoch [2100/5000], loss:0.0065, MSE_loss:0.0143
e

# TEST

In [14]:
import numpy as np
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os
import numpy
from PIL import Image
import torch

if not os.path.exists('./clouds_remove'):
    os.mkdir('./clouds_remove')

def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


img_transform = transforms.Compose([
    transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
])
        

class LandsatDataset():

    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.transform = transform
#         self.batch_size=batch_size
    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
#         samples=[]
#         for i in range(self.batch_size):
        img_name = os.listdir(self.root_dir)[idx]#[idx*self.batch_size+i]
        img = Image.open(self.root_dir+img_name)
        image = numpy.array(img)[3000:4700,2000:3500]
        image_tensor=torch.from_numpy(image)
        image_tensor=image_tensor.cuda()
        #samples.append(image_tensor)
        #tensor=torch.cat((*samples),(0))
#         if self.transform:
#             sample = self.transform(sample)
        return image_tensor

num_epochs=10000
learning_rate = 1e-2
batch_size=39

dataset = LandsatDataset('./landsat8_142_49_truecolor/',transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(55, 33),
            nn.ReLU(True),
            nn.Linear(33, 11),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(11, 33),
            nn.ReLU(True),
            nn.Linear(33, 55),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x



model = autoencoder().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

for data in dataloader:
    
    img_B = min_max_normalization(data[0:,0:,0:,0].float(),0.0,1.0)
    img_G = min_max_normalization(data[0:,0:,0:,1].float(),0.0,1.0)
    img_R = min_max_normalization(data[0:,0:,0:,2].float(),0.0,1.0)
    img_B=img_B.view(img_B.size(0),-1).float()
    img_B = torch.transpose(img_B, 0, 1)
    img_G=img_G.view(img_G.size(0),-1).float()
    img_G = torch.transpose(img_G, 0, 1)
    img_R=img_R.view(img_R.size(0),-1).float()
    img_R = torch.transpose(img_R, 0, 1)
    for epoch in range(num_epochs):
        img_B = Variable(img_B).cuda()
        output_B = model(img_B)
        output_G = model(img_G)
        output_R = model(img_R)
        loss = criterion(output_B, img_B)
        MSE_loss = nn.MSELoss()(output_B, img_B)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%1000==0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
            output_fb=torch.transpose(output_B,0,1).view(img_B.size(0),1,img_B.size(1),img_B.size(2))
            output_fg=torch.transpose(output_G,0,1).view(img_B.size(0),1,img_B.size(1),img_B.size(2))
            output_fr=torch.transpose(output_R,0,1).view(img_B.size(0),1,img_B.size(1),img_B.size(2))
            output=torch.cat((output_fb,output_fg,output_fr),1).cpu().data
            # MIDIMG= model.encoder(img).view(img.size(0),1,56,56).cpu().data
            # save_image(x, './mlp_pansharpened/x_{}.png'.format(epoch))
            save_image(output, './clouds_remove/x_output{}.png'.format(epoch))
    
#     img_fb= torch.transpose(img_B,0,1)
#     img_fg= torch.transpose(img_G,0,1)
#     img_fr= torch.transpose(img_R,0,1)
        
        
    input_fb=torch.transpose(img_B,0,1).view(img_B.size(0),1,img_B.size(1),img_B.size(2))
    input_fg=torch.transpose(img_G,0,1).view(img_B.size(0),1,img_B.size(1),img_B.size(2))
    input_fr=torch.transpose(img_R,0,1).view(img_B.size(0),1,img_B.size(1),img_B.size(2))
    ini=torch.cat((input_fb,input_fg,input_fr),1).cpu().data
    save_image(ini, './clouds_remove/x_inputlayer.png')
    # save_image(MIDIMG, './mlp_pansharpened/x_midlayer{}.png'.format(epoch))

    


RuntimeError: size mismatch, m1: [2550000 x 39], m2: [55 x 33] at c:\programdata\miniconda3\conda-bld\pytorch_1533090623466\work\aten\src\thc\generic/THCTensorMathBlas.cu:249

# TEST @@ 2


In [12]:
import numpy as np
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os
import numpy
from PIL import Image
import torch

if not os.path.exists('./clouds_remove'):
    os.mkdir('./clouds_remove')

def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


img_transform = transforms.Compose([
    transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
])
        

class LandsatDataset():

    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.transform = transform
#         self.batch_size=batch_size
    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
#         samples=[]
#         for i in range(self.batch_size):
        img_name = os.listdir(self.root_dir)[idx]#[idx*self.batch_size+i]
        img = Image.open(self.root_dir+img_name)
        image = numpy.array(img)[4000:4700,3000:3500]
        image_tensor=torch.from_numpy(image)
        image_tensor=image_tensor.cuda()
        #samples.append(image_tensor)
        #tensor=torch.cat((*samples),(0))
#         if self.transform:
#             sample = self.transform(sample)
        return image_tensor

num_epochs=5000
learning_rate = 4*1e-4
batch_size=55

dataset = LandsatDataset('./data/landsat8_142_49_truecolor_copy/',transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(55, 33),
            nn.ReLU(True),
            nn.Linear(33, 11),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(11, 33),
            nn.ReLU(True),
            nn.Linear(33, 55),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x



model = autoencoder().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)



for data in dataloader:
    
    img_B = min_max_normalization(data[0:,0:,0:,0].float(),0.0,1.0)
    img_G = min_max_normalization(data[0:,0:,0:,1].float(),0.0,1.0)
    img_R = min_max_normalization(data[0:,0:,0:,2].float(),0.0,1.0)
    img_B=img_B.view(img_B.size(0),-1).float()
    img_B = torch.transpose(img_B, 0, 1)
    img_G=img_G.view(img_G.size(0),-1).float()
    img_G = torch.transpose(img_G, 0, 1)
    img_R=img_R.view(img_R.size(0),-1).float()
    img_R = torch.transpose(img_R, 0, 1)
    for epoch in range(num_epochs):
        img_B = Variable(img_B).cuda()
        output_B = model(img_B)
        output_G = model(img_G)
        output_R = model(img_R)
        loss = criterion(output_B, img_B)
        MSE_loss = nn.MSELoss()(output_B, img_B)
        # ===================backward====================
        optimizer.zero_grad()
        MSE_loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%1000==0 or epoch == 0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.7f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
            
            
    for epoch in range(num_epochs):
        output_B = Variable(output_B).cuda()
        output_B_outagain = model(output_B)
        output_G_outagain = model(output_G)
        output_R_outagain = model(output_R)
        #loss = nn.BCELoss()(output_B_outagain, output_B)
        MSE_loss = nn.MSELoss()(output_B_outagain, output_B)
        # ===================backward====================
        optimizer.zero_grad()
        MSE_loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%1000==0 or epoch == 0:
            print('epoch [{}/{}], MSE_loss:{:.7f}'
              .format(epoch + 1, num_epochs, MSE_loss.data[0]))
            
    

    output_fb=torch.transpose(output_B,0,1).view(data.size(0),1,data.size(1),data.size(2))
    output_fg=torch.transpose(output_G,0,1).view(data.size(0),1,data.size(1),data.size(2))
    output_fr=torch.transpose(output_R,0,1).view(data.size(0),1,data.size(1),data.size(2))
    output=torch.cat((output_fb,output_fg,output_fr),1).cpu().data
    save_image(output, './clouds_remove/x_output.png')   
        
    input_fb=torch.transpose(img_B,0,1).view(data.size(0),1,data.size(1),data.size(2))
    input_fg=torch.transpose(img_G,0,1).view(data.size(0),1,data.size(1),data.size(2))
    input_fr=torch.transpose(img_R,0,1).view(data.size(0),1,data.size(1),data.size(2))
    ini=torch.cat((input_fb,input_fg,input_fr),1).cpu().data
    save_image(ini, './clouds_remove/x_inputlayer.png')
    
    output_fb_outagain=torch.transpose(output_B_outagain,0,1).view(data.size(0),1,data.size(1),data.size(2))
    output_fg_outagain=torch.transpose(output_G_outagain,0,1).view(data.size(0),1,data.size(1),data.size(2))
    output_fr_outagain=torch.transpose(output_R_outagain,0,1).view(data.size(0),1,data.size(1),data.size(2))
    output_outagain=torch.cat((output_fb_outagain,output_fg_outagain,output_fr_outagain),1).cpu().data
    save_image(output_outagain, './clouds_remove/x_output_outagain.png')   


    




epoch [1/5000], loss:0.6937, MSE_loss:0.0567351
epoch [1000/5000], loss:0.5956, MSE_loss:0.0108309
epoch [2000/5000], loss:0.5897, MSE_loss:0.0081981
epoch [3000/5000], loss:0.5864, MSE_loss:0.0067738
epoch [4000/5000], loss:0.5858, MSE_loss:0.0064650
epoch [5000/5000], loss:0.5846, MSE_loss:0.0059638
torch.Size([350000, 55]) torch.Size([350000, 55])
epoch [1/5000], MSE_loss:0.0000476
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])




torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
epoch [1000/5000], MSE_loss:0.0000473
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([

torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
epoch [2000/5000], MSE_loss:0.0000466
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([

torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])
torch.Size([350000, 55]) torch.Size([350000, 55])


In [31]:
MIDIMG_B= min_max_normalization(torch.transpose(model.encoder(img_B),0,1).view(11,1,700,500),0.0,1.0)
MIDIMG_G= min_max_normalization(torch.transpose(model.encoder(img_G),0,1).view(11,1,700,500),0.0,1.0)
MIDIMG_R= min_max_normalization(torch.transpose(model.encoder(img_R),0,1).view(11,1,700,500),0.0,1.0)
MIDIMG=torch.cat((MIDIMG_B,MIDIMG_G,MIDIMG_R),1).cpu().data
save_image(MIDIMG, './clouds_remove/x_midlayer.png')

# TEST 3

In [22]:
import numpy as np
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os
import numpy
from PIL import Image
import torch
import gdal

os.chdir(r'E:\Moore\pyScripts\pytorch\data')
if not os.path.exists('./clouds_remove'):
    os.mkdir('./clouds_remove')

def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


img_transform = transforms.Compose([
    transforms.Lambda(lambda tensor:min_max_normalization(tensor)),
])
        

class LandsatDataset():

    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.transform = transform
#         self.batch_size=batch_size
    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
#         samples=[]
#         for i in range(self.batch_size):
        img_name = os.listdir(self.root_dir)[idx]#[idx*self.batch_size+i]
        in_ds = gdal.Open(self.root_dir+img_name)
        image_b=torch.from_numpy(in_ds.GetRasterBand(1).ReadAsArray(3000,4000,700,700).astype('float')).view(1,700,700)
        image_g=torch.from_numpy(in_ds.GetRasterBand(2).ReadAsArray(3000,4000,700,700).astype('float')).view(1,700,700)
        image_r=torch.from_numpy(in_ds.GetRasterBand(3).ReadAsArray(3000,4000,700,700).astype('float')).view(1,700,700)
        image=torch.cat((image_b,image_g,image_r),0)
        image_tensor=image.cuda()
        #samples.append(image_tensor)
        #tensor=torch.cat((*samples),(0))
#         if self.transform:
#             sample = self.transform(sample)
        return image_tensor

num_epochs=2000
learning_rate = 1e-3
batch_size=55

dataset = LandsatDataset('./landsat8_142_49_truecolor_copy/',transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3,16,3,stride=3,padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2,stride=2),
            nn.Conv2d(16,8,3,stride=2,padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2,stride=1)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8,16,3,stride=2),
            nn.ReLU(True),
            nn.ConvTranspose2d(16,8,5,stride=3,padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(8,3,2,stride=2,padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
model = autoencoder().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

for data in dataloader:
    
    img = min_max_normalization(data.float(),0.0,1.0)
#     img_B=img_B.view(img_B.size(0),-1).float()
#     img_B = torch.transpose(img_B, 0, 1)
#     img_G=img_G.view(img_G.size(0),-1).float()
#     img_G = torch.transpose(img_G, 0, 1)
#     img_R=img_R.view(img_R.size(0),-1).float()
#     img_R = torch.transpose(img_R, 0, 1)
    for epoch in range(num_epochs):
        img= Variable(img).cuda()
        output = model(img)

        loss = criterion(output, img)
        MSE_loss = nn.MSELoss(reduction='sum')(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        MSE_loss.backward()
        optimizer.step()
    # ===================log========================
        if (epoch+1)%200==0:
            print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
              .format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
            save_image(output, './clouds_remove/outputlayer{}.png'.format(epoch+1))    
    save_image(img, './clouds_remove/inputlayer.png')
#     img_fb= torch.transpose(img_B,0,1)
#     img_fg= torch.transpose(img_G,0,1)
#     img_fr= torch.transpose(img_R,0,1)
#     output_fb=torch.transpose(output_B,0,1).view(data.size(0),1,data.size(1),data.size(2))
#     output_fg=torch.transpose(output_G,0,1).view(data.size(0),1,data.size(1),data.size(2))
#     output_fr=torch.transpose(output_R,0,1).view(data.size(0),1,data.size(1),data.size(2))
#     output=torch.cat((output_fb,output_fg,output_fr),1).cpu().data
#     # MIDIMG= model.encoder(img).view(img.size(0),1,56,56).cpu().data
#     # save_image(x, './mlp_pansharpened/x_{}.png'.format(epoch))
#       
        
#     input_fb=torch.transpose(img_B,0,1).view(data.size(0),1,data.size(1),data.size(2))
#     input_fg=torch.transpose(img_G,0,1).view(data.size(0),1,data.size(1),data.size(2))
#     input_fr=torch.transpose(img_R,0,1).view(data.size(0),1,data.size(1),data.size(2))
#     ini=torch.cat((input_fb,input_fg,input_fr),1).cpu().data
#     
#     # save_image(MIDIMG, './mlp_pansharpened/x_midlayer{}.png'.format(epoch))



    




epoch [200/2000], loss:nan, MSE_loss:551018.7500
epoch [400/2000], loss:nan, MSE_loss:493812.1250
epoch [600/2000], loss:0.6002, MSE_loss:452259.8125
epoch [800/2000], loss:0.5987, MSE_loss:404751.6250
epoch [1000/2000], loss:nan, MSE_loss:388593.1562
epoch [1200/2000], loss:nan, MSE_loss:370256.6250
epoch [1400/2000], loss:nan, MSE_loss:320458.8125
epoch [1600/2000], loss:0.5956, MSE_loss:294124.6875
epoch [1800/2000], loss:0.5951, MSE_loss:278981.8750
epoch [2000/2000], loss:0.5948, MSE_loss:265329.8125


torch.Size([55, 3, 28, 28])
epoch [1/1], loss:nan, MSE_loss:0.1496




# Can it do a better job?


In [2]:
import numpy as np
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os
import numpy
from PIL import Image
import torch
from osgeo import gdal

def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor
        
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(10, 7),
            nn.ReLU(True),
            nn.Linear(7, 3),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(3, 7),
            nn.ReLU(True),
            nn.Linear(7, 10),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
    
path = 'E:/pyScripts/pytorch/data/142_49/images_lessthan30clouds/'  
os.chdir(path)



In [5]:
band5ImgNames = os.listdir(path+'band 5')
bands5 = []
for i in range(len(band5ImgNames)):
    in_ds = gdal.Open('band 5/'+ band5ImgNames[i] )
    band=torch.from_numpy(
        in_ds.GetRasterBand(1).ReadAsArray(1200,1400,1000,1000).astype('float'))
    band=min_max_normalization(band,0.0,1.0).view(-1,1)
    bands5.append(band)
bands5_tuple= tuple(bands5)   
bands5_input = torch.cat(bands5_tuple,1).float()


num_epochs=1
learning_rate = 1e-4
batch_size=2500
batch_len=int(bands5_input.size(0)/batch_size)

model = autoencoder().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

for epoch in range(num_epochs):
    for i in range(batch_len):
        bands5_input_sub = bands5_input[i*batch_size:i*batch_size+batch_size]
        bands5_input_sub = bands5_input_sub.cuda()
        output5=model(bands5_input_sub)
        loss = criterion(output5, bands5_input_sub)
        MSE_loss = nn.MSELoss()(output5, bands5_input_sub)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # ===================log========================
    if (epoch+1)%500==0 or epoch == 0:
        print('epoch [{}/{}], loss:{:.7f}, MSE_loss:{:.7f}'.format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0]))
    
print('model developed')    



epoch [1/1], loss:0.6581972, MSE_loss:0.0569363
model developed




In [6]:
###############################################################################    
band2ImgNames = os.listdir(path+'band 2')
bands2 = []
for i in range(len(band2ImgNames)):
    in_ds = gdal.Open('band 2/'+ band2ImgNames[i] )
    band=torch.from_numpy(
        in_ds.GetRasterBand(1).ReadAsArray(1200,1400,1000,1000).astype('float'))
    band=min_max_normalization(band,0.0,1.0).view(-1,1)
    bands2.append(band)
bands2_tuple= tuple(bands2)   
bands2_input = torch.cat(bands2_tuple,1).float()

band3ImgNames = os.listdir(path+'band 3')
bands3 = []
for i in range(len(band3ImgNames)):
    in_ds = gdal.Open('band 3/'+ band3ImgNames[i] )
    band=torch.from_numpy(
        in_ds.GetRasterBand(1).ReadAsArray(1200,1400,1000,1000).astype('float'))
    band=min_max_normalization(band,0.0,1.0).view(-1,1)
    bands3.append(band)
bands3_tuple= tuple(bands3)   
bands3_input = torch.cat(bands3_tuple,1).float()

band4ImgNames = os.listdir(path+'band 4')
bands4 = []
for i in range(len(band4ImgNames)):
    in_ds = gdal.Open('band 4/'+ band4ImgNames[i] )
    band=torch.from_numpy(
        in_ds.GetRasterBand(1).ReadAsArray(1200,1400,1000,1000).astype('float'))
    band=min_max_normalization(band,0.0,1.0).view(-1,1)
    bands4.append(band)
bands4_tuple= tuple(bands4)   
bands4_input = torch.cat(bands4_tuple,1).float()
###################################################################################
model=model.cpu()
output2=model(bands2_input).permute(1,0).view(10,1,1000,1000)
output3=model(bands3_input).permute(1,0).view(10,1,1000,1000)
output4=model(bands4_input).permute(1,0).view(10,1,1000,1000)
output5=model(bands4_input).permute(1,0).view(10,1,1000,1000)

input2=bands2_input.cpu().permute(1,0).view(10,1,1000,1000)
input3=bands3_input.cpu().permute(1,0).view(10,1,1000,1000)
input4=bands4_input.cpu().permute(1,0).view(10,1,1000,1000)
input5=bands5_input.cpu().permute(1,0).view(10,1,1000,1000)

if not os.path.exists('./clouds_remove'):
    os.mkdir('./clouds_remove')

out_BGR = torch.cat((output2,output3,output4),1)
save_image(out_BGR, './clouds_remove/outputs_BGR.png')

out_GRNIR = torch.cat((output3,output4,output5),1)
save_image(out_BGR, './clouds_remove/outputs_GRNIR.png')

in_BGR = torch.cat((input2,input3,input4),1)
save_image(in_BGR, './clouds_remove/inputs_BGR.png')

in_GRNIR = torch.cat((input3,input4,input5),1)
save_image(in_GRNIR, './clouds_remove/inputs_GRNIR.png')

print('imgs saved')

imgs saved
