In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from importlib import reload

import utils.models as models
import utils.plotting as plotting
import utils.dataloaders as dl
import utils.traintest as tt
import utils.adversarial as adv
import utils.eval as ev
import resnet
from tensorboardX import SummaryWriter



import params as hps

ImportError: No module named 'params'

In [2]:
device = torch.device('cuda:' + str(hps.gpu))
writer = SummaryWriter()

if hps.dataset=='MNIST':
    base_model = models.LeNetMadry().to(device)
    train_loader = dl.MNIST_train_loader
    noise_loader = dl.Noise_train_loader_MNIST
elif hps.dataset=='CIFAR10':
    base_model = resnet.ResNet50().to(device).to(device)
    train_loader = dl.CIFAR10_train_loader
    noise_loader = dl.Noise_train_loader_CIFAR10
    
noise_loader = dl.PrecomputeLoader(noise_loader)

In [3]:

if hps.use_gmm:
    loading_string = hps.dataset+'_n'+str(hps.n) 
    gmm = torch.load('SavedModels/gmm_'+loading_string+'.pth')
    gmm.alpha = nn.Parameter(gmm.alpha)
    model = models.RobustModel(base_model, gmm, -5.).to(device)
    model.loglam.requires_grad = False
else:
    model = base_model

saving_string = hps.dataset+'_lam'+str(hps.lam)+'_n'+str(hps.n)



lr = .1*hps.lr

if hps.use_gmm:
    param_groups = [{'params':model.base_model.parameters(),'lr':lr, 'weight_decay':hps.decay},
                   {'params':model.mm.parameters(),'lr':lr, 'weight_decay':0.}]
else:
    param_groups = [{'params':model.parameters(),'lr':lr, 'weight_decay':hps.decay}]
    
optimizer = optim.Adam(param_groups)

In [1]:
import utils.dataloaders as dl

In [2]:
gmm = torch.load('SavedModels/gmm_MNIST_n1000_data_used3000.pth')

In [10]:
torch.logsumexp(gmm(data.view(-1,784)), 0).mean()

tensor(-1750.1868, grad_fn=<MeanBackward1>)

In [18]:
loaderE = dl.GrayCIFAR10()
dataE = enumerate(loaderE).__next__()[1][0]
torch.logsumexp(gmm(dataE.view(-1,784)), 0).mean()

tensor(-1750.9824, grad_fn=<MeanBackward1>)