Skip to content

Commit

Permalink
removed hardcoded paths
Browse files Browse the repository at this point in the history
  • Loading branch information
chirag126 committed Nov 22, 2022
1 parent 84592ae commit da70155
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions openxai/LoadModel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import torch
import requests

# models
import openxai.ML_Models.ANN.model as model_ann
from openxai.ML_Models.LR.model import LogisticRegression
Expand Down Expand Up @@ -37,68 +40,89 @@ def LoadModel(data_name: str, ml_model, pretrained: bool = True):
inputs, labels = data_iter.next()

if pretrained:
os.mkdir('./pretrained')
if data_name == 'synthetic':
if ml_model == 'ann':
model_path = './openxai/ML_Models/Saved_Models/ANN/gaussian_lr_0.002_acc_0.91.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718575', allow_redirects=True)
model_path = './pretrained/ann_synthetic.pt'
open(model_path, 'wb').write(r.content)
model = model_ann.ANN_softmax(input_layer=inputs.shape[1],
hidden_layer_1=100,
num_of_classes=2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif ml_model == 'lr':
model_path = './openxai/ML_Models/Saved_Models/LR/gaussian_lr_0.002_acc_0.73.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718576', allow_redirects=True)
model_path = './pretrained/lr_synthetic.pt'
open(model_path, 'wb').write(r.content)
model = LogisticRegression(input_dim=inputs.shape[1])
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif data_name == 'adult':
if ml_model == 'ann':
model_path = './openxai/ML_Models/Saved_Models/ANN/adult_lr_0.002_acc_0.83.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718041', allow_redirects=True)
model_path = './pretrained/ann_adult.pt'
open(model_path, 'wb').write(r.content)
model = model_ann.ANN_softmax(input_layer=inputs.shape[1],
hidden_layer_1=100,
num_of_classes=2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif ml_model == 'lr':
model_path = './openxai/ML_Models/Saved_Models/LR/adult_lr_0.002_acc_0.84.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718044', allow_redirects=True)
model_path = './pretrained/lr_adult.pt'
open(model_path, 'wb').write(r.content)
model = LogisticRegression(input_dim=inputs.shape[1])
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif data_name == 'compas':
if ml_model == 'ann':
model_path = './openxai/ML_Models/Saved_Models/ANN/compas_lr_0.002_acc_0.85.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718040', allow_redirects=True)
model_path = './pretrained/ann_compas.pt'
open(model_path, 'wb').write(r.content)
model = model_ann.ANN_softmax(input_layer=inputs.shape[1],
hidden_layer_1=100,
num_of_classes=2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif ml_model == 'lr':
model_path = './openxai/ML_Models/Saved_Models/LR/compas_lr_0.002_acc_0.85.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718042', allow_redirects=True)
model_path = './pretrained/lr_compas.pt'
open(model_path, 'wb').write(r.content)
model = LogisticRegression(input_dim=inputs.shape[1])
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif data_name == 'german':
if ml_model == 'ann':
model_path = './openxai/ML_Models/Saved_Models/ANN/german_lr_0.002_acc_0.71.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718047', allow_redirects=True)
model_path = './pretrained/ann_german.pt'
open(model_path, 'wb').write(r.content)
model = model_ann.ANN_softmax(input_layer=inputs.shape[1],
hidden_layer_1=100,
num_of_classes=2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif ml_model == 'lr':
model_path = './openxai/ML_Models/Saved_Models/LR/german_lr_0.002_acc_0.72.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718043', allow_redirects=True)
model_path = './pretrained/lr_german.pt'
open(model_path, 'wb').write(r.content)
model = LogisticRegression(input_dim=inputs.shape[1])
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif data_name == 'heloc':
if ml_model == 'ann':
model_path = './openxai/ML_Models/Saved_Models/ANN/heloc_lr_0.002_acc_0.74.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718045', allow_redirects=True)
model_path = './pretrained/ann_heloc.pt'
open(model_path, 'wb').write(r.content)
model = model_ann.ANN_softmax(input_layer=inputs.shape[1],
hidden_layer_1=100,
num_of_classes=2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

elif ml_model == 'lr':
model_path = './openxai/ML_Models/Saved_Models/LR/heloc_lr_0.002_acc_0.72.pt'
r = requests.get('https://dataverse.harvard.edu/api/access/datafile/6718046', allow_redirects=True)
model_path = './pretrained/lr_heloc.pt'
open(model_path, 'wb').write(r.content)
model = LogisticRegression(input_dim=inputs.shape[1])
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
else:
Expand Down

0 comments on commit da70155

Please sign in to comment.