In [1]:
pip install torchsummary

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install torchvision 

Note: you may need to restart the kernel to use updated packages.


In [3]:
import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
#from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
from torchvision import transforms


In [None]:
enc_block = Block(1, 64)
x         = torch.randn(1, 1, 28, 28)
enc_block(x).shape


In [None]:
#https://github.com/g2archie/UNet-MRI-Reconstruction
#https://amaarora.github.io/2020/09/13/unet.html#understanding-input-and-output-shapes-in-u-net

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
    
    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))

In [None]:
enc_block = Block(1, 64)
x         = torch.randn(1, 1, 28, 28)
enc_block(x).shape


In [None]:
class Encoder(nn.Module):
    def __init__(self, chs=(1,32,64,128,256)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

In [None]:
chs=(1,32,64,128,256)
nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])

In [None]:
encoder = Encoder()
# input image
x    = torch.randn(1, 1, 28, 28)
ftrs = encoder(x)
for ftr in ftrs: print(ftr.shape)

In [None]:
class Decoder(nn.Module):
    def __init__(self, chs=(1024,512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            x        = self.dec_blocks[i](x)
        return x


In [None]:
class Decoder(nn.Module):
    def __init__(self, chs=(256, 128, 64, 32, 1)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features=ftrs[::-1][1:]):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs

In [None]:
import torchsummary
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Decoder().to(device)
x=torch.randn(1, 256, 3, 3)
summary(model,(1, 256, 3, 3))

In [None]:
decoder = Decoder()
x = torch.randn(1, 256, 3, 3)
decoder(x,ftrs[::-1][1:]).shape

In [None]:
chs=(256, 128, 64,32,1)
nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])

In [None]:
class UNet(nn.Module):
  def __init__(self, enc_chs=(1,32,64,128,256), dec_chs=(256, 128, 64, 32), num_class=1, retain_dim=False, out_sz=(572,572)):
      super().__init__()
      self.encoder     = Encoder(enc_chs)
      self.decoder     = Decoder(dec_chs)
      self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
      self.retain_dim  = retain_dim

  def forward(self, x):
      enc_ftrs = self.encoder(x)
      out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
      out      = self.head(out)
      if self.retain_dim:
          out = F.interpolate(out, out_sz)
      return out

In [None]:
import torchsummary
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = UNet().to(device)

summary(model, (1, 1, 256, 256))

In [None]:
unet = UNet()
x    = torch.randn(1, 1, 256, 256)
unet(x).shape

In [None]:
class Encoder(nn.Module):

	def __init__(self, encoded_space_dim):
		super().__init__()

		### Convolutional section
		self.encoder_cnn = nn.Sequential(
		nn.Conv2d(1, 8, 3, stride=2, padding=1),
		nn.ReLU(True),
		nn.Conv2d(8, 16, 3, stride=2, padding=1),
		nn.BatchNorm2d(16),
		nn.ReLU(True),
		nn.Conv2d(16, 32, 3, stride=2, padding=0),
		nn.ReLU(True)
		)

		### Flatten layer
		self.flatten = nn.Flatten(start_dim=1)
### Linear section
		self.encoder_lin = nn.Sequential(
			nn.Linear(3 * 3 * 32, 128),
			nn.ReLU(True),
			nn.Linear(128, encoded_space_dim)
		)

	def forward(self, x):
		x = self.encoder_cnn(x)
		x = self.flatten(x)
		x = self.encoder_lin(x)
		return x
class Decoder(nn.Module):

	def __init__(self, encoded_space_dim):
		super().__init__()
		self.decoder_lin = nn.Sequential(
			nn.Linear(encoded_space_dim, 128),
			nn.ReLU(True),
			nn.Linear(128, 3 * 3 * 32),
			nn.ReLU(True)
		)

		self.unflatten = nn.Unflatten(dim=1,
		unflattened_size=(32, 3, 3))

		self.decoder_conv = nn.Sequential(
			nn.ConvTranspose2d(32, 16, 3,
			stride=2, output_padding=0),
			nn.BatchNorm2d(16),
			nn.ReLU(True),
			nn.ConvTranspose2d(16, 8, 3, stride=2,
			padding=1, output_padding=1),
			nn.BatchNorm2d(8),
			nn.ReLU(True),
			nn.ConvTranspose2d(8, 1, 3, stride=2,
			padding=1, output_padding=1)
		)

	def forward(self, x):
		x = self.decoder_lin(x)
		x = self.unflatten(x)
		x = self.decoder_conv(x)
		x = torch.sigmoid(x)
		return x

