In [1]:
import torch
import torch.nn as nn 
from torch.nn import Sequential,Conv2d,ConvTranspose2d,ReLU,MaxPool2d
from torch.utils.data import DataLoader,Dataset,random_split

##loss,and optimization
from torch.optim import SGD,lr_scheduler
from torch.nn import MSELoss

##
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST

## import numpy
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy




In [2]:
device=('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


cuda


In [3]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

In [4]:
train_dataset=MNIST(root='../data/',
                    train=True,
                    transform=transform,
                    download=False)

test_dataset=MNIST(root='../data/',
                   train=False,
                   transform=transform,
                   download=False)

In [5]:
##splitting train dataset to  train and val dataset (0.7 and 0.3)

train_samples=int(0.7*len(train_dataset))
val_samples=len(train_dataset)-train_samples

train_dataset,val_dataset=random_split(train_dataset,[train_samples,val_samples])

In [6]:
batch_size=64
##defining dataset and dataloader
dataset={'train':train_dataset,
         'val':val_dataset,
         'test':test_dataset}

dataloader={x:DataLoader(dataset[x],batch_size=batch_size,shuffle=True)
            for x in ['train','val','test']}

In [7]:
##samples images and labels

for images,labels in dataloader['train']:
    break

print('images.shape  :',images.shape)

images.shape  : torch.Size([64, 1, 28, 28])


In [8]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()

        self.encoder=Sequential(
            Conv2d(1,128,3,padding=1),
            MaxPool2d(kernel_size=2),
            ReLU(),
            Conv2d(128,64,3,padding=1),
            MaxPool2d(2),
            ReLU(),
          
            )
        
    def forward(self,x):
        output=self.encoder(x)
        return output

In [9]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()

        self.decoder=Sequential(
            ConvTranspose2d(64,128,3,stride=2,padding=1,output_padding=1),
            ReLU(),
            ConvTranspose2d(128,1,3,stride=2,padding=1,output_padding=1),
            ReLU()
        )

    def forward(self,x):
        return self.decoder(x)

In [23]:
encoder=Encoder()
decoder=Decoder()

In [24]:
encoder

Encoder(
  (encoder): Sequential(
    (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU()
    (3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU()
  )
)

In [25]:
decoder

Decoder(
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(128, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (3): ReLU()
  )
)

In [26]:
encoder_output=encoder(images)

In [27]:
encoder_output.shape

torch.Size([64, 64, 7, 7])

In [28]:
decoder_output=decoder(encoder_output)

In [29]:
decoder_output.shape

torch.Size([64, 1, 28, 28])

In [30]:
class AutoEncoder(nn.Module):

    def __init__(self,encoder,decoder):
        super(AutoEncoder,self).__init__()
        self.encoder=encoder
        self.decoder=decoder

    def forward(self,x):
        encoder_output=self.encoder(x)
        decoder_output=self.decoder(encoder_output)
        return decoder_output

In [31]:
model=AutoEncoder(encoder,decoder).to(device=device)

In [32]:
dataloader

{'train': <torch.utils.data.dataloader.DataLoader at 0x7fd6f83887c0>,
 'val': <torch.utils.data.dataloader.DataLoader at 0x7fd6f8388e50>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7fd608a7a2b0>}

In [36]:
def train_model(model,optimization,scheduler,criterion,num_epochs=10):
    print('--'*10,"Stating Training","--"*10)
    best_wt=deepcopy(model.state_dict())
    best_val_loss=0.0
   

    ## iterating over epochs 

    for epoch_num in range(num_epochs):

       
        print('Epich_Num:{}/{}'.format(epoch_num,num_epochs))

        epoch_train_loss=0.0
        epoch_val_loss=0.0

        for phase in ['train','val']:

            train_running_loss=0.0
            val_running_loss=0.0

            if phase=='train':
                model.train()
            
            else:
                model.eval()

            for images,_ in dataloader[phase]:

                images=images.to(device=device)

                with torch.set_grad_enabled(phase=='train'):
                    pred_images=model(images)
                    loss=criterion(pred_images,images)

                    if phase=='train': ##optimize
                        loss.backward()
                        optimization.step()


                if phase=='train':
                    train_running_loss+=loss.item()*images.shape[0]

                else:
                    train_running_loss+=loss.item()*images.shape[0]


            
            if phase=='train':
                epoch_train_loss=train_running_loss/len(dataset[phase])
                
                print('Phase:{}  |Epoch Loss: {:.2f}'.format(phase,epoch_train_loss))

                scheduler.step()
            
            else:
                epoch_val_loss=val_running_loss/len(dataset[phase])
                
                print('Phase:{}  |Epoch Loss: {:.2f}'.format(phase,epoch_val_loss))

        
        if epoch_val_loss<epoch_train_loss and epoch_val_loss>best_val_loss:
            best_wt=deepcopy(model.state_dict())
            best_val_loss=epoch_val_loss

        
    ##load the best_wt
    model.load_state_dict(best_wt)

    print('Best Val Loss: {:.2f}'.format(best_val_loss))

    return model



In [37]:
## defining loss and optimization 
optimization=SGD(model.parameters(),lr=0.01)
criterion=MSELoss()
step_lr=lr_scheduler.StepLR(optimizer=optimization,
                            step_size=7,
                            gamma=0.001)

In [38]:
model=train_model(model=model,
                  optimization=optimization,
                  scheduler=step_lr,
                  criterion=criterion)

-------------------- Stating Training --------------------
Epich_Num:0/10
Phase:train  |Epoch Loss: 0.93
Phase:val  |Epoch Loss: 0.00
Epich_Num:1/10
Phase:train  |Epoch Loss: 0.93
Phase:val  |Epoch Loss: 0.00
Epich_Num:2/10
Phase:train  |Epoch Loss: 0.93
Phase:val  |Epoch Loss: 0.00
Epich_Num:3/10


KeyboardInterrupt: 

In [20]:
images

tensor([[[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        ...,


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., 

In [22]:
images.to('cuda')

tensor([[[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        ...,


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., 