In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from load_data import load_skl_data
import numpy as np

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

batch_size = 128
step_size = 0.0025
random_seed = 0
epochs = 500
L2_decay = 1e-4
alpha = 1.
perturb_loss_weight = 0.99

torch.manual_seed(random_seed)

<torch._C.Generator at 0x1c8ea242470>

In [3]:
train_data, train_labels, val_data, val_labels, test_data, test_labels = load_skl_data('breast_cancer')
test_data = np.vstack((val_data, test_data))
test_labels = np.hstack((val_labels, test_labels))
train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
train_labels = torch.from_numpy(train_labels)
test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_labels = torch.from_numpy(test_labels)
train_mean = torch.mean(train_data, 0)
train_std = torch.std(train_data, 0)
train_data = (train_data - train_mean) / train_std
test_data = (test_data - train_mean) / train_std
train_set = torch.utils.data.TensorDataset(train_data, train_labels)
test_set = torch.utils.data.TensorDataset(test_data, test_labels)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
class fc_model(nn.Module):
    def __init__(self):
        super(fc_model, self).__init__()
        self.fc1 = nn.Linear(30, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)
    def forward(self, inputs):
        fc1_out = F.tanh(self.fc1(inputs))
        fc2_out = F.tanh(self.fc2(fc1_out))
        fc3_out = F.tanh(self.fc3(fc2_out))
        fc4_out = self.fc4(fc3_out)
        return fc4_out

model = fc_model()
print(model)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

fc_model(
  (fc1): Linear(in_features=30, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=1, bias=True)
)


In [5]:
def mixup_breast(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_breast(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        weighted_total_loss = mixup_loss * perturb_loss_weight + loss_org * (1 - perturb_loss_weight)
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += (mixup_loss.item() + loss_org.item())
        
        weighted_total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))



0: 2.0471267104148865 2.0428451895713806 4.089971899986267
1: 2.0371421575546265 2.0297855138778687 4.066927671432495
2: 2.029387652873993 2.007843255996704 4.037230908870697
3: 1.991696834564209 1.9777859449386597 3.9694827795028687
4: 1.981281042098999 1.9439350962638855 3.9252161383628845
5: 1.9566518664360046 1.9071394801139832 3.863791346549988
6: 1.9371070265769958 1.8711283802986145 3.8082354068756104
7: 1.875202476978302 1.8321066498756409 3.707309126853943
8: 1.8867568373680115 1.7899393439292908 3.6766961812973022
9: 1.8377708792686462 1.754818618297577 3.592589497566223
10: 1.820988953113556 1.7117766737937927 3.5327656269073486
11: 1.805282473564148 1.6798624992370605 3.4851449728012085
12: 1.7907819151878357 1.6327078342437744 3.42348974943161
13: 1.7286030650138855 1.592519760131836 3.3211228251457214
14: 1.6891078352928162 1.5448973178863525 3.2340051531791687
15: 1.662678301334381 1.5058940052986145 3.1685723066329956
16: 1.7190831303596497 1.4593437612056732 3.17842689

135: 1.1691966950893402 0.4504515677690506 1.6196482628583908
136: 1.061656802892685 0.4609376788139343 1.5225944817066193
137: 1.115433931350708 0.455592542886734 1.571026474237442
138: 1.1799685060977936 0.4559183567762375 1.635886862874031
139: 1.1567333489656448 0.4459766298532486 1.6027099788188934
140: 0.92243292927742 0.45215701311826706 1.374589942395687
141: 1.253379374742508 0.44284819066524506 1.696227565407753
142: 0.9490256905555725 0.4504431337118149 1.3994688242673874
143: 1.0474419742822647 0.4498634785413742 1.497305452823639
144: 0.892090380191803 0.451057493686676 1.343147873878479
145: 1.1104298532009125 0.44086863100528717 1.5512984842061996
146: 0.8890916109085083 0.4462655633687973 1.3353571742773056
147: 1.0745422542095184 0.43841317296028137 1.5129554271697998
148: 0.8894397467374802 0.4490896314382553 1.3385293781757355
149: 0.9183914065361023 0.4406631737947464 1.3590545803308487
150: 1.0149857699871063 0.4332006275653839 1.4481863975524902
151: 1.05310982465

268: 0.8088594675064087 0.40758616477251053 1.2164456322789192
269: 1.329753339290619 0.412514328956604 1.742267668247223
270: 1.2361213862895966 0.4073999673128128 1.6435213536024094
271: 1.181550532579422 0.4073808342218399 1.588931366801262
272: 1.1211144030094147 0.4407980740070343 1.561912477016449
273: 0.865449845790863 0.41436219960451126 1.2798120453953743
274: 1.1384339332580566 0.4105965867638588 1.5490305200219154
275: 1.0706208646297455 0.4173689931631088 1.4879898577928543
276: 1.1575074195861816 0.41772758960723877 1.5752350091934204
277: 0.9783556163311005 0.4102698937058449 1.3886255100369453
278: 0.8305497467517853 0.4095241278409958 1.240073874592781
279: 1.1425999999046326 0.4260272830724716 1.5686272829771042
280: 0.9896824210882187 0.41238920390605927 1.402071624994278
281: 1.2387690544128418 0.4188104271888733 1.657579481601715
282: 1.2593353390693665 0.4247693866491318 1.6841047257184982
283: 0.6057218462228775 0.42304685711860657 1.028768703341484
284: 0.9892513

401: 1.168251782655716 0.4032699167728424 1.5715216994285583
402: 0.819391280412674 0.39080530405044556 1.2101965844631195
403: 0.9196958839893341 0.40926551818847656 1.3289614021778107
404: 1.2069666385650635 0.40046264231204987 1.6074292808771133
405: 0.9887629449367523 0.3957927003502846 1.384555645287037
406: 1.1502066850662231 0.3939432054758072 1.5441498905420303
407: 0.7459203153848648 0.3984140232205391 1.144334338605404
408: 0.6840984523296356 0.4044248014688492 1.0885232537984848
409: 1.1311912834644318 0.4088060110807419 1.5399972945451736
410: 1.1577954292297363 0.4111301302909851 1.5689255595207214
411: 0.9066527187824249 0.388245552778244 1.294898271560669
412: 1.1550244390964508 0.3947813957929611 1.549805834889412
413: 0.9789033830165863 0.4098517671227455 1.3887551501393318
414: 0.8157884180545807 0.38192256540060043 1.1977109834551811
415: 1.1546483039855957 0.39715710282325745 1.5518054068088531
416: 1.1250407993793488 0.40432850271463394 1.5293693020939827
417: 1.19

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_breast_augment')
model = fc_model()
model.load_state_dict(torch.load('./mixup_model_pytorch_breast_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9780701754385965


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9853372434017595
