In [1]:
import os, copy
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from datetime import datetime

In [2]:
class myData(Dataset):
    def __init__(self, pathToImgs, tranformation=None):
        self.pathToImgs = pathToImgs
        self.imgNames = os.listdir(pathToImgs)
        self.transforms = tranformation
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])
        self.defTransform = transforms.Compose([transforms.Resize((224,224)), ## add nomalization
                                                transforms.ToTensor()])
    
    def __len__(self):
        return len(self.imgNames)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
          idx = idx.tolist()
        
        img_name = os.path.join(self.pathToImgs, self.imgNames[idx])
        image = Image.open(img_name)
        label = self.imgNames[idx].split('.')[0]
        label = torch.tensor(int(label.split('_')[1]))
        
        if self.transforms:
            image = self.transforms(image)
            image = torch.cat(3*[image])
        else: 
            image = self.defTransform(image).float()
            image = torch.cat(3*[image])
            image = self.normalize(image)
            
        return image, label

In [3]:
class Autoencoder(nn.Module):
    def __init__(self, batchsize):
        super(Autoencoder, self).__init__()
        
        self.batchsize = batchsize
        
        ## encoder
        self.encoder = models.squeezenet1_1(pretrained=True)
        self.encoder.classifier = nn.Sequential(nn.Conv2d(512, 1, kernel_size=(2, 2), stride=(1, 1)))
        
        ## decoder
        self.decoder = nn.Sequential( nn.ConvTranspose2d(1, 16, 4, stride=2),
                                       nn.BatchNorm2d(16),
                                       nn.ReLU(),
                                       nn.ConvTranspose2d(16, 32, 3, stride=2),
                                       nn.BatchNorm2d(32),
                                       nn.ReLU(),
                                       nn.ConvTranspose2d(32, 64, 3, stride=2),
                                       nn.BatchNorm2d(64),
                                       nn.ReLU(),
                                       nn.ConvTranspose2d(64, 32, 3, stride=2),
                                       nn.BatchNorm2d(32),
                                       nn.ReLU(),
                                       nn.ConvTranspose2d(32, 16, 3, stride=1),
                                       nn.BatchNorm2d(16),
                                       nn.ReLU(),
                                       nn.ConvTranspose2d(16, 1, 8, stride=1),)

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(self.batchsize, 1, 12, 12)
        x = self.decoder(x)
        return x
        

In [18]:
def train(epoch, model, optimizer, criterion, dataloder, device, scheduler=None, exp_name=None):
  
    best_model_wts = copy.deepcopy(model.state_dict())
    scores = {'bestValLoss': float('inf') , 'epoch': 0, 'trainLoss@BVal': float('inf'), 'iter': 0}

#     if exp_name:
#         writer = SummaryWriter('./logs/{}'.format(exp_name))
        
    for e in range(epoch):
        model.train()
        epochLoss, counter = 0, 0
        for img, label in dataloder[0]:
            img = img.to(device)
            optimizer.zero_grad()
            output = model(img)
            loss = criterion(output, img)
            loss.backward()
            optimizer.step()
            epochLoss += loss.item()
            counter += 1
            scores['iter'] +=1
            print(loss.item())
#             writer.add_scalar("trainLoss", loss.item(), global_step = scores['iter'])
        avgLossTrain = epochLoass/counter
        print(f'train epochLoss: {avgLossTrain}\t epoch: {e}')
#         writer.add_scalar("TrainEpochLoss", avgLossTrain, global_step=e)
        
        with torch.no_grad():
            model.eval()
            epochLoss, counter = 0, 0
            for img, label in dataloder[1]:
                img = img.to(device)
                output = model(img)
                loss = criterion(output, img)
                epochLoss += loss.item()
                counter += 1
        avgLossVal = epochLoass/counter
        print(f'test epochLoss: {avgLossVal}\t epoch: {e}')
        print('==' * 30)
#         writer.add_scalar("testEpochLoss", avgLossVal, global_step=e)
                
        if scheduler:
            scheduler.step()
            
        if avgLossVal < scores['bestValLoss']:
            scores['bestValLoss'] = avgLossVal
            scores['epoch'] = e
            scores['trainLoss@BVal'] = avgLossTrain
            best_model_wts  = copy.deepcopy(model.state_dict())
            
    print(scores)
    model.load_state_dict(best_model_wts)

In [5]:
bsize = 50
path_tr = './EyesDataset/train/'
path_ts = './EyesDataset/test/'

trainDataset = myData(path_tr)
testDataset = myData(path_ts)
trainloader = DataLoader(trainDataset, batch_size=bsize, shuffle=True, pin_memory=True)
testloader  = DataLoader(testDataset, batch_size=bsize, pin_memory=True)

In [19]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = Autoencoder(bsize)
model = model.to(device)

In [20]:
epoches = 2
learning_rate = 3e-3

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam( model.parameters(),
                              lr=learning_rate, )
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [7,10,13], gamma=0.1)
scheduler = None
params = {'lr':learning_rate,
          'epoches': epoches,
          'batch_size': bsize,
          'scheduler': scheduler}
exp_name = datetime.now().isoformat() + str(params)
print(exp_name)

2020-09-19T14:26:19.253007{'lr': 0.003, 'epoches': 2, 'batch_size': 50, 'scheduler': None}


In [21]:
train(epoches, model, optimizer, criterion, (trainloader, testloader), device, exp_name=exp_name)

passed torch.Size([50, 1, 12, 12])


  return F.mse_loss(input, target, reduction=self.reduction)


2.880716323852539
passed torch.Size([50, 1, 12, 12])
3.236191987991333
passed torch.Size([50, 1, 12, 12])
2.151003360748291
passed torch.Size([50, 1, 12, 12])


KeyboardInterrupt: 