In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
import torchvision.utils as vutils
import models
from visualize import *
from data import *
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
from functools import partial
import warnings
warnings.filterwarnings("ignore")
data_path = './cosmo'

# load dataset and model

In [2]:
# params
img_size = 256
class_num = 1

# cosmo dataset
transformer = transforms.Compose([ToTensor()])
mnu_dataset = MassMapsDataset(opj(data_path, 'cosmological_parameters.txt'),  
                              opj(data_path, 'z1_256'),
                              transform=transformer)

# dataloader
data_loader = torch.utils.data.DataLoader(mnu_dataset, batch_size=64, shuffle=True, num_workers=4)

# load model
model = models.load_model(model_name='resnet18', device=device, inplace=False, data_path=data_path).to(device)
model = model.eval()
# freeze layers
for param in model.parameters():
    param.requires_grad = False

In [3]:
class Transform(nn.Module):
    def __init__(self):
        super(Transform, self).__init__()
        ## encoder layers ##
        self.conv1 = nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3, bias=False)  
#         self.conv2 = nn.Conv2d(32, 10, kernel_size=4, stride=1, padding=1, bias=False)
        self.pool = nn.MaxPool2d(2, 2, return_indices=True)
        self.unpool = nn.MaxUnpool2d(2, 2)

    def forward(self, x):
        ## encode ##
        # add hidden layers with relu activation function
        # and maxpooling after
        x = F.relu(self.conv1(x))
        x, indices = self.pool(x)
        x = self.unpool(x, indices)
        # add second hidden layer
#         x = F.relu(self.conv2(x))
#         x = self.pool(x)  # compressed representation
                
        return x
    
class Transform_i(nn.Module):
    def __init__(self):
        super(Transform_i, self).__init__()        
        ## decoder layers ##
        ## a kernel of 2 and a stride of 2 will increase the spatial dims by 2
#         self.t_conv1 = nn.ConvTranspose2d(10, 32, kernel_size=2, stride=2, padding=0, output_padding=0, bias=False)
        self.t_conv1 = nn.ConvTranspose2d(32, 1, kernel_size=2, stride=2, padding=0, output_padding=0, bias=False)
#         self.u_conv1 = nn.Conv2d(10, 32, kernel_size=7, padding=0)
#         self.u_conv1 = nn.Conv2d(32, 1, kernel_size=3, padding=1)


    def forward(self, x):
        ## decode ##
        # upsample, followed by a conv layer, with relu activation function  
        # this function is called `upsample` in some PyTorch versions
#         x = F.interpolate(x, scale_factor=4, mode='nearest')
        x = self.t_conv1(x)
        # upsample again, output should have a sigmoid applied
#         x = F.interpolate(x, scale_factor=3, mode='nearest')
#         x = self.u_conv2(x)
                
        return x
    
class Mask(nn.Module):
    def __init__(self, img_size=128):
        super(Mask, self).__init__()
        self.img_size = img_size
        self.initialize()
#         self.mask = nn.Parameter(torch.clamp(abs(torch.randn(img_size, img_size)), 0, 1))
        
    def forward(self, x):
        return torch.mul(self.mask, x)   
    
    def initialize(self):
        self.mask = nn.Parameter(torch.ones(64, 32, self.img_size, self.img_size))

In [4]:
# def transforms
t = Transform().to(device)
transform_i = Transform_i().to(device)
# initialize
# t.conv1.weight.data = model.conv1.weight.data
# transform_i.convt1.weight.data = model.conv1.weight.data

# mask
mask = Mask().to(device)

# transform_i.load_state_dict(torch.load('./models/conv_filters_pen'))
# prepend transformation
model_t = TrimModel(model, transform_i, use_residuals=True)

In [5]:
# criterion
criterion = nn.MSELoss()

# l1-loss
l1loss = nn.L1Loss()

# Setup Adam optimizers
optimizer_t = optim.Adam(t.parameters(), lr=0.0005)
optimizer_i = optim.Adam(transform_i.parameters(), lr=0.0005)
optimizer = optim.Adam(mask.parameters(), lr=0.1)

In [6]:
# Training Loop
# Lists to keep track of progress
losses = []
num_epochs = 2

lamb_l1 = 100.0
lamb = 0.1

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(data_loader, 0):
        inputs, params = data['image'], data['params']
        if device == 'cuda':
            inputs = inputs.to(device)
            params = params.to(device)

        losses_inner = []
        for ii in range(100):
            # update masks    
            im_mask = mask(t(inputs))
            output_ = model(transform_i(im_mask))
            # loss
            loss = -output_[:,1].sum() + lamb_l1 * l1loss(mask.mask, torch.zeros_like(mask.mask))
            # zero grad
            optimizer.zero_grad()
            # backward
            loss.backward()
            # Update G
            optimizer.step()
            # projection
            mask.mask.data = torch.clamp(mask.mask.data, 0, 1)

            # mask training stats
            print('\rTrain Epoch: {}/100 ({}/{})'.format(ii, epoch, num_epochs), end='') 
            losses_inner.append(loss.item())
            
        outputs = transform_i(t(inputs))
        # loss
        loss = criterion(inputs, outputs)
        # interp
        im_mask = mask(t(inputs))
        output_ = model(transform_i(im_mask))
        loss += -lamb * output_[:,1].sum()
        # zero grad
        t.zero_grad()
        transform_i.zero_grad()
        # backward
        loss.backward()
        # Update G
        optimizer_t.step()
        optimizer_i.step()

        # Output training stats
        if i % 10 == 0:
            print('\nTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(inputs), len(data_loader.dataset),
                       100. * i / len(data_loader), loss.data.item()))
            torch.save(t.state_dict(), './models/transform_')
            torch.save(transform_i.state_dict(), './models/transform_i_')

        # Save Losses for plotting later
        losses.append(loss.item())
        # Setup Adam optimizer
        mask.initialize()
        mask = mask.to(device)
        optimizer = optim.Adam(mask.parameters(), lr=0.1)

Starting Training Loop...
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 24/100 (0/2)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 55/100 (0/2)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 20/100 (0/2)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 84/100 (0/2)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 63/100 (0/2)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 99/100 (0/2)
Train Epoch: 6/100 (0/2)

KeyboardInterrupt: 

In [None]:
# loss versus training iterations
plt.figure(figsize=(10,5))
plt.title("Generator Loss During Training")
plt.plot(losses, label="G")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# viz filters
viz_filters(t.conv1.weight, normalize=True)

In [None]:
# viz filters
viz_filters(transform_i.t_conv1.weight, normalize=True)

In [None]:
im = mnu_dataset[25000]['image'].to(device).unsqueeze(0)
viz_im_r(im, transform_i(t(im)))
print(torch.norm(im - transform_i(t(im))).item()**2/28**2)