In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
import torchvision.utils as vutils
import models
from visualize import *
from data import *
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
from functools import partial
import warnings
warnings.filterwarnings("ignore")
data_path = './cosmo'

# load dataset and model

In [2]:
# params
img_size = 256
class_num = 1

# cosmo dataset
transformer = transforms.Compose([ToTensor()])
mnu_dataset = MassMapsDataset(opj(data_path, 'cosmological_parameters.txt'),  
                              opj(data_path, 'z1_256'),
                              transform=transformer)
train_set = torch.utils.data.Subset(mnu_dataset, [_ for _ in range(0,40000)])
test_set = torch.utils.data.Subset(mnu_dataset, [_ for _ in range(40001,50000)])

# dataloader
# data_loader = torch.utils.data.DataLoader(mnu_dataset, batch_size=64, shuffle=True, num_workers=4)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False, num_workers=2)

# load model
model = models.load_model(model_name='resnet18', device=device, inplace=False, data_path=data_path).to(device)
model = model.eval()
# freeze layers
for param in model.parameters():
    param.requires_grad = False

In [3]:
class conv2DBatchNormRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, bias=True, with_bn=True):
        super(conv2DBatchNormRelu, self).__init__()

        conv_mod = nn.Conv2d(in_channels,
                             out_channels,
                             kernel_size=kernel_size,
                             padding=padding,
                             stride=stride,
                             bias=bias,
                             dilation=dilation, )

        if with_bn:
            self.block_unit = nn.Sequential(conv_mod,
                                          nn.BatchNorm2d(out_channels),
                                          nn.ReLU(inplace=True))
        else:
            self.block_unit = nn.Sequential(conv_mod, 
                                          nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.block_unit(x)
        return x

In [4]:
class ConvDown2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvDown2, self).__init__()
        self.conv1 = conv2DBatchNormRelu(in_channels, out_channels, 3, 1, 1)
        self.conv2 = conv2DBatchNormRelu(out_channels, out_channels, 3, 1, 1)
        self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        unpooled_shape = x.size()
        x, indices = self.maxpool_with_argmax(x)
        return x, indices, unpooled_shape


class ConvDown3(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvDown3, self).__init__()
        self.conv1 = conv2DBatchNormRelu(in_channels, out_channels, 3, 1, 1)
        self.conv2 = conv2DBatchNormRelu(out_channels, out_channels, 3, 1, 1)
        self.conv3 = conv2DBatchNormRelu(out_channels, out_channels, 3, 1, 1)
        self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        unpooled_shape = x.size()
        x, indices = self.maxpool_with_argmax(x)
        return x, indices, unpooled_shape


class ConvUp2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvUp2, self).__init__()
        self.unpool = nn.MaxUnpool2d(2, 2)
        self.conv1 = conv2DBatchNormRelu(in_channels, in_channels, 3, 1, 1)
        self.conv2 = conv2DBatchNormRelu(in_channels, out_channels, 3, 1, 1)

    def forward(self, x, indices, output_shape):
        x = self.unpool(input=x, indices=indices, output_size=output_shape)
        x = self.conv1(x)
        x = self.conv2(x)
        return x
    
    
class ConvUp3(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvUp3, self).__init__()
        self.unpool = nn.MaxUnpool2d(2, 2)
        self.conv1 = conv2DBatchNormRelu(in_channels, in_channels, 3, 1, 1)
        self.conv2 = conv2DBatchNormRelu(in_channels, in_channels, 3, 1, 1)
        self.conv3 = conv2DBatchNormRelu(in_channels, out_channels, 3, 1, 1)

    def forward(self, x, indices, output_shape):
        x = self.unpool(input=x, indices=indices, output_size=output_shape)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x    

In [5]:
class Net(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(Net, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.down1 = ConvDown2(self.in_channels, 64)
        self.down2 = ConvDown2(64, 128)
        self.down3 = ConvDown3(128, 256)
#         self.down4 = ConvDown3(256, 512)
#         self.down5 = ConvDown3(512, 512)

#         self.up5 = ConvUp3(512, 512)
#         self.up4 = ConvUp3(512, 256)
        self.up3 = ConvUp3(256, 128)
        self.up2 = ConvUp2(128, 64)
        self.up1 = ConvUp2(64, out_channels)

    def forward(self, x):

        x, indices_1, unpool_shape1 = self.down1(x)
        x, indices_2, unpool_shape2 = self.down2(x)
        x, indices_3, unpool_shape3 = self.down3(x)
#         x, indices_4, unpool_shape4 = self.down4(x)
#         x, indices_5, unpool_shape5 = self.down5(x)

#         x = self.up5(x, indices_5, unpool_shape5)
#         x = self.up4(x, indices_4, unpool_shape4)
        x = self.up3(x, indices_3, unpool_shape3)
        x = self.up2(x, indices_2, unpool_shape2)
        x = self.up1(x, indices_1, unpool_shape1)

        return x

    def init_vgg16_params(self, vgg16):
        blocks = [self.down1, self.down2, self.down3, self.down4, self.down5]

        ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
        features = list(vgg16.features.children())

        vgg_layers = []
        for _layer in features:
            if isinstance(_layer, nn.Conv2d):
                vgg_layers.append(_layer)

        merged_layers = []
        for idx, conv_block in enumerate(blocks):
            if idx < 2:
                units = [conv_block.conv1.block_unit, conv_block.conv2.block_unit]
            else:
                units = [
                    conv_block.conv1.block_unit,
                    conv_block.conv2.block_unit,
                    conv_block.conv3.block_unit,
                ]
            for _unit in units:
                for _layer in _unit:
                    if isinstance(_layer, nn.Conv2d):
                        merged_layers.append(_layer)

        assert len(vgg_layers) == len(merged_layers)

        for l1, l2 in zip(vgg_layers, merged_layers):
            if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
                assert l1.weight.size() == l2.weight.size()
                assert l1.bias.size() == l2.bias.size()
                l2.weight.data = l1.weight.data
                l2.bias.data = l1.bias.data

In [6]:
def adjust_learning_rate(optimizer, shrink_factor):
    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))
    
    
