In [2]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms as tfs
from torchvision.utils import save_image
from PIL import Image
from torch import nn
from torch.autograd import Variable
from torchvision.transforms.functional import *
import os
import numpy
import torch 
import gdal


def minmaxnormization(tensor):
    tensor=tensor-tensor.min()
    tensor=tensor/tensor.max()
    return tensor

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, 4),
            nn.ReLU(True),
            nn.Linear(4, 1)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(1, 4),
            nn.ReLU(True),
            nn.Linear(4, 16),
            nn.ReLU(True),
            nn.Linear(16, 32),
            nn.ReLU(True),
            nn.Linear(32, 64),
            nn.ReLU(True),
            nn.Tanh()
        )

    def forward(self, x):
        encode = self.encoder(x)
        decode = self.decoder(encode)
        return encode, decode



os.chdir(r'E:\Moore\pyScripts\pytorch\PCAvsAE')


class LandsatDataset():

    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.transform = transform
#         self.batch_size=batch_size
    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
#         samples=[]
#         for i in range(self.batch_size):
        img_name = os.listdir(self.root_dir)[idx]#[idx*self.batch_size+i]
        in_ds = gdal.Open(self.root_dir+img_name)
        image=torch.from_numpy(in_ds.GetRasterBand(1).ReadAsArray().astype('float')).view(1200,1200)
        image=minmaxnormization(image)
        image_tensor=image.cuda()
        return image_tensor

    
dataset = LandsatDataset('./imgfolder/band5_1750/',transform=None)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)


model = autoencoder().cuda()
criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
   
for data in dataloader:
    for epoch in range(100): #10000
        for batch in range(1200): #1200
            img_tuple=[]
            for i in range(64):
                img=data[i,batch,:].view(1,-1)
                img_tuple.append(img)
            img_tuple=tuple(img_tuple)
            img_input=torch.cat(img_tuple,0).permute(1,0).float()
            img_input=Variable(img_input)
            _,output= model(img_input)
            loss=criterion(output, img_input)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (batch+1)%200==0:
                print('epoch: {},line: {}, Loss: {:.4f}'.format(epoch + 1,batch+1, loss.data[0]))

torch.save(model.state_dict(), './model_dict/1113_autoencoder_1750_test9.pth')



epoch: 1,line: 200, Loss: 8296.6787
epoch: 1,line: 400, Loss: 7346.6953
epoch: 1,line: 600, Loss: 3253.9287
epoch: 1,line: 800, Loss: 3318.6025
epoch: 1,line: 1000, Loss: 3639.2876
epoch: 1,line: 1200, Loss: 2271.0806
epoch: 2,line: 200, Loss: 6739.7266
epoch: 2,line: 400, Loss: 6590.8135
epoch: 2,line: 600, Loss: 3241.2197
epoch: 2,line: 800, Loss: 3316.7720
epoch: 2,line: 1000, Loss: 3639.0342
epoch: 2,line: 1200, Loss: 2271.0356
epoch: 3,line: 200, Loss: 6739.1865
epoch: 3,line: 400, Loss: 6590.4790
epoch: 3,line: 600, Loss: 3241.1736
epoch: 3,line: 800, Loss: 3316.7178
epoch: 3,line: 1000, Loss: 3639.0535
epoch: 3,line: 1200, Loss: 2271.0527
epoch: 4,line: 200, Loss: 6739.0840
epoch: 4,line: 400, Loss: 6590.3667
epoch: 4,line: 600, Loss: 3241.1782
epoch: 4,line: 800, Loss: 3316.7373
epoch: 4,line: 1000, Loss: 3639.0356
epoch: 4,line: 1200, Loss: 2271.0947
epoch: 5,line: 200, Loss: 6739.0518
epoch: 5,line: 400, Loss: 6590.3628
epoch: 5,line: 600, Loss: 3241.2065
epoch: 5,line: 800, 

