In [1]:
import time
import torch
from torch import nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as dset
from IPython.display import display, clear_output
from torch.autograd import Variable

In [2]:
def status(batch_size, ep, epoch, i, loss, data_loader):
    # status
    clear_output(wait=True)
    print(str(ep) + '/' + str(epoch))
    print('batch: ' + str(i+1) + '/' + str(len(data_loader)) + 
             ' [' + '='*int((i+1)/(len(data_loader)/20)) +
              '>' + ' '*(20 - int((i+1)/(len(data_loader)/20))) +
              ']')
    print('Loss: %.4g '% (loss))
    
#-------------------------------------------------------------------
# this function has been modified in order to accept only BW images
def showAllImages(x,y,z):
    x = x[1,:,:,:].detach()
    y = y[1,:,:,:].detach()
    z = z[1,:,:,:].detach()
    
    x = x.cpu()
    x = x.squeeze()
    y = y.cpu()
    y = y.squeeze()
    z = z.cpu()
    z = z.squeeze()
    
    plt.figure(figsize=(12,8))
    plt.subplot(131)
    plt.imshow(x,cmap='gray')
    plt.subplot(132)
    plt.imshow(y)
    plt.subplot(133)
    plt.imshow(z)
    plt.show()

#-------------------------------------------------------------------
def conv(dimIn, dimOut):
    model = nn.Sequential(
        nn.Conv2d(dimIn, dimOut, kernel_size=3, stride=1,
                  padding=1),
        nn.BatchNorm2d(dimOut),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(dimOut, dimOut, kernel_size=3, stride=1,
                 padding=1),
        nn.BatchNorm2d(dimOut)
    )
    return model

#-------------------------------------------------------------------
def pool():
    p = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return p

#-------------------------------------------------------------------
def invConv(dimIn, dimOut):
    model = nn.Sequential(
        nn.ConvTranspose2d(dimIn, dimOut, kernel_size=3, stride=2,
                           padding=1,output_padding=1),
        nn.BatchNorm2d(dimOut),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return model
    
#-------------------------------------------------------------------
def last(dimIn, dimOut):
    model = nn.Sequential(
        nn.Conv2d(dimIn, dimOut, kernel_size=3, stride=1,
                  padding=1),
        nn.Tanh()
    )
    return model

In [3]:
class UNetGen(nn.Module):
    def __init__(self, filtersNum):
        super().__init__()
        #self.dimIn = dimIn
        #self.dimOut = dimOut
        self.fil = filtersNum
        
        print("\n------Initializing UNetGen------\n")
        
        self.conv1 = conv(1, self.fil)
        self.pool1 = pool()
        self.conv2 = conv(self.fil, self.fil*2)
        self.pool2 = pool()
        self.conv3 = conv(self.fil*2, self.fil*4)
        self.pool3 = pool()
        self.conv4 = conv(self.fil*4, self.fil*8)
        self.pool4 = pool()
        
        self.bridge = conv(self.fil*8, self.fil*16)
        
        self.inv1 = invConv(self.fil*16, self.fil*8)
        self.up1 = conv(self.fil*16, self.fil*8)
        self.inv2 = invConv(self.fil*8, self.fil*4)
        self.up2 = conv(self.fil*8, self.fil*4)
        self.inv3 = invConv(self.fil*4, self.fil*2)
        self.up3 = conv(self.fil*4, self.fil*2)
        self.inv4 = invConv(self.fil*2, self.fil)
        self.up4 = conv(self.fil*2, self.fil)
        
        self.last = last(self.fil, 1)
        
    def forward(self, img):
        conv1 = self.conv1(img)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        
        bridge = self.bridge(pool4)
        
        inv1 = self.inv1(bridge)
        join1 = torch.cat([inv1, conv4],dim=1)
        up1 = self.up1(join1)
        inv2 = self.inv2(up1)
        join2 = torch.cat([inv2, conv3],dim=1)
        up2 = self.up2(join2)
        inv3 = self.inv3(up2)
        join3 = torch.cat([inv3, conv2],dim=1)
        up3 = self.up3(join3)
        inv4 = self.inv4(up3)
        join4 = torch.cat([inv4, conv1],dim=1)
        up4 = self.up4(join4)
        
        res = self.last(up4)
        return res


In [4]:
checkpoint = torch.load('./testCNNtraining.pth')
model = checkpoint['model']
model.load_state_dict(checkpoint['model_state_dict'])
lr = checkpoint['learningRate']
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

AttributeError: Can't get attribute 'TestGen' on <module '__main__'>

In [15]:
a = np.array([3, 6, 2])
b = np.tile(a, (2,2,3,4))

In [16]:
print(b, '\n', b.shape)

[[[[3 6 2 3 6 2 3 6 2 3 6 2]
   [3 6 2 3 6 2 3 6 2 3 6 2]
   [3 6 2 3 6 2 3 6 2 3 6 2]]

  [[3 6 2 3 6 2 3 6 2 3 6 2]
   [3 6 2 3 6 2 3 6 2 3 6 2]
   [3 6 2 3 6 2 3 6 2 3 6 2]]]


 [[[3 6 2 3 6 2 3 6 2 3 6 2]
   [3 6 2 3 6 2 3 6 2 3 6 2]
   [3 6 2 3 6 2 3 6 2 3 6 2]]

  [[3 6 2 3 6 2 3 6 2 3 6 2]
   [3 6 2 3 6 2 3 6 2 3 6 2]
   [3 6 2 3 6 2 3 6 2 3 6 2]]]] 
 (2, 2, 3, 12)


In [19]:
print(b[1,1,:,:])

[[3 6 2 3 6 2 3 6 2 3 6 2]
 [3 6 2 3 6 2 3 6 2 3 6 2]
 [3 6 2 3 6 2 3 6 2 3 6 2]]


In [None]:
for ep in range(epoch):
    for i, (image, label) in enumerate(trainloader):
        
        print(image.shape)
        # add sequence length
        image5 = np.tile(image, (6,1,1,1,1))
        
        print(image5.shape)
        
        im = image[1,:,:,:]
        plt.imshow(im.squeeze())
        plt.show()
        
        im = image5[1,1,:,:,:]
        plt.imshow(im.squeeze())
        plt.show()
        
        break