In [None]:
decoder = Encoder(encoded_space_dim=64)
x = torch.randn(1, 1, 28, 28)
Encoder(encoded_space_dim=64)(x).shape

In [None]:
num_epochs = 3
batch_size = 500

In [None]:

from torch.utils.data import Dataset, DataLoader

In [None]:
def autoencoder_loss(x, x_hat):
    return F.binary_cross_entropy(x_hat, x) 

In [None]:
#https://gist.github.com/Mahedi-61/e70f08e1f36aa9a4fa575d2a5a3f6c25


In [None]:
#n_epochs = 3
#batch_size_train = 64
#batch_size_test = 1000

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./', train=True, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./', train=False, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

example_targets.shape

import matplotlib.pyplot as plt

fig = plt.figure()

for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')


def shortcut(ims):

    
    f = self.encoder(ims, cond)

In [12]:
#https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial9/AE_CIFAR10.html

class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):

        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), 
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), 
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2),
            act_fn(),
            nn.Flatten(), 
            nn.Linear(2*16*c_hid, latent_dim)
        )

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):

        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), 
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=0),
            act_fn(),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2),
            nn.Tanh() 
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x
encoder = Encoder(num_input_channels=1, base_channel_size=32, latent_dim=256)
# input image
x    = torch.randn(10000,1, 28, 28)
encoder(x).shape
decoder = Decoder(num_input_channels=1, base_channel_size=32, latent_dim=256)
# input image
#x    = torch.randn(1000,256)
#decoder(x).shape
encoder(x).shape


torch.Size([10000, 256])

In [9]:
def train(epoch, train_loader, optimizer, encoder, decoder):
    log_interval=50
    train_losses = []
    train_counter = []
    loss_f= torch.nn.MSELoss()
    encoder.train()
    decoder.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        encoded_data = encoder(data)
        # Decode data
        decoded_data = decoder(encoded_data)
        loss = loss_f(decoded_data, data)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append(
            (batch_idx*1000) + ((epoch-1)*len(train_loader.dataset)))
def test(test_loader, encoder, decoder):
    loss_f= torch.nn.MSELoss()
    test_losses = []
    encoder.eval()
    decoder.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            encoded_data = encoder(data)
            # Decode data
            output= decoder(encoded_data)
            test_loss += loss_f(output,data).item()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f} \n'.format(
        test_loss))
    #output image plotting
    
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

def short_cut(n_epochs, batch_size_train,batch_size_test):
    train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./', train=True, download=False,
                                transform=torchvision.transforms.Compose([
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(
                                    (0.1307,), (0.3081,))
                                ])),batch_size=batch_size_train, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./', train=False, download=False,
                                transform=torchvision.transforms.Compose([
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(
                                    (0.1307,), (0.3081,))
                                ])),batch_size=batch_size_test, shuffle=True)

    encoder=Encoder(num_input_channels=1, base_channel_size=32, latent_dim=256)
    decoder=Decoder(num_input_channels=1, base_channel_size=32, latent_dim=256)
    mean = (0.1307, )
    std = (0.3081, ) 
    learning_rate = 0.01

    params_to_optimize = [
        {'params': encoder.parameters()},
        {'params': decoder.parameters()}
    ]
    optimizer = torch.optim.Adam(params_to_optimize,lr=learning_rate)

    for epoch in range(1, n_epochs + 1):
        train(epoch=epoch, train_loader=train_loader, optimizer=optimizer, encoder=encoder,decoder=decoder)
        test(test_loader=test_loader, encoder=encoder,decoder=decoder)

In [10]:
short_cut(3,100,100)


Test set: Avg. loss: 0.0202 


Test set: Avg. loss: 0.0202 


Test set: Avg. loss: 0.0202 



In [None]:
test()