In [1]:
import os
import torch
import numpy as np

from PIL import Image
from torch import nn
import torch.optim as optim
from torchvision.transforms import Compose, ToTensor, Resize
import torchvision.datasets as datasets
import torch.nn.functional as F

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
import time

%matplotlib inline

In [2]:
sample_input = os.path.join(os.curdir,'..','data','preprocessed','autoencoder','val','CNV-13823-2.jpeg')

In [3]:
img =  Image.open(sample_input).convert('RGB')

In [4]:
transforms = Compose([
    Resize(224),
    ToTensor()
    ])

In [5]:
img  = transforms(img).unsqueeze(0)

In [6]:
img.shape

torch.Size([1, 3, 224, 224])

In [7]:
class ConvBlock(nn.Module):
    def __init__(self, **kwargs):
        super(ConvBlock, self).__init__()
        
        self.block = nn.Sequential(
                nn.Conv2d(**kwargs),
                nn.BatchNorm2d(num_features=kwargs['out_channels']),       
        )
        
    def forward(self,x):
        return self.block(x)
        

In [8]:
class ConvTransBlock(nn.Module):
    def __init__(self, **kwargs):
        super(ConvTransBlock, self).__init__()
        
        self.block = nn.Sequential(
                nn.ConvTranspose2d(**kwargs),
                nn.BatchNorm2d(num_features=kwargs['out_channels']),
                nn.ReLU(inplace=True)
        )
        
    def forward(self,x):
        return self.block(x)

In [9]:
class ResBlock(nn.Module):
    def __init__(self,in_channels, out_channels, stride = 1,downsample = False):
        super(ResBlock,self).__init__()
        
        self.block1 = ConvBlock(in_channels= in_channels,out_channels= out_channels, stride= stride,
                                kernel_size=3, padding = 1, bias= False)
        self.block2 = ConvBlock(in_channels= out_channels,out_channels= out_channels, stride= 1,
                               kernel_size= 3,padding = 1, bias = False)
        
        self.relu = nn.ReLU(inplace=True)
        
        if downsample:
            self.downsample = ConvBlock(in_channels=in_channels,out_channels= out_channels,
                                        kernel_size= 1, stride= 2, bias= False )
        else:
            self.downsample = None
        
        
        
        
    def forward(self, x):
        identity = x
        x = self.block1(x)
        x = self.relu(x)
        x = self.block2(x)
        
        if self.downsample != None:
            identity = self.downsample(identity)
        x += identity
        x = self.relu(x)
        
        return x

In [26]:
class ResnetAutoencoder(nn.Module):
    def __init__(self,res34=False):
        super(ResnetAutoencoder,self).__init__()
        
        self.res34 = res34
        self.same_layer = nn.MaxPool2d(kernel_size=1,stride=1)
        
        
        #Encoder
        self.conv1 = ConvBlock(in_channels= 3, out_channels= 64, kernel_size= 7,
                        stride= 2, padding= 3, bias= False) 

        self.max_pool = nn.MaxPool2d(kernel_size= 3, stride= 2, padding=1)
        self.relu = nn.ReLU(inplace= True)
        
        
        self.conv2_x = nn.Sequential(
                ResBlock(in_channels=64, out_channels=64, stride=1, downsample=False),
                ResBlock(in_channels=64, out_channels=64, stride=1, downsample=False),
                ResBlock(in_channels=64, out_channels=64, stride=1, downsample=False) 
                        if self.res34 else self.same_layer
        
        ) 
        self.conv3_x = nn.Sequential(
                ResBlock(in_channels=64, out_channels=128, stride=2, downsample=True),
                ResBlock(in_channels=128, out_channels=128, stride=1, downsample=False),
                ResBlock(in_channels=128, out_channels=128, stride=1, downsample=False) 
                        if self.res34 else self.same_layer,
                ResBlock(in_channels=128, out_channels=128, stride=1, downsample=False) 
                        if self.res34 else self.same_layer,
        )
        
        self.conv4_x = nn.Sequential(
                ResBlock(in_channels=128, out_channels=256, stride=2, downsample=True),
                ResBlock(in_channels=256, out_channels=256, stride=1, downsample=False),
                ResBlock(in_channels=256, out_channels=256, stride=1, downsample=False) 
                        if self.res34 else self.same_layer,
                ResBlock(in_channels=256, out_channels=256, stride=1, downsample=False) 
                        if self.res34 else self.same_layer,
                ResBlock(in_channels=256, out_channels=256, stride=1, downsample=False) 
                        if self.res34 else self.same_layer,
                ResBlock(in_channels=256, out_channels=256, stride=1, downsample=False) 
                        if self.res34 else self.same_layer,
        )

        self.conv5_x = nn.Sequential(
                ResBlock(in_channels=256, out_channels=512, stride=2, downsample=True),
                ResBlock(in_channels=512, out_channels=512, stride=1, downsample=False),
                ResBlock(in_channels=512, out_channels=512, stride=1, downsample=False) 
                        if self.res34 else self.same_layer
        
        ) 
               
        
        #Decoder
        self.block6 = nn.Sequential(
                
                #ConvTransBlock(in_channels= 512,out_channels=512, kernel_size=3, padding=1),
                ConvTransBlock(in_channels= 512, out_channels = 256, kernel_size= 2, stride=2),
            
                #ConvTransBlock(in_channels= 256,out_channels=256, kernel_size=3, padding=1),
                ConvTransBlock(in_channels= 256, out_channels = 128, kernel_size= 2, stride=2),
                
                #ConvTransBlock(in_channels= 128,out_channels=128, kernel_size=3, padding=1),
                ConvTransBlock(in_channels= 128, out_channels = 64,  kernel_size= 2, stride=2),
            
                #ConvTransBlock(in_channels= 64,out_channels=64, kernel_size=3, padding=1),
                ConvTransBlock(in_channels= 64, out_channels = 64,   kernel_size= 2, stride=2),
                
                ConvTransBlock(in_channels= 64,out_channels=64, kernel_size=3, padding=1),
                ConvTransBlock(in_channels= 64, out_channels = 3,    kernel_size= 2, stride=2),
        
        
        )
        
        #self.avg_pooling = nn.AdaptiveAvgPool2d(output_size=(1,1))
        #self.fc = nn.Linear(512,1000)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv2_x(x)
        x = self.conv3_x(x)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.block6(x)
        return x