epoch: 37,line: 1000, Loss: 2681.1421
epoch: 37,line: 1200, Loss: 1701.8059
epoch: 38,line: 200, Loss: 4896.5020
epoch: 38,line: 400, Loss: 4786.1387
epoch: 38,line: 600, Loss: 2398.0298
epoch: 38,line: 800, Loss: 2449.5249
epoch: 38,line: 1000, Loss: 2681.1365
epoch: 38,line: 1200, Loss: 1701.7988
epoch: 39,line: 200, Loss: 4896.5205
epoch: 39,line: 400, Loss: 4786.1382
epoch: 39,line: 600, Loss: 2398.0408
epoch: 39,line: 800, Loss: 2449.5105
epoch: 39,line: 1000, Loss: 2681.1382
epoch: 39,line: 1200, Loss: 1701.8020
epoch: 40,line: 200, Loss: 4896.5078
epoch: 40,line: 400, Loss: 4786.1367
epoch: 40,line: 600, Loss: 2398.0322
epoch: 40,line: 800, Loss: 2449.5200
epoch: 40,line: 1000, Loss: 2681.1348
epoch: 40,line: 1200, Loss: 1701.7983
epoch: 41,line: 200, Loss: 4896.5146
epoch: 41,line: 400, Loss: 4786.1343
epoch: 41,line: 600, Loss: 2398.0396
epoch: 41,line: 800, Loss: 2449.5061
epoch: 41,line: 1000, Loss: 2681.1382
epoch: 41,line: 1200, Loss: 1701.8002
epoch: 42,line: 200, Loss: 4

epoch: 74,line: 400, Loss: 3755.3254
epoch: 74,line: 600, Loss: 1922.7374
epoch: 74,line: 800, Loss: 1963.7836
epoch: 74,line: 1000, Loss: 2139.7195
epoch: 74,line: 1200, Loss: 1390.8169
epoch: 75,line: 200, Loss: 3845.7629
epoch: 75,line: 400, Loss: 3755.3142
epoch: 75,line: 600, Loss: 1922.7335
epoch: 75,line: 800, Loss: 1963.7815
epoch: 75,line: 1000, Loss: 2139.7195
epoch: 75,line: 1200, Loss: 1390.8176
epoch: 76,line: 200, Loss: 3845.7607
epoch: 76,line: 400, Loss: 3755.3088
epoch: 76,line: 600, Loss: 1922.7288
epoch: 76,line: 800, Loss: 1963.7794
epoch: 76,line: 1000, Loss: 2139.7202
epoch: 76,line: 1200, Loss: 1390.8173
epoch: 77,line: 200, Loss: 3845.7588
epoch: 77,line: 400, Loss: 3755.3042
epoch: 77,line: 600, Loss: 1922.7275
epoch: 77,line: 800, Loss: 1963.7811
epoch: 77,line: 1000, Loss: 2139.7192
epoch: 77,line: 1200, Loss: 1390.8149
epoch: 78,line: 200, Loss: 3845.7524
epoch: 78,line: 400, Loss: 3755.2993
epoch: 78,line: 600, Loss: 1922.7156
epoch: 78,line: 800, Loss: 196

# deeper structure

In [6]:
from torch import nn
from torch.autograd import Variable
from torchvision.transforms.functional import *
import os
import numpy
import torch 
import gdal


def minmaxnormization(tensor):
    tensor=tensor-tensor.min()
    tensor=tensor/tensor.max()
    return tensor

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, 8),
            nn.ReLU(True),
            nn.Linear(8, 4),
            nn.ReLU(True),
            nn.Linear(4, 1)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(1, 4),
            nn.ReLU(True),
            nn.Linear(4, 8),
            nn.ReLU(True),
            nn.Linear(8, 16),
            nn.ReLU(True),
            nn.Linear(16, 32),
            nn.ReLU(True),
            nn.Linear(32, 64),
            nn.ReLU(True),
            nn.Tanh()
        )

    def forward(self, x):
        encode = self.encoder(x)
        decode = self.decoder(encode)
        return encode, decode



os.chdir(r'E:\Moore\pyScripts\pytorch\PCAvsAE')


class LandsatDataset():

    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.transform = transform
#         self.batch_size=batch_size
    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
#         samples=[]
#         for i in range(self.batch_size):
        img_name = os.listdir(self.root_dir)[idx]#[idx*self.batch_size+i]
        in_ds = gdal.Open(self.root_dir+img_name)
        image=torch.from_numpy(in_ds.GetRasterBand(1).ReadAsArray().astype('float')).view(1200,1200)
        image=minmaxnormization(image)
        image_tensor=image.cuda()
        return image_tensor

    
