In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
import torchvision.utils as vutils
import models
from visualize import *
from data import *
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
import warnings
warnings.filterwarnings("ignore")
data_path = './cosmo'

# load dataset and model

In [None]:
# params
img_size = 256
class_num = 1

# cosmo dataset
transformer = transforms.Compose([ToTensor()])
mnu_dataset = MassMapsDataset(opj(data_path, 'cosmological_parameters.txt'),  
                              opj(data_path, 'z1_256'),
                              transform=transformer)

# dataloader
data_loader = torch.utils.data.DataLoader(mnu_dataset, batch_size=64, shuffle=False, num_workers=4)

# load model
model = models.load_model(model_name='resnet18', device=device, data_path=data_path).to(device)

In [None]:
# with torch.no_grad():
#     result = {'y': [], 'pred': []}
#     for i in tqdm(range(100)):
#         sample = mnu_dataset[i]
#         x = sample['image']
#         result['y'].append(sample['params'][1].item())
#         result['pred'].append(model(x.unsqueeze(0).to(device)).flatten()[1].item())
# # print(result)
# plt.scatter(result['y'], result['pred'])
# plt.xlabel('true param')
# plt.ylabel('predicted param')
# plt.show()

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.convt1 = nn.ConvTranspose2d(64, 1, kernel_size=6, stride=2, padding=2, bias=False)
        
    def forward(self, x):
        return self.convt1(x)
    

class Reconstruction(nn.Module):
    def __init__(self, model, generator):
        super(Reconstruction, self).__init__()
        self.conv1 = model.conv1
        self.bn1 = model.bn1
        self.relu1 = model.relu
        self.convt1 = generator.convt1
        
    def feature_map(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        return x
        
    def forward(self, x):
        return self.convt1(self.feature_map(x))


In [None]:
# # model
# netG = Generator().to(device)

# # prepend model and netG
# netR = Reconstruction(model, netG).to(device)

# # criterion
# criterion = nn.MSELoss()

# # Setup Adam optimizers for G
# optimizerG = optim.Adam(netG.parameters(), lr=0.01)

In [None]:
# # Training Loop
# # Lists to keep track of progress
# G_losses = []
# num_epochs = 50

# print("Starting Training Loop...")
# # For each epoch
# for epoch in range(num_epochs):
#     # For each batch in the dataloader
#     for i, data in enumerate(data_loader, 0):
#         inputs, params = data['image'], data['params']
#         if device == 'cuda':
#             inputs = inputs.to(device)
#             params = params.to(device)
#         inputs_ = netR(inputs)
#         # loss
#         loss = criterion(inputs, inputs_)
#         # zero grad
#         netG.zero_grad()
#         # backward
#         loss.backward()
#         # Update G
#         optimizerG.step()

#         # Output training stats
#         if i % 50 == 0:
#             print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, i * len(inputs), len(data_loader.dataset),
#                        100. * i / len(data_loader), loss.data.item()), end='')

#         # Save Losses for plotting later
#         G_losses.append(loss.item())


In [None]:
# loss versus training iterations
plt.figure(figsize=(10,5))
plt.title("Generator Loss During Training")
plt.plot(G_losses, label="G")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# viz filters
viz_filters(netG)

In [None]:
num = 60
im = iter(data_loader).next()['image'][num:num+1].to(device)
viz_im_r(im, netR(im))
print(torch.norm(im - netR(im)).item()**2/28**2)

In [None]:
sum(sum(sum(sum(model.relu(model.bn1(model.conv1(inputs))) == 0)))).item()/(64*64*128*128)