In [3]:
import numpy as np
from tqdm import tqdm
from solver.qp import qp
import torch
import gzip
import pickle
import warnings
from generate_instances import generate_softmarginsvm

### Soft margin Support Vector Machine

In [16]:
def surrogate_gen():
    C = 10 #regularization parameter
    m = 4 #number of Datapoints
    X = np.random.uniform(-1, 1, (m, 2)) #Dataset
    bias = np.full((m, 1), 0.5)
    X[:, 1] += bias.ravel()   # Adding a bias to make data linearly separable
    y = np.array([1.0 if x[0] > x[1] else -1.0 for x in X]).reshape(m, 1) #label vector
    X_dash = (y.reshape(-1,1) * 1.) * X #X_dash * X_dash^T = Q
    Q, q, G, h, A, b = generate_softmarginsvm(y, X, X_dash, C)    
    return Q, q, G, h, A, b, X_dash

### create

In [25]:
import os

directory = f'Quadratic_Programming_Datasets/raw'
os.makedirs(directory, exist_ok=True)
warnings.filterwarnings("error")

ips = []
pkg_idx = 0 #identifier index
success_cnt = 0
fail_cnt = 0

max_iter = 2
num = 2

for i in tqdm(range(max_iter)):
    Q, q, G, h, A, b, S = surrogate_gen()
    
    try:
        res = qp(Q, q, G, h, A, b)
        print("status",res['status'])
    except Exception as e:
        fail_cnt += 1
        warnings.warn(f'Optimization failed with error: {str(e)}')
        continue
    else:
        if res['status'] == 'optimal': #np.isnan(res.fun)?
            ips.append((torch.from_numpy(np.array(Q)).to(torch.float), torch.from_numpy(np.array(q)).to(torch.float), torch.from_numpy(np.array(G)).to(torch.float), torch.from_numpy(np.array(h)).to(torch.float), torch.from_numpy(np.array(A)).to(torch.float), torch.from_numpy(np.array(b)).to(torch.float), torch.from_numpy(np.array(S)).to(torch.float)))
            print(success_cnt)
            success_cnt += 1
            
    if len(ips) >= 1000 or success_cnt == num:
        print("success")
        with gzip.open(f'Quadratic_Programming_Datasets/raw/instance_{pkg_idx}.pkl.gz', "wb") as file:
            pickle.dump(ips, file)
            pkg_idx += 1
        ips = []

    if success_cnt >= num:
        break

warnings.resetwarnings()

In [20]:
# Open and load the .pkl.gz file
with gzip.open(f'Quadratic_Programming_Datasets/raw/instance_{pkg_idx-1}.pkl.gz', 'rb') as f:
    data = pickle.load(f)

# Now `data` contains the contents of the pickle file
print(data)