In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from load_data import load_skl_data
import numpy as np

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

batch_size = 128
step_size = 0.0025
random_seed = 0
epochs = 500
L2_decay = 1e-4
alpha = 1.
perturb_loss_weight = 0.25

torch.manual_seed(random_seed)

<torch._C.Generator at 0x2282aaf4470>

In [3]:
train_data, train_labels, val_data, val_labels, test_data, test_labels = load_skl_data('breast_cancer')
test_data = np.vstack((val_data, test_data))
test_labels = np.hstack((val_labels, test_labels))
train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
train_labels = torch.from_numpy(train_labels)
test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_labels = torch.from_numpy(test_labels)
train_mean = torch.mean(train_data, 0)
train_std = torch.std(train_data, 0)
train_data = (train_data - train_mean) / train_std
test_data = (test_data - train_mean) / train_std
train_set = torch.utils.data.TensorDataset(train_data, train_labels)
test_set = torch.utils.data.TensorDataset(test_data, test_labels)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
class fc_model(nn.Module):
    def __init__(self):
        super(fc_model, self).__init__()
        self.fc1 = nn.Linear(30, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)
    def forward(self, inputs):
        fc1_out = F.tanh(self.fc1(inputs))
        fc2_out = F.tanh(self.fc2(fc1_out))
        fc3_out = F.tanh(self.fc3(fc2_out))
        fc4_out = self.fc4(fc3_out)
        return fc4_out

model = fc_model()
print(model)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

fc_model(
  (fc1): Linear(in_features=30, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=1, bias=True)
)


In [5]:
def mixup_breast(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_breast(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        weighted_total_loss = mixup_loss * perturb_loss_weight + loss_org * (1 - perturb_loss_weight)
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += (mixup_loss.item() + loss_org.item())
        
        weighted_total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))



0: 2.0460712909698486 2.041869282722473 4.087940573692322
1: 2.0322369933128357 2.0255582332611084 4.057795226573944
2: 2.0135964155197144 1.996120035648346 4.00971645116806
3: 1.9792700409889221 1.9637529253959656 3.9430229663848877
4: 1.9631226658821106 1.9222071766853333 3.885329842567444
5: 1.9414384961128235 1.876710593700409 3.8181490898132324
6: 1.9124352931976318 1.832288920879364 3.744724214076996
7: 1.8272262215614319 1.783295214176178 3.61052143573761
8: 1.8360344767570496 1.7356398701667786 3.571674346923828
9: 1.8000473976135254 1.6867690086364746 3.48681640625
10: 1.7553472518920898 1.632099449634552 3.387446701526642
11: 1.7488099336624146 1.576772153377533 3.3255820870399475
12: 1.702246367931366 1.5210442543029785 3.2232906222343445
13: 1.6262482404708862 1.464089572429657 3.090337812900543
14: 1.5957139730453491 1.4073049426078796 3.0030189156532288
15: 1.5762674808502197 1.3528970777988434 2.929164558649063
16: 1.6035785675048828 1.2916239500045776 2.8952025175094604

135: 1.2843068689107895 0.20351479202508926 1.4878216609358788
136: 1.0661578178405762 0.2009727582335472 1.2671305760741234
137: 1.269296258687973 0.2100590467453003 1.4793553054332733
138: 1.5153158903121948 0.2026846706867218 1.7180005609989166
139: 1.147338941693306 0.19724802300333977 1.3445869646966457
140: 1.0199501514434814 0.194389458745718 1.2143396101891994
141: 1.3764547407627106 0.1898689605295658 1.5663237012922764
142: 1.0546560883522034 0.20456792041659355 1.259224008768797
143: 1.0961093455553055 0.19327741116285324 1.2893867567181587
144: 1.0064391642808914 0.19886640086770058 1.205305565148592
145: 1.1860433220863342 0.19738641753792763 1.3834297396242619
146: 0.8294548839330673 0.1971050687134266 1.026559952646494
147: 1.185952365398407 0.19522829353809357 1.3811806589365005
148: 0.977034717798233 0.19779562950134277 1.1748303472995758
149: 0.9877348020672798 0.20024849474430084 1.1879832968115807
150: 1.0896282941102982 0.2049873173236847 1.2946156114339828
151: 1.

267: 0.838202491402626 0.16068492457270622 0.9988874159753323
268: 0.84328593313694 0.1589210033416748 1.0022069364786148
269: 1.776093304157257 0.1737479455769062 1.9498412497341633
270: 1.4259012937545776 0.15947891399264336 1.585380207747221
271: 1.5819818079471588 0.15682969614863396 1.7388115040957928
272: 1.4211529791355133 0.17233704775571823 1.5934900268912315
273: 0.9310109615325928 0.16330474615097046 1.0943157076835632
274: 1.4732015132904053 0.16380691900849342 1.6370084322988987
275: 1.2260515838861465 0.17067956551909447 1.396731149405241
276: 1.6317744255065918 0.1715935841202736 1.8033680096268654
277: 1.2704474925994873 0.16454028710722923 1.4349877797067165
278: 0.7791505306959152 0.16019291058182716 0.9393434412777424
279: 1.2321017384529114 0.16249947622418404 1.3946012146770954
280: 1.1212950348854065 0.17511854320764542 1.296413578093052
281: 1.6048288345336914 0.16107231006026268 1.765901144593954
282: 1.5346029698848724 0.16668039560317993 1.7012833654880524
283

399: 1.1223298907279968 0.1552724353969097 1.2776023261249065
400: 1.3010914623737335 0.15323522314429283 1.4543266855180264
401: 1.5976471602916718 0.1578531675040722 1.755500327795744
402: 0.868842139840126 0.1511685997247696 1.0200107395648956
403: 0.9947040751576424 0.16044582054018974 1.155149895697832
404: 1.4155129492282867 0.15332595258951187 1.5688389018177986
405: 1.1673442721366882 0.15938902646303177 1.32673329859972
406: 1.5163486003875732 0.15242059901356697 1.6687691994011402
407: 0.7805293649435043 0.15777451917529106 0.9383038841187954
408: 0.5069420300424099 0.15758949518203735 0.6645315252244473
409: 1.2388770580291748 0.15699118748307228 1.395868245512247
410: 1.3252348005771637 0.1580769382417202 1.483311738818884
411: 0.95259840041399 0.15143398940563202 1.104032389819622
412: 1.3595705926418304 0.15419884398579597 1.5137694366276264
413: 1.114592045545578 0.14769031293690205 1.26228235848248
414: 0.9892724305391312 0.15428083389997482 1.143553264439106
415: 1.646

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_breast_augment')
model = fc_model()
model.load_state_dict(torch.load('./mixup_model_pytorch_breast_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9605263157894737


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9941348973607038
