In [1]:
import torch
import torch.nn as nn
import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split



In [12]:
# read data as numpy
bc = datasets.load_breast_cancer()
X, y = bc.data, bc.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# normalize data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# transform data format to torch tensor
def transfrom_data(x):
    x = torch.from_numpy(x.astype(np.float32))
    return x

X_train = transfrom_data(X_train)
X_test = transfrom_data(X_test)
y_train = transfrom_data(y_train)
y_test = transfrom_data(y_test)

n_samples, n_features = X_train.shape

In [18]:
class LogisticRegression(nn.Module):
    def __init__(self, n_input_features):
        # initiate 
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)
        
    def forward(self, X):
        # forward
        y_pred = torch.sigmoid(self.linear(X)).squeeze()
        return y_pred
model = LogisticRegression(n_features)

# define lr, batch_size, num_iter
lr = 0.01
num_iter = 100

# define optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

for i in range(num_iter):
    # calculate loss
    y_pred = model(X_train)
    loss = criterion(y_pred, y_train)
    
    # clear gradient
    optimizer.zero_grad()
    
    # update gradient
    loss.backward()
    
    # update weights with optimizer
    optimizer.step()

    if i % 10 == 0: 
        print(i, loss.item())

with torch.no_grad():
    y_pred = model(X_test)
    y_class = y_pred.round()
    acc = y_class.eq(y_test).sum() / y_test.shape[0]
    print(acc)

0 0.6892776489257812
10 0.5553321242332458
20 0.47372159361839294
30 0.4187549948692322
40 0.37897899746894836
50 0.3486774265766144
60 0.324696809053421
70 0.30515679717063904
80 0.2888660132884979
90 0.2750319242477417
tensor(0.9737)
