In [1]:
# wandb and pytorch manual 
# https://colab.research.google.com/drive/1XDtq-KT0GkX06a_g1MevuLFOMk4TxjKZ#scrollTo=bZpt5W2NNl6S


In [2]:
import torch
import torch.nn as nn
import numpy as np
import os

import sys
sys.path.append('../')

from loss_landscape.landscape_utils import init_directions, init_network
import argparse
from datetime import datetime
# import wandb

import matplotlib.pyplot as plt
from train import build_dataset, build_model, seed_everything

In [3]:
def load_model(model, config):
    path = '../saved1/' + config.arch + '/best_model.pth'
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [4]:
def run_landscape_gen(config):
    
    seed_everything(config.seed) 
    
    _, val_loader, _ = build_dataset(config)
    
    for model_id in config.model_ids:
        print(f'Testing {model_id}')
        config.arch = model_id
        model = build_model(config.arch)
        model.eval()  


        if os.path.exists(f'results/{model_id}_contour_bs_{config.batch_size}_res_{config.resolution}_.png'):
            continue

        noises = init_directions(load_model(model, config))

        criterion = nn.L1Loss()

        A, B = np.meshgrid(np.linspace(-1, 1, config.resolution),
                        np.linspace(-1, 1, config.resolution), indexing='ij')

        loss_surface = np.empty_like(A)       

        for i in range(config.resolution):
            for j in range(config.resolution):
                total_loss = 0.
                n_batch = 0
                alpha = A[i, j]
                beta = B[i, j]
                net = init_network(load_model(model, config), noises, alpha, beta).to('cuda')

                for batch, label in iter(val_loader):
                    batch = batch.to('cuda')
                    label = label.to('cuda')
                    with torch.no_grad():
                        preds = net(batch)
                        loss = criterion(batch, preds)
                        # loss = criterion(label, preds)
                        total_loss += loss.item()
                        n_batch += 1
                loss_surface[i, j] = total_loss / (n_batch * config.batch_size)
                del net, batch, preds
                print(f'alpha : {alpha:.2f}, beta : {beta:.2f}, loss : {loss_surface[i, j]:.2f}')
                torch.cuda.empty_cache()

        plt.figure(figsize=(10, 10))
        plt.contour(A, B, loss_surface)
        plt.savefig(f'results/{model_id}_contour_bs_{config.batch_size}_res_{config.resolution}.png', dpi=100)
        plt.close()

        np.save(f'{model_id}_xx_card.npy', A)
        np.save(f'{model_id}_yy_card.npy', B)
        np.save(f'{model_id}_zz_card.npy', loss_surface)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

args = argparse.Namespace(
    data_path = '../dataset/',
    arch = 'ae_basic',
    num_workers = 12,
    epochs = 100,
    batch_size = 256,
    device = device,
    seed = 42,
    resolution = 100, # sweep 100*100 for visualizing
    model_ids = ['ae_basic']
)

In [None]:
run_landscape_gen(args)

(113842, 30)
Testing ae_basic
A total of 21,082 parameters.
alpha : -1.00, beta : -1.00, loss : 0.11
alpha : -1.00, beta : -0.98, loss : 0.10
alpha : -1.00, beta : -0.96, loss : 0.10
alpha : -1.00, beta : -0.94, loss : 0.09
alpha : -1.00, beta : -0.92, loss : 0.09
alpha : -1.00, beta : -0.90, loss : 0.09
alpha : -1.00, beta : -0.88, loss : 0.08
alpha : -1.00, beta : -0.86, loss : 0.08
alpha : -1.00, beta : -0.84, loss : 0.08
alpha : -1.00, beta : -0.82, loss : 0.07
alpha : -1.00, beta : -0.80, loss : 0.07
alpha : -1.00, beta : -0.78, loss : 0.07
alpha : -1.00, beta : -0.76, loss : 0.07
alpha : -1.00, beta : -0.74, loss : 0.06
alpha : -1.00, beta : -0.72, loss : 0.06
alpha : -1.00, beta : -0.70, loss : 0.06
alpha : -1.00, beta : -0.68, loss : 0.06
alpha : -1.00, beta : -0.66, loss : 0.06
alpha : -1.00, beta : -0.64, loss : 0.06
alpha : -1.00, beta : -0.62, loss : 0.05
alpha : -1.00, beta : -0.60, loss : 0.05
alpha : -1.00, beta : -0.58, loss : 0.05
alpha : -1.00, beta : -0.56, loss : 0.

Exception ignored in: <function _releaseLock at 0x7f5c64319e60>
Traceback (most recent call last):
  File "/home/beomgon/anaconda3/envs/pytorch/lib/python3.7/logging/__init__.py", line 221, in _releaseLock
    def _releaseLock():
KeyboardInterrupt


KeyboardInterrupt: 