dataset = LandsatDataset('./imgfolder/band5_1750/',transform=None)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)


model = autoencoder().cuda()
criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
   
for data in dataloader:
    for epoch in range(300): #10000
        for batch in range(1200): #1200
            img_tuple=[]
            for i in range(64):
                img=data[i,batch,:].view(1,-1)
                img_tuple.append(img)
            img_tuple=tuple(img_tuple)
            img_input=torch.cat(img_tuple,0).permute(1,0).float()
            img_input=Variable(img_input)
            _,output= model(img_input)
            loss=criterion(output, img_input)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (batch+1)%200==0:
                print('epoch: {},line: {}, Loss: {:.4f}'.format(epoch + 1,batch+1, loss.data[0]))

torch.save(model.state_dict(), './model_dict/1113_autoencoder_1750_test10.pth')



epoch: 1,line: 200, Loss: 9336.9775
epoch: 1,line: 400, Loss: 8791.9941
epoch: 1,line: 600, Loss: 4307.4199
epoch: 1,line: 800, Loss: 4371.6006
epoch: 1,line: 1000, Loss: 4805.1758
epoch: 1,line: 1200, Loss: 2967.3848
epoch: 2,line: 200, Loss: 9010.5537
epoch: 2,line: 400, Loss: 8776.5605
epoch: 2,line: 600, Loss: 4306.2979
epoch: 2,line: 800, Loss: 4371.5215
epoch: 2,line: 1000, Loss: 4805.0400
epoch: 2,line: 1200, Loss: 2967.3591
epoch: 3,line: 200, Loss: 9010.0498
epoch: 3,line: 400, Loss: 8775.4160
epoch: 3,line: 600, Loss: 4306.0513
epoch: 3,line: 800, Loss: 4371.2671
epoch: 3,line: 1000, Loss: 4804.9775
epoch: 3,line: 1200, Loss: 2967.3843
epoch: 4,line: 200, Loss: 9009.6201
epoch: 4,line: 400, Loss: 8775.2207
epoch: 4,line: 600, Loss: 4306.0679
epoch: 4,line: 800, Loss: 4371.2949
epoch: 4,line: 1000, Loss: 4805.0269
epoch: 4,line: 1200, Loss: 2967.4270
epoch: 5,line: 200, Loss: 9009.6299
epoch: 5,line: 400, Loss: 8775.1436
epoch: 5,line: 600, Loss: 4306.1157
epoch: 5,line: 800, 

epoch: 37,line: 1000, Loss: 1416.5212
epoch: 37,line: 1200, Loss: 962.9057
epoch: 38,line: 200, Loss: 2476.2173
epoch: 38,line: 400, Loss: 2414.5007
epoch: 38,line: 600, Loss: 1295.7268
epoch: 38,line: 800, Loss: 1313.8223
epoch: 38,line: 1000, Loss: 1416.4703
epoch: 38,line: 1200, Loss: 962.9105
epoch: 39,line: 200, Loss: 2476.0649
epoch: 39,line: 400, Loss: 2414.4263
epoch: 39,line: 600, Loss: 1295.7102
epoch: 39,line: 800, Loss: 1313.8097
epoch: 39,line: 1000, Loss: 1416.4283
epoch: 39,line: 1200, Loss: 962.8896
epoch: 40,line: 200, Loss: 2475.9092
epoch: 40,line: 400, Loss: 2414.1675
epoch: 40,line: 600, Loss: 1295.6581
epoch: 40,line: 800, Loss: 1313.7776
epoch: 40,line: 1000, Loss: 1416.4219
epoch: 40,line: 1200, Loss: 962.8920
epoch: 41,line: 200, Loss: 2475.8423
epoch: 41,line: 400, Loss: 2414.1428
epoch: 41,line: 600, Loss: 1295.6591
epoch: 41,line: 800, Loss: 1313.7688
epoch: 41,line: 1000, Loss: 1416.3687
epoch: 41,line: 1200, Loss: 962.8777
epoch: 42,line: 200, Loss: 2475.7