def ensure_folder(folder):
    if not os.path.exists(folder):
        os.makedirs(folder)    
    
def save_checkpoint(epoch, model, optimizer, val_loss, is_best, save_foler='./models/autoencoder'):
    ensure_folder(save_folder)
    state = {'model': model,
             'optimizer': optimizer}
    filename = '{0}/checkpoint_{1}_{2:.3f}.tar'.format(save_folder, epoch, val_loss)
    torch.save(state, filename)
    # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
    if is_best:
        torch.save(state, '{}/BEST_checkpoint.tar'.format(save_folder))    

In [7]:
def train(epoch, train_loader, model, optimizer):
    # Ensure dropout layers are in train mode
    model.train()

    # Loss function
    criterion = nn.MSELoss()

    losses = []
    
    running_loss = 0
    
    # Batches
    for i_batch, data in enumerate(train_loader):
        inputs, params = data['image'], data['params']
        # Set device options
        if device == 'cuda':
            inputs = inputs.to(device)
            params = params.to(device)        

        # Zero gradients
        optimizer.zero_grad()
        
        # Model output
        outputs = model(inputs)
        
        # Loss
        loss = criterion(inputs, outputs)
        loss.backward()

        # optimizer.step(closure)
        optimizer.step()

        # Save Losses for plotting later
        losses.append(loss.item())
        running_loss += loss.item() * inputs.size(0)

        # Print status
        if i_batch % 50 == 0:
            print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i_batch * len(inputs), len(train_loader.dataset),
                       100. * i_batch / len(train_loader), loss.data.item()), end='')

    epoch_loss = running_loss / len(train_loader.dataset)
    return losses


def valid(val_loader, model):
    model.eval()  # eval mode (no dropout or batchnorm)

    # Loss function
    criterion = nn.MSELoss()
    
    losses = [] 
    
    running_loss = 0

    with torch.no_grad():
        # Batches
        for i_batch, data in enumerate(val_loader):
            inputs, params = data['image'], data['params']
            # Set device options
            if device == 'cuda':
                inputs = inputs.to(device)
                params = params.to(device)               

            # Model output
            outputs = model(inputs)
            
            # Loss
            loss = criterion(inputs, outputs)

            # Save Losses for plotting later
            losses.append(loss.item())   
            running_loss += loss.item() * inputs.size(0)

            # Print status
            if i_batch % 50 == 0:
                print('\rVal Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, i_batch * len(inputs), len(val_loader.dataset),
                           100. * i_batch / len(val_loader), loss.data.item()), end='')
                
    epoch_loss = running_loss / len(val_loader.dataset)
    return epoch_loss, losses



In [8]:
# # debug
# # model
# t = Net().to(device)

# # Setup Adam optimizers
# optimizer = optim.Adam(t.parameters(), lr=0.001)

# best_loss = 1e10

# epochs_since_improvement = 0

# num_epochs = 5

# data = iter(train_loader).next()
# inputs, params = data['image'], data['params']
# # Set device options
# if device == 'cuda':
#     inputs = inputs.to(device)
#     params = params.to(device)  
# t(inputs)    

In [None]:
# model
t = Net().to(device)

# Setup Adam optimizers
optimizer = optim.Adam(t.parameters(), lr=0.001)

best_loss = 1e10

epochs_since_improvement = 0

num_epochs = 5

# Epochs
for epoch in range(1, num_epochs):
    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(optimizer, 0.8)

    # One epoch's training
    train(epoch, train_loader, t, optimizer)

    # One epoch's validation
    val_loss = valid(val_loader, t)
    print('\n * LOSS - {loss:.3f}\n'.format(loss=val_loss))

    # Check if there was an improvement
    is_best = val_loss < best_loss
    best_loss = min(best_loss, val_loss)

    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save checkpoint
    save_checkpoint(epoch, model, optimizer, val_loss, is_best)

