In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from load_data import load_skl_data
import numpy as np

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

batch_size = 128
step_size = 0.0025
random_seed = 0
epochs = 500
L2_decay = 1e-4
alpha = 1.

torch.manual_seed(random_seed)

<torch._C.Generator at 0x293926d1470>

In [3]:
train_data, train_labels, val_data, val_labels, test_data, test_labels = load_skl_data('breast_cancer')
test_data = np.vstack((val_data, test_data))
test_labels = np.hstack((val_labels, test_labels))
train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
train_labels = torch.from_numpy(train_labels)
test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_labels = torch.from_numpy(test_labels)
train_mean = torch.mean(train_data, 0)
train_std = torch.std(train_data, 0)
train_data = (train_data - train_mean) / train_std
test_data = (test_data - train_mean) / train_std
train_set = torch.utils.data.TensorDataset(train_data, train_labels)
test_set = torch.utils.data.TensorDataset(test_data, test_labels)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
class fc_model(nn.Module):
    def __init__(self):
        super(fc_model, self).__init__()
        self.fc1 = nn.Linear(30, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)
    def forward(self, inputs):
        fc1_out = F.tanh(self.fc1(inputs))
        fc2_out = F.tanh(self.fc2(fc1_out))
        fc3_out = F.tanh(self.fc3(fc2_out))
        fc4_out = self.fc4(fc3_out)
        return fc4_out

model = fc_model()
print(model)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

fc_model(
  (fc1): Linear(in_features=30, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=1, bias=True)
)


In [5]:
def mixup_breast(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_breast(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        total_loss = mixup_loss + loss_org
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += total_loss.item()
        total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))



0: 2.044673800468445 2.0406798124313354 4.08535361289978
1: 2.019120156764984 2.0096917748451233 4.028811812400818
2: 1.9993300437927246 1.9572952389717102 3.95662522315979
3: 1.9294987916946411 1.8977524638175964 3.8272513151168823
4: 1.9005972146987915 1.8264869451522827 3.727084159851074
5: 1.849746823310852 1.7433170676231384 3.5930639505386353
6: 1.8022889494895935 1.6593273878097534 3.4616163969039917
7: 1.679317057132721 1.573448359966278 3.2527655363082886
8: 1.687528908252716 1.4889679849147797 3.1764968633651733
9: 1.5846096873283386 1.3889218270778656 2.9735316038131714
10: 1.5091869831085205 1.2956005930900574 2.804787576198578
11: 1.5145399570465088 1.1898334622383118 2.704373359680176
12: 1.4189932346343994 1.0928047001361847 2.5117979645729065
13: 1.3486145734786987 1.0004624426364899 2.3490771055221558
14: 1.331528663635254 0.9392843544483185 2.27081298828125
15: 1.2507122159004211 0.8540705740451813 2.10478276014328
16: 1.4715981781482697 0.7808809280395508 2.252479076

135: 1.2606091797351837 0.3160216435790062 1.5766308009624481
136: 0.9637845754623413 0.303374707698822 1.2671593129634857
137: 1.1992859542369843 0.3037807419896126 1.5030667185783386
138: 1.298317939043045 0.29352547228336334 1.5918433964252472
139: 1.0861995071172714 0.3081156760454178 1.3943151384592056
140: 1.044379636645317 0.30553334951400757 1.3499130010604858
141: 1.2598493993282318 0.2957211881875992 1.5555706322193146
142: 1.0085576474666595 0.2979360967874527 1.306493729352951
143: 1.0308417677879333 0.29681286960840225 1.3276546597480774
144: 0.9027444869279861 0.3112296387553215 1.2139740884304047
145: 1.0758727192878723 0.2911331430077553 1.367005854845047
146: 0.873399905860424 0.2859184890985489 1.1593183875083923
147: 1.020448923110962 0.2998952716588974 1.3203442096710205
148: 0.8263221681118011 0.286902591586113 1.113224744796753
149: 0.9636991918087006 0.3095497041940689 1.2732488811016083
150: 1.035943254828453 0.28675682097673416 1.3227000683546066
151: 1.1650434

268: 0.895712286233902 0.27349478006362915 1.1692070662975311
269: 1.5059959888458252 0.2724952697753906 1.778491199016571
270: 1.2927679419517517 0.2773704156279564 1.5701383352279663
271: 1.4460265636444092 0.2737376019358635 1.7197641134262085
272: 1.0884104073047638 0.27258920669555664 1.3609996140003204
273: 0.8878566324710846 0.27782584726810455 1.1656824946403503
274: 1.3279609680175781 0.2762800306081772 1.6042410135269165
275: 1.1058567464351654 0.2747013121843338 1.380558043718338
276: 1.3221862316131592 0.2770207077264786 1.5992069244384766
277: 1.206302136182785 0.28646307438611984 1.492765188217163
278: 0.8189234435558319 0.28222232311964035 1.1011457741260529
279: 1.1881926953792572 0.29621654003858566 1.4844092726707458
280: 1.1040130108594894 0.28957682847976685 1.3935898691415787
281: 1.364516705274582 0.2769203186035156 1.64143705368042
282: 1.3095263838768005 0.284866027534008 1.5943924188613892
283: 0.5354007631540298 0.2809305712580681 0.8163313418626785
284: 1.049

401: 1.29267817735672 0.26971589773893356 1.5623940825462341
402: 0.8168929815292358 0.263143390417099 1.0800363719463348
403: 0.8752667531371117 0.26234041154384613 1.1376071274280548
404: 1.1428338587284088 0.26534922420978546 1.408183068037033
405: 1.0802967250347137 0.2670843303203583 1.3473811000585556
406: 1.342866063117981 0.2575233727693558 1.6003893911838531
407: 0.7277664542198181 0.27167604118585587 0.9994425177574158
408: 0.5524964556097984 0.25391507893800735 0.8064115196466446
409: 1.117228925228119 0.2610400319099426 1.378268986940384
410: 1.1624333560466766 0.26388169080018997 1.426315039396286
411: 0.8561277389526367 0.26810939610004425 1.1242371499538422
412: 1.1676982939243317 0.2569599822163582 1.4246582984924316
413: 1.0362613201141357 0.26338138431310654 1.2996427118778229
414: 0.9485417753458023 0.2620789408683777 1.2106206864118576
415: 1.4555179476737976 0.2524365112185478 1.707954466342926
416: 1.324064016342163 0.2566532641649246 1.580717295408249
417: 1.4269

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_breast_augment')
model = fc_model()
model.load_state_dict(torch.load('./mixup_model_pytorch_breast_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9736842105263158


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9853372434017595