In [27]:
train_data_path = os.path.join(os.curdir,'..','data','preprocessed', 'autoencoder')
val_data_path = os.path.join(os.curdir,'..','data','preprocessed', 'autoencoder')

In [28]:
class AEDataset(Dataset):
    def __init__(self, data_path, train=True ):
        if train :
            self.data_path = os.path.join(data_path, 'train')
        else:
            self.data_path = os.path.join(data_path, 'val')
        self.all_images = os.listdir(self.data_path)
        self.trasforms = Compose([
                    Resize(224),
                    ToTensor()
                    ])
        
    def __len__(self):
        return len(self.all_images)
    
    def __getitem__(self, idx)->torch.tensor:
        image_name = self.all_images[idx]
        path = os.path.join(self.data_path, image_name)
        return self.trasforms(Image.open(fp=path).convert('RGB'))

In [29]:
train_data = AEDataset(train_data_path,train =True)
val_data = AEDataset(val_data_path,train =False)

In [30]:
train_iterator = DataLoader(train_data,shuffle = True, batch_size=64)
val_iterator = DataLoader(train_data,shuffle = True, batch_size=64)

In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [32]:
model = ResnetAutoencoder()

In [33]:
model(img).shape

torch.Size([1, 3, 224, 224])

In [34]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 11,920,393 trainable parameters


In [471]:
optimizer = optim.Adam(model.parameters(),lr=0.001)

In [472]:
criterion = nn.MSELoss()

In [473]:
model.to(device)

ResnetAutoencoder(
  (same_layer): MaxPool2d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
  (conv1): ConvBlock(
    (block): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (relu): ReLU(inplace=True)
  (conv2_x): Sequential(
    (0): ResBlock(
      (block1): ConvBlock(
        (block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block2): ConvBlock(
        (block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [474]:
def train(model, iterator, optimizer, criterion, device):
    
    epoch_loss = 0
    
    model.train()
    
    for images in iterator:
        
        images = images.to(device)
        
        optimizer.zero_grad()
                
        output = model(images)
        
        loss = criterion(output, images)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [475]:
def evaluate(model, iterator, criterion, device):
    
    epoch_loss = 0
    
    model.eval()
    
    with torch.no_grad():
        
        for images in iterator:
            
            images = images.to(device)
            
            output = model(images)

            loss = criterion(output, images)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [476]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [477]:
EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(EPOCHS):
    
    start_time = time.monotonic()
    
    #train_loss = train(model, train_iterator, optimizer, criterion, device)
    valid_loss = evaluate(model, val_iterator, criterion, device)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'ae1-model.pt')
    
    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

KeyboardInterrupt: 