In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from load_data import load_skl_data
import numpy as np

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

batch_size = 128
step_size = 0.0015
random_seed = 0
epochs = 500
L2_decay = 1e-4
gauss_vicinal_std = 0.25

torch.manual_seed(random_seed)

<torch._C.Generator at 0x222bf3c3470>

In [3]:
train_data, train_labels, val_data, val_labels, test_data, test_labels = load_skl_data('breast_cancer')
test_data = np.vstack((val_data, test_data))
test_labels = np.hstack((val_labels, test_labels))
train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
train_labels = torch.from_numpy(train_labels)
test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_labels = torch.from_numpy(test_labels)
train_mean = torch.mean(train_data, 0)
train_std = torch.std(train_data, 0)
train_data = (train_data - train_mean) / train_std
test_data = (test_data - train_mean) / train_std
train_set = torch.utils.data.TensorDataset(train_data, train_labels)
test_set = torch.utils.data.TensorDataset(test_data, test_labels)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
class fc_model(nn.Module):
    def __init__(self):
        super(fc_model, self).__init__()
        self.fc1 = nn.Linear(30, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)
    def forward(self, inputs):
        fc1_out = F.tanh(self.fc1(inputs))
        fc2_out = F.tanh(self.fc2(fc1_out))
        fc3_out = F.tanh(self.fc3(fc2_out))
        fc4_out = self.fc4(fc3_out)
        return fc4_out

model = fc_model()
print(model)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

fc_model(
  (fc1): Linear(in_features=30, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=1, bias=True)
)


In [5]:
def gauss_vicinal(inputs, gauss_vicinal_std):
    inputs_gauss = torch.normal(inputs, gauss_vicinal_std)
    return inputs_gauss

In [6]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_gauss_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        inputs_gauss = gauss_vicinal(inputs, gauss_vicinal_std)
        optimizer.zero_grad()
        outputs = model(inputs_gauss)
        
        ##
        gauss_loss = criterion(outputs, labels)
        
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        total_loss = gauss_loss + loss_org
        
        epoch_gauss_loss += gauss_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += total_loss.item()
        total_loss.backward()
        ##

        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_gauss_loss, epoch_org_loss, epoch_loss))



0: 2.044547140598297 2.045577645301819 4.090124845504761
1: 2.0271178483963013 2.0248252153396606 4.051943063735962
2: 1.9910165071487427 1.9909996390342712 3.9820162057876587
3: 1.9455761313438416 1.9440457820892334 3.8896219730377197
4: 1.8958339095115662 1.8941122889518738 3.7899460792541504
5: 1.8431907892227173 1.8377187848091125 3.6809096336364746
6: 1.7827086448669434 1.7787753343582153 3.5614839792251587
7: 1.7220118045806885 1.7138314843177795 3.435843348503113
8: 1.6604631543159485 1.650818407535553 3.3112815618515015
9: 1.5866779088974 1.5800595879554749 3.1667375564575195
10: 1.521092176437378 1.5115066170692444 3.0325987339019775
11: 1.4453860223293304 1.4340205788612366 2.8794065713882446
12: 1.3671723902225494 1.361830085515976 2.7290024161338806
13: 1.2875382900238037 1.2814837396144867 2.569022059440613
14: 1.2127504348754883 1.202123612165451 2.414874017238617
15: 1.1345266997814178 1.1283094882965088 2.262836217880249
16: 1.0600744783878326 1.041613221168518 2.101687

132: 0.08844645321369171 0.06986645795404911 0.15831291303038597
133: 0.10383858159184456 0.07122637704014778 0.17506495118141174
134: 0.091525848954916 0.07120111584663391 0.1627269685268402
135: 0.10225638933479786 0.07396359741687775 0.17621998861432076
136: 0.08138798922300339 0.0678613567724824 0.1492493450641632
137: 0.09092517383396626 0.0690299067646265 0.15995508432388306
138: 0.10760414320975542 0.06483187433332205 0.17243601940572262
139: 0.08064592070877552 0.07169408909976482 0.1523400079458952
140: 0.11753558367490768 0.06710299663245678 0.1846385858952999
141: 0.08137056976556778 0.0719719473272562 0.15334251150488853
142: 0.07771099358797073 0.06615331210196018 0.14386430382728577
143: 0.08288256265223026 0.06866019126027822 0.1515427529811859
144: 0.08616712130606174 0.06454446353018284 0.15071158111095428
145: 0.10345468297600746 0.06727606616914272 0.17073074728250504
146: 0.08973154239356518 0.06726424116641283 0.15699577704071999
147: 0.09301426075398922 0.06222580

259: 0.03743298910558224 0.035339586436748505 0.07277257554233074
260: 0.05289844051003456 0.035326800774782896 0.08822524175047874
261: 0.06286236830055714 0.033919292502105236 0.0967816598713398
262: 0.07399231754243374 0.034556809812784195 0.10854912362992764
263: 0.055361753329634666 0.03353521693497896 0.08889696933329105
264: 0.05480913259088993 0.03283026907593012 0.08763940259814262
265: 0.05168780870735645 0.033363780938088894 0.08505159243941307
266: 0.05819206964224577 0.033085050992667675 0.0912771187722683
267: 0.06481045205146074 0.03516020579263568 0.09997065924108028
268: 0.06637414544820786 0.03301012422889471 0.09938427060842514
269: 0.05913981702178717 0.03475333284586668 0.093893151730299
270: 0.07951104827225208 0.032590001821517944 0.11210105288773775
271: 0.0532782981172204 0.03275013901293278 0.08602843806147575
272: 0.0813229437917471 0.03168667294085026 0.11300961673259735
273: 0.05379283335059881 0.030432533007115126 0.08422536589205265
274: 0.062838448211550

384: 0.05113812256604433 0.021242122165858746 0.07238024473190308
385: 0.08004909614101052 0.020500957500189543 0.10055005364120007
386: 0.04956096317619085 0.021323848981410265 0.07088481448590755
387: 0.03458487335592508 0.023925135377794504 0.058510009199380875
388: 0.057594554498791695 0.02119725150987506 0.07879180647432804
389: 0.08585366886109114 0.021908671129494905 0.10776233859360218
390: 0.02473483933135867 0.019999571377411485 0.0447344109416008
391: 0.0628516310825944 0.02225023042410612 0.08510186057537794
392: 0.06794464495033026 0.021883924258872867 0.08982856944203377
393: 0.02546567190438509 0.020627913065254688 0.04609358496963978
394: 0.0625244383700192 0.0224553938023746 0.08497983124107122
395: 0.04156609904021025 0.02095845155417919 0.06252455059438944
396: 0.05307464674115181 0.020796794211491942 0.0738714411854744
397: 0.07188298646360636 0.02077141823247075 0.0926544051617384
398: 0.11968105472624302 0.02012230595573783 0.13980336114764214
399: 0.0392163116484

In [7]:
torch.save(model.state_dict(), './gauss_model_pytorch_breast_augment')
model = fc_model()
model.load_state_dict(torch.load('./gauss_model_pytorch_breast_augment'))

<All keys matched successfully>

In [8]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9649122807017544


In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

1.0
