In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from load_data import load_skl_data
import numpy as np

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

batch_size = 128
step_size = 0.0025
random_seed = 0
epochs = 500
L2_decay = 1e-4
alpha = 1.
perturb_loss_weight = 0.75

torch.manual_seed(random_seed)

<torch._C.Generator at 0x1761ec95470>

In [3]:
train_data, train_labels, val_data, val_labels, test_data, test_labels = load_skl_data('breast_cancer')
test_data = np.vstack((val_data, test_data))
test_labels = np.hstack((val_labels, test_labels))
train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
train_labels = torch.from_numpy(train_labels)
test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_labels = torch.from_numpy(test_labels)
train_mean = torch.mean(train_data, 0)
train_std = torch.std(train_data, 0)
train_data = (train_data - train_mean) / train_std
test_data = (test_data - train_mean) / train_std
train_set = torch.utils.data.TensorDataset(train_data, train_labels)
test_set = torch.utils.data.TensorDataset(test_data, test_labels)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
class fc_model(nn.Module):
    def __init__(self):
        super(fc_model, self).__init__()
        self.fc1 = nn.Linear(30, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)
    def forward(self, inputs):
        fc1_out = F.tanh(self.fc1(inputs))
        fc2_out = F.tanh(self.fc2(fc1_out))
        fc3_out = F.tanh(self.fc3(fc2_out))
        fc4_out = self.fc4(fc3_out)
        return fc4_out

model = fc_model()
print(model)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

fc_model(
  (fc1): Linear(in_features=30, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=1, bias=True)
)


In [5]:
def mixup_breast(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_breast(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        weighted_total_loss = mixup_loss * perturb_loss_weight + loss_org * (1 - perturb_loss_weight)
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += (mixup_loss.item() + loss_org.item())
        
        weighted_total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))



0: 2.046876907348633 2.0399494767189026 4.086826384067535
1: 2.0319778323173523 2.026467502117157 4.058445334434509
2: 2.0203582644462585 2.0010411739349365 4.021399438381195
3: 1.9859657287597656 1.9705966114997864 3.956562340259552
4: 1.9716074466705322 1.9338880777359009 3.905495524406433
5: 1.9511802792549133 1.892950713634491 3.8441309928894043
6: 1.9255191087722778 1.8561509847640991 3.781670093536377
7: 1.8596004247665405 1.8181761503219604 3.677776575088501
8: 1.8804265260696411 1.7698708772659302 3.6502974033355713
9: 1.8214162588119507 1.7260793447494507 3.5474956035614014
10: 1.792328655719757 1.6850735545158386 3.4774022102355957
11: 1.7820415496826172 1.6426143050193787 3.424655854701996
12: 1.7494288682937622 1.5962571501731873 3.3456860184669495
13: 1.7145826816558838 1.5512292981147766 3.2658119797706604
14: 1.6647958159446716 1.4986023008823395 3.163398116827011
15: 1.6490156650543213 1.4572288393974304 3.1062445044517517
16: 1.6599730253219604 1.413318395614624 3.0732

135: 1.1731794774532318 0.38292646408081055 1.5561059415340424
136: 0.983438104391098 0.38822487741708755 1.3716629818081856
137: 1.0711035430431366 0.3839276507496834 1.45503119379282
138: 1.1675423979759216 0.3739970847964287 1.5415394827723503
139: 1.1565122604370117 0.3688741996884346 1.5253864601254463
140: 0.9656000137329102 0.379307396709919 1.3449074104428291
141: 1.1612784564495087 0.36598850786685944 1.527266964316368
142: 1.000545710325241 0.39657243341207504 1.3971181437373161
143: 1.052235797047615 0.38662751019001007 1.4388633072376251
144: 0.8616255819797516 0.386377714574337 1.2480032965540886
145: 1.147466242313385 0.38262179493904114 1.5300880372524261
146: 0.8613082766532898 0.3802953436970711 1.2416036203503609
147: 1.058737576007843 0.3911030516028404 1.4498406276106834
148: 0.9233545064926147 0.36811240017414093 1.2914669066667557
149: 0.9115927517414093 0.36761223524808884 1.2792049869894981
150: 1.0429856926202774 0.3759893774986267 1.4189750701189041
151: 1.113

268: 0.7914650067687035 0.3606719970703125 1.152137003839016
269: 1.4109699130058289 0.35330506414175034 1.7642749771475792
270: 1.2092819213867188 0.3511219695210457 1.5604038909077644
271: 1.290676236152649 0.3462914749979973 1.6369677111506462
272: 1.1463123857975006 0.3525935113430023 1.498905897140503
273: 0.8030145764350891 0.35417303442955017 1.1571876108646393
274: 1.1572430431842804 0.3509327173233032 1.5081757605075836
275: 1.1538161039352417 0.35865335166454315 1.5124694555997849
276: 1.2803171277046204 0.3520459905266762 1.6323631182312965
277: 1.0657977163791656 0.3591430336236954 1.424940750002861
278: 0.7578674107789993 0.35172753036022186 1.1095949411392212
279: 1.0618066191673279 0.35823874920606613 1.420045368373394
280: 1.0116283148527145 0.36339331418275833 1.3750216290354729
281: 1.1188707053661346 0.34869492799043655 1.4675656333565712
282: 1.3274042308330536 0.3653407692909241 1.6927450001239777
283: 0.597285270690918 0.3490118384361267 0.9462971091270447
284: 1.

401: 1.3364212810993195 0.33929334580898285 1.6757146269083023
402: 0.7706766426563263 0.34435419738292694 1.1150308400392532
403: 0.8758959323167801 0.3376108929514885 1.2135068252682686
404: 1.1580677032470703 0.3497527688741684 1.5078204721212387
405: 0.9906221032142639 0.35273291170597076 1.3433550149202347
406: 1.2446665167808533 0.3383788764476776 1.5830453932285309
407: 0.7565347105264664 0.3519188463687897 1.108453556895256
408: 0.6316870227456093 0.33651791512966156 0.9682049378752708
409: 1.0366118252277374 0.32706601172685623 1.3636778369545937
410: 1.0655322670936584 0.3363558501005173 1.4018881171941757
411: 0.8587306141853333 0.333199217915535 1.1919298321008682
412: 1.110163301229477 0.32859931886196136 1.4387626200914383
413: 0.9911385774612427 0.32602357119321823 1.317162148654461
414: 0.8448314815759659 0.34507282078266144 1.1899043023586273
415: 1.271856665611267 0.34097060561180115 1.6128272712230682
416: 1.1893716156482697 0.33050937950611115 1.5198809951543808
417

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_breast_augment')
model = fc_model()
model.load_state_dict(torch.load('./mixup_model_pytorch_breast_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9649122807017544


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9824046920821115