epoch: 74,line: 600, Loss: 1295.5830
epoch: 74,line: 800, Loss: 1313.7335
epoch: 74,line: 1000, Loss: 1416.3467
epoch: 74,line: 1200, Loss: 962.7728
epoch: 75,line: 200, Loss: 2475.5527
epoch: 75,line: 400, Loss: 2413.7029
epoch: 75,line: 600, Loss: 1295.6035
epoch: 75,line: 800, Loss: 1313.7356
epoch: 75,line: 1000, Loss: 1416.3450
epoch: 75,line: 1200, Loss: 962.7797
epoch: 76,line: 200, Loss: 2475.5640
epoch: 76,line: 400, Loss: 2413.7119
epoch: 76,line: 600, Loss: 1295.6046
epoch: 76,line: 800, Loss: 1313.7350
epoch: 76,line: 1000, Loss: 1416.3472
epoch: 76,line: 1200, Loss: 962.7794
epoch: 77,line: 200, Loss: 2475.5037
epoch: 77,line: 400, Loss: 2413.6567
epoch: 77,line: 600, Loss: 1295.5848
epoch: 77,line: 800, Loss: 1313.7319
epoch: 77,line: 1000, Loss: 1416.3469
epoch: 77,line: 1200, Loss: 962.7714
epoch: 78,line: 200, Loss: 2475.5503
epoch: 78,line: 400, Loss: 2413.6973
epoch: 78,line: 600, Loss: 1295.6030
epoch: 78,line: 800, Loss: 1313.7343
epoch: 78,line: 1000, Loss: 1416.3

epoch: 110,line: 1000, Loss: 1416.3594
epoch: 110,line: 1200, Loss: 962.7534
epoch: 111,line: 200, Loss: 2475.5144
epoch: 111,line: 400, Loss: 2413.6213
epoch: 111,line: 600, Loss: 1295.5833
epoch: 111,line: 800, Loss: 1313.7294
epoch: 111,line: 1000, Loss: 1416.3560
epoch: 111,line: 1200, Loss: 962.7494
epoch: 112,line: 200, Loss: 2475.4883
epoch: 112,line: 400, Loss: 2413.5906
epoch: 112,line: 600, Loss: 1295.5671
epoch: 112,line: 800, Loss: 1313.7251
epoch: 112,line: 1000, Loss: 1416.3589
epoch: 112,line: 1200, Loss: 962.7526
epoch: 113,line: 200, Loss: 2475.5142
epoch: 113,line: 400, Loss: 2413.6223
epoch: 113,line: 600, Loss: 1295.5818
epoch: 113,line: 800, Loss: 1313.7288
epoch: 113,line: 1000, Loss: 1416.3557
epoch: 113,line: 1200, Loss: 962.7484
epoch: 114,line: 200, Loss: 2475.4866
epoch: 114,line: 400, Loss: 2413.5854
epoch: 114,line: 600, Loss: 1295.5635
epoch: 114,line: 800, Loss: 1313.7240
epoch: 114,line: 1000, Loss: 1416.3588
epoch: 114,line: 1200, Loss: 962.7520
epoch: 

epoch: 146,line: 600, Loss: 1295.5636
epoch: 146,line: 800, Loss: 1313.7229
epoch: 146,line: 1000, Loss: 1416.3557
epoch: 146,line: 1200, Loss: 962.7399
epoch: 147,line: 200, Loss: 2475.4805
epoch: 147,line: 400, Loss: 2413.5671
epoch: 147,line: 600, Loss: 1295.5447
epoch: 147,line: 800, Loss: 1313.7163
epoch: 147,line: 1000, Loss: 1416.3630
epoch: 147,line: 1200, Loss: 962.7416
epoch: 148,line: 200, Loss: 2475.4966
epoch: 148,line: 400, Loss: 2413.6006
epoch: 148,line: 600, Loss: 1295.5588
epoch: 148,line: 800, Loss: 1313.7225
epoch: 148,line: 1000, Loss: 1416.3566
epoch: 148,line: 1200, Loss: 962.7421
epoch: 149,line: 200, Loss: 2475.4927
epoch: 149,line: 400, Loss: 2413.5967
epoch: 149,line: 600, Loss: 1295.5569
epoch: 149,line: 800, Loss: 1313.7222
epoch: 149,line: 1000, Loss: 1416.3569
epoch: 149,line: 1200, Loss: 962.7421
epoch: 150,line: 200, Loss: 2475.4985
epoch: 150,line: 400, Loss: 2413.6038
epoch: 150,line: 600, Loss: 1295.5603
epoch: 150,line: 800, Loss: 1313.7227
epoch: 1

