In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import cv2
from tqdm import tqdm 
import numpy as np
import os
from torchsummary import summary
import PIL
import matplotlib.pyplot as plt

In [2]:
class Arpit_net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,6,3)
        self.conv2 = nn.Conv2d(6,10,3)
        self.conv3 = nn.Conv2d(10,20,3)
        self.conv4 = nn.Conv2d(20,9,3)
        self.max1 = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.max2 = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.conv5 = nn.Conv2d(9,9,3,padding=(1,1))
        self.max3 = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.unpool1 = nn.MaxUnpool2d(2, stride=2)
        self.tconv1 = nn.Conv2d(9,9,3,padding=(1,1))
        self.unpool2 = nn.MaxUnpool2d(2, stride=2)
        self.unpool3 = nn.MaxUnpool2d(2, stride=2)
        self.tconv2 = nn.Conv2d(9,20,3, padding=(2,2))
        self.tconv3 = nn.Conv2d(20,10,3, padding=(2,2))
        self.tconv4 = nn.Conv2d(10,6,3, padding=(2,2))
        self.tconv5 = nn.Conv2d(6,1,3, padding=(2,2))
        
        

        
        
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x1, self.idc = self.max1(x)
        x2, self.indices = self.max2(x1)
        x3= F.relu(self.conv5(x2))
        x , self.id = self.max3(x3)
        print(self.indices.shape)
        print(x.shape)
        x = self.unpool1(x, self.id ,output_size=x3.shape)
        x = F.relu(self.tconv1(x))
        x = self.unpool2(x, self.indices)
        x = self.unpool3(x, self.idc)
        x = F.relu(self.tconv1(x))
        x = F.relu(self.tconv2(x))
        x = F.relu(self.tconv3(x))
        x = F.relu(self.tconv4(x))
        x = F.sigmoid(self.tconv5(x))
        
        
        
        return x
        
        

In [3]:
r = Arpit_net()

In [4]:
summary(r, (1,100,100))

torch.Size([2, 9, 23, 23])
torch.Size([2, 9, 11, 11])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 98, 98]              60
            Conv2d-2           [-1, 10, 96, 96]             550
            Conv2d-3           [-1, 20, 94, 94]           1,820
            Conv2d-4            [-1, 9, 92, 92]           1,629
         MaxPool2d-5  [[-1, 9, 46, 46], [-1, 9, 46, 46]]               0
         MaxPool2d-6  [[-1, 9, 23, 23], [-1, 9, 23, 23]]               0
            Conv2d-7            [-1, 9, 23, 23]             738
         MaxPool2d-8  [[-1, 9, 11, 11], [-1, 9, 11, 11]]               0
       MaxUnpool2d-9            [-1, 9, 23, 23]               0
           Conv2d-10            [-1, 9, 23, 23]             738
      MaxUnpool2d-11            [-1, 9, 46, 46]               0
      MaxUnpool2d-12            [-1, 9, 92, 92]               0
           Conv2d-13  



In [5]:
train_transforms = transforms.Compose([transforms.Resize((100,100)),
                                       transforms.Grayscale(num_output_channels=1),
                                       transforms.ToTensor(),
                                       ])
train_data = datasets.ImageFolder(r"C:\Users\srava\OneDrive\Desktop\Arpit_study_material\Pytorch\PetImages",       
                    transform=train_transforms)

In [6]:
valid_size = .2
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
np.random.shuffle(indices)
from torch.utils.data.sampler import SubsetRandomSampler
train_idx = indices[split:]
train_sampler = SubsetRandomSampler(train_idx)

In [7]:
trainloader = torch.utils.data.DataLoader(train_data,
                   sampler=train_sampler, batch_size=100)

In [8]:
import cv2
for i, j in trainloader:
    print(i[0].shape)
    #print(i[0])
    break

torch.Size([1, 100, 100])


In [9]:
optimizer = torch.optim.Adam(r.parameters(), lr = 0.01, weight_decay=1e-5)

In [11]:
def train_model(train_data, EPOCHS, optimizer, model):
    for epoch in range(EPOCHS):
        epoch_loss = 0
        for X, y in tqdm(trainloader):
            model.zero_grad()
            output = model(X)
            loss = F.mse_loss(output, X)
            loss.backward()
            optimizer.step()
            epoch_loss += loss
            torch.no_grad()
            for img in output:
                
                p = transforms.ToPILImage()(img)
                p.show()
               
                break
            
            print("Loss:",loss)
            
        print("Epoch {} ".format(epoch))
            
            

In [None]:
import time
start = time.time()
train_model(trainloader,2,optimizer, r)
print(time.time()-start)

  0%|                                                                                          | 0/100 [00:00<?, ?it/s]

torch.Size([100, 9, 23, 23])
torch.Size([100, 9, 11, 11])


  1%|▊                                                                                 | 1/100 [00:08<14:18,  8.67s/it]

Loss: tensor(0.0624, grad_fn=<MseLossBackward>)
torch.Size([100, 9, 23, 23])
torch.Size([100, 9, 11, 11])


  2%|█▋                                                                                | 2/100 [00:17<14:02,  8.60s/it]

Loss: tensor(0.0584, grad_fn=<MseLossBackward>)
torch.Size([100, 9, 23, 23])
torch.Size([100, 9, 11, 11])


  3%|██▍                                                                               | 3/100 [00:27<14:37,  9.05s/it]

Loss: tensor(0.0556, grad_fn=<MseLossBackward>)


  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


torch.Size([100, 9, 23, 23])
torch.Size([100, 9, 11, 11])


  4%|███▎                                                                              | 4/100 [00:38<15:27,  9.66s/it]

Loss: tensor(0.0571, grad_fn=<MseLossBackward>)
