In [1]:
# imports
import torch
from torch import nn
from torchvision.datasets import CIFAR100
from torchvision import transforms
#from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter

In [2]:
class OneUpOneDownAutoencoder(nn.Module):
    
    def __init__(self, in_channels, 
                out_channels,
                kernel_size = 3,
                stride = 1,
                padding=0):
        super(OneUpOneDownAutoencoder, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels = self.in_channels, 
                      out_channels = self.out_channels,
                      kernel_size = kernel_size, 
                      stride = stride, 
                      padding = padding), 
            nn.ReLU(True),
            # Maybe add pooling here
        ) # output shape = (N, out_channels, H*, W*)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels = self.out_channels, 
                out_channels = self.in_channels,
                kernel_size = kernel_size, 
                stride=stride,
                padding = padding,
                output_padding=padding),
            nn.ReLU(True),
        )  # output shape = (N, in_channels, H, W)
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [3]:
class SimpleDecoderLayer(nn.Module):
    
    def __init__(self, in_channels, out_channels, 
                 kernel_size = 3,
                 stride=1,
                 padding=0):
        super(SimpleDecoderLayer, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels = in_channels, 
                out_channels = out_channels,
                kernel_size = kernel_size,
                stride = stride,
                padding = padding,
                output_padding = padding),
            nn.ReLU(True),
        )  # output shape = (N, in_channels, H, W)
        
    def forward(self, x):
        return self.decoder(x)

In [4]:
class HierarchicalAutoencoder(nn.Module):
    def __init__(self, encoder_layer_func, 
                 decoder_layer_func,
                 num_layers,
                 input_size,
                 output_sizes = [],
                 stride = 1,
                 padding = 0):
        super(HierarchicalAutoencoder, self).__init__()
        
        if type(output_sizes) is int:
            output_sizes = [output_sizes]
            
        
        self.encoder_layers = nn.ModuleList()
        for i in range(num_layers):
            if i >= len(output_sizes):
                output_size = input_size * 2
            else:
                output_size = output_sizes[i]
                
            instantiated_encoder = encoder_layer_func(
                in_channels = input_size,
                out_channels = output_size,
                stride = stride,
                padding = padding
            )
            self.encoder_layers.append(instantiated_encoder)
            
            input_size = output_size
    
        self.decoder_layers = nn.ModuleList()
        for i in range(num_layers):
            input_size = self.encoder_layers[num_layers - i - 1].out_channels
            output_size = self.encoder_layers[num_layers - i - 1].in_channels
            instantiated_decoder = decoder_layer_func(
                in_channels = input_size,
                out_channels = output_size,
                stride = stride,
                padding =padding
            )
            self.decoder_layers.append(instantiated_decoder)
        
        #self.encoder = nn.Sequential(*self.encoder_layers)
        #self.decoder = nn.Sequential(*self.decoder_layers)
            
    def forward(self, x):
        for layer in self.encoder_layers:
            x, _ = layer(x)
            
        encoded = x
        for layer in self.decoder_layers:
            x = layer(x)
            
        decoded = x
        return encoded, decoded

In [5]:
dummy_input = torch.ones([4,3,224,224])
dummy_input.shape

torch.Size([4, 3, 224, 224])

In [6]:
model = OneUpOneDownAutoencoder(3, 16, padding = 1, stride = 2)

In [7]:
encoded, decoded = model(dummy_input)
dummy_input.shape, encoded.shape, decoded.shape

(torch.Size([4, 3, 224, 224]),
 torch.Size([4, 16, 112, 112]),
 torch.Size([4, 3, 224, 224]))

In [8]:
model = HierarchicalAutoencoder(OneUpOneDownAutoencoder, SimpleDecoderLayer, 3, 3, output_sizes=32, stride = 2, padding = 1)

In [9]:
encoded, decoded = model(dummy_input)
dummy_input.shape, encoded.shape, decoded.shape

(torch.Size([4, 3, 224, 224]),
 torch.Size([4, 128, 28, 28]),
 torch.Size([4, 3, 224, 224]))

In [10]:
model = model.cuda()#summary(model.cuda(), (3,224,224))

In [11]:
num_epochs = 100
batch_size = 64
learning_rate = 1e-3

In [12]:
dataset = CIFAR100(
    '/data/standard_datasets/cifar100',
    download=True,
    train = True,
    transform=transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
)
len(dataset)

Files already downloaded and verified


50000

In [13]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [14]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                             weight_decay=1e-5)

In [17]:
# setup tensorboard
writer = SummaryWriter('runs/just_one_loss/cifar100')
writer.add_graph(model, dummy_input.cuda())

In [18]:
# train
running_loss = 0.0
for epoch in range(num_epochs):
    for i,data in enumerate(dataloader):
        img, _ = data
        img = torch.autograd.Variable(img).cuda()
        # ===================forward=====================
        _,output = model(img)
        loss = criterion(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # tensorboard
        running_loss += loss.item()
        if i % 100 == 0:
            writer.add_scalar('training loss',
                running_loss,
                epoch * len(dataloader) + i)
    # ===================log========================
    
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, loss.data))

epoch [1/100], loss:0.6873


KeyboardInterrupt: 