epoch: 182,line: 200, Loss: 2105.8315
epoch: 182,line: 400, Loss: 2055.9438
epoch: 182,line: 600, Loss: 1127.3804
epoch: 182,line: 800, Loss: 1146.3284
epoch: 182,line: 1000, Loss: 1227.5958
epoch: 182,line: 1200, Loss: 849.1818
epoch: 183,line: 200, Loss: 2105.8452
epoch: 183,line: 400, Loss: 2055.9736
epoch: 183,line: 600, Loss: 1127.3933
epoch: 183,line: 800, Loss: 1146.3325
epoch: 183,line: 1000, Loss: 1227.5916
epoch: 183,line: 1200, Loss: 849.1841
epoch: 184,line: 200, Loss: 2105.8528
epoch: 184,line: 400, Loss: 2055.9839
epoch: 184,line: 600, Loss: 1127.3975
epoch: 184,line: 800, Loss: 1146.3334
epoch: 184,line: 1000, Loss: 1227.5903
epoch: 184,line: 1200, Loss: 849.1829
epoch: 185,line: 200, Loss: 2105.8315
epoch: 185,line: 400, Loss: 2055.9453
epoch: 185,line: 600, Loss: 1127.3823
epoch: 185,line: 800, Loss: 1146.3291
epoch: 185,line: 1000, Loss: 1227.5966
epoch: 185,line: 1200, Loss: 849.1816
epoch: 186,line: 200, Loss: 2105.8457
epoch: 186,line: 400, Loss: 2055.9741
epoch: 1

epoch: 217,line: 1000, Loss: 1227.5969
epoch: 217,line: 1200, Loss: 849.1773
epoch: 218,line: 200, Loss: 2105.8320
epoch: 218,line: 400, Loss: 2055.9548
epoch: 218,line: 600, Loss: 1127.3884
epoch: 218,line: 800, Loss: 1146.3297
epoch: 218,line: 1000, Loss: 1227.5978
epoch: 218,line: 1200, Loss: 849.1772
epoch: 219,line: 200, Loss: 2105.8438
epoch: 219,line: 400, Loss: 2055.9688
epoch: 219,line: 600, Loss: 1127.3932
epoch: 219,line: 800, Loss: 1146.3312
epoch: 219,line: 1000, Loss: 1227.5961
epoch: 219,line: 1200, Loss: 849.1769
epoch: 220,line: 200, Loss: 2105.8218
epoch: 220,line: 400, Loss: 2055.9282
epoch: 220,line: 600, Loss: 1127.3799
epoch: 220,line: 800, Loss: 1146.3251
epoch: 220,line: 1000, Loss: 1227.5991
epoch: 220,line: 1200, Loss: 849.1747
epoch: 221,line: 200, Loss: 2105.8350
epoch: 221,line: 400, Loss: 2055.9573
epoch: 221,line: 600, Loss: 1127.3894
epoch: 221,line: 800, Loss: 1146.3301
epoch: 221,line: 1000, Loss: 1227.5978
epoch: 221,line: 1200, Loss: 849.1768
epoch: 

epoch: 253,line: 600, Loss: 1127.3813
epoch: 253,line: 800, Loss: 1146.3220
epoch: 253,line: 1000, Loss: 1227.5966
epoch: 253,line: 1200, Loss: 849.1602
epoch: 254,line: 200, Loss: 2105.8325
epoch: 254,line: 400, Loss: 2055.9531
epoch: 254,line: 600, Loss: 1127.3806
epoch: 254,line: 800, Loss: 1146.3217
epoch: 254,line: 1000, Loss: 1227.5967
epoch: 254,line: 1200, Loss: 849.1595
epoch: 255,line: 200, Loss: 2105.8364
epoch: 255,line: 400, Loss: 2055.9570
epoch: 255,line: 600, Loss: 1127.3810
epoch: 255,line: 800, Loss: 1146.3224
epoch: 255,line: 1000, Loss: 1227.5963
epoch: 255,line: 1200, Loss: 849.1592
epoch: 256,line: 200, Loss: 2105.8245
epoch: 256,line: 400, Loss: 2055.9434
epoch: 256,line: 600, Loss: 1127.3787
epoch: 256,line: 800, Loss: 1146.3187
epoch: 256,line: 1000, Loss: 1227.5969
epoch: 256,line: 1200, Loss: 849.1577
epoch: 257,line: 200, Loss: 2105.8486
epoch: 257,line: 400, Loss: 2055.9614
epoch: 257,line: 600, Loss: 1127.3810
epoch: 257,line: 800, Loss: 1146.3242
epoch: 2

epoch: 289,line: 200, Loss: 2105.8340
epoch: 289,line: 400, Loss: 2055.9392
epoch: 289,line: 600, Loss: 1127.3727
epoch: 289,line: 800, Loss: 1146.3177
epoch: 289,line: 1000, Loss: 1227.5984
epoch: 289,line: 1200, Loss: 849.1548
epoch: 290,line: 200, Loss: 2105.8389
epoch: 290,line: 400, Loss: 2055.9419
epoch: 290,line: 600, Loss: 1127.3727
epoch: 290,line: 800, Loss: 1146.3206
epoch: 290,line: 1000, Loss: 1227.5989
epoch: 290,line: 1200, Loss: 849.1549
epoch: 291,line: 200, Loss: 2105.8176
epoch: 291,line: 400, Loss: 2055.9260
epoch: 291,line: 600, Loss: 1127.3730
epoch: 291,line: 800, Loss: 1146.3147
epoch: 291,line: 1000, Loss: 1227.5955
epoch: 291,line: 1200, Loss: 849.1526
epoch: 292,line: 200, Loss: 2105.8406
epoch: 292,line: 400, Loss: 2055.9399
epoch: 292,line: 600, Loss: 1127.3726
epoch: 292,line: 800, Loss: 1146.3210
epoch: 292,line: 1000, Loss: 1227.5991
epoch: 292,line: 1200, Loss: 849.1547
epoch: 293,line: 200, Loss: 2105.8149
epoch: 293,line: 400, Loss: 2055.9189
epoch: 2

In [4]:
# CAE

In [None]:
from torch import nn
from torch.autograd import Variable
from torchvision.transforms.functional import *
import os
import numpy
import torch 
import gdal


def minmaxnormization(tensor):
    tensor=tensor-tensor.min()
    tensor=tensor/tensor.max()
    return tensor

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1,32,3,stride=3,padding=1),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32,1,3,stride=2),
            nn.ReLU(True)
        )



    def forward(self, x):
        encode = self.encoder(x)
        decode = self.decoder(encode)
        return encode, decode



os.chdir(r'E:\Moore\pyScripts\pytorch\PCAvsAE')


class LandsatDataset():

    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.transform = transform
#         self.batch_size=batch_size
    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
#         samples=[]
#         for i in range(self.batch_size):
        img_name = os.listdir(self.root_dir)[idx]#[idx*self.batch_size+i]
        in_ds = gdal.Open(self.root_dir+img_name)
        image=torch.from_numpy(in_ds.GetRasterBand(1).ReadAsArray().astype('float')).view(1200,1200)
        image=minmaxnormization(image)
        image_tensor=image.cuda()
        return image_tensor

    
dataset = LandsatDataset('./imgfolder/band5_1750/',transform=None)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)


model = autoencoder().cuda()
criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
   
for data in dataloader:
    for epoch in range(100): #10000
        for batch in range(1200): #1200
            img_tuple=[]
            for i in range(64):
                img=data[i,batch,:].view(1,-1)
                img_tuple.append(img)
            img_tuple=tuple(img_tuple)
            img_input=torch.cat(img_tuple,0).permute(1,0).float()
            img_input=Variable(img_input)
            _,output= model(img_input)
            loss=criterion(output, img_input)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (batch+1)%200==0:
                print('epoch: {},line: {}, Loss: {:.4f}'.format(epoch + 1,batch+1, loss.data[0]))

torch.save(model.state_dict(), './model_dict/1113_autoencoder_1750_CAE.pth')