In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from load_data import load_skl_data
import numpy as np

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

batch_size = 128
step_size = 0.0025
random_seed = 0
epochs = 500
L2_decay = 1e-4
alpha = 1.
perturb_loss_weight = 0.9

torch.manual_seed(random_seed)

<torch._C.Generator at 0x2722c575470>

In [3]:
train_data, train_labels, val_data, val_labels, test_data, test_labels = load_skl_data('breast_cancer')
test_data = np.vstack((val_data, test_data))
test_labels = np.hstack((val_labels, test_labels))
train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
train_labels = torch.from_numpy(train_labels)
test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_labels = torch.from_numpy(test_labels)
train_mean = torch.mean(train_data, 0)
train_std = torch.std(train_data, 0)
train_data = (train_data - train_mean) / train_std
test_data = (test_data - train_mean) / train_std
train_set = torch.utils.data.TensorDataset(train_data, train_labels)
test_set = torch.utils.data.TensorDataset(test_data, test_labels)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
class fc_model(nn.Module):
    def __init__(self):
        super(fc_model, self).__init__()
        self.fc1 = nn.Linear(30, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)
    def forward(self, inputs):
        fc1_out = F.tanh(self.fc1(inputs))
        fc2_out = F.tanh(self.fc2(fc1_out))
        fc3_out = F.tanh(self.fc3(fc2_out))
        fc4_out = self.fc4(fc3_out)
        return fc4_out

model = fc_model()
print(model)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

fc_model(
  (fc1): Linear(in_features=30, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=1, bias=True)
)


In [5]:
def mixup_breast(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_breast(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        weighted_total_loss = mixup_loss * perturb_loss_weight + loss_org * (1 - perturb_loss_weight)
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += (mixup_loss.item() + loss_org.item())
        
        weighted_total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))



0: 2.053349494934082 2.0472776293754578 4.10062712430954
1: 2.039411783218384 2.034392476081848 4.073804259300232
2: 2.0297462940216064 2.008290708065033 4.038037002086639
3: 1.9940574765205383 1.9798986315727234 3.9739561080932617
4: 1.9818975329399109 1.9455218315124512 3.927419364452362
5: 1.9645658731460571 1.9100894331932068 3.874655306339264
6: 1.9411677718162537 1.8703407645225525 3.811508536338806
7: 1.8923071026802063 1.8320188522338867 3.724325954914093
8: 1.8734227418899536 1.7913270592689514 3.664749801158905
9: 1.8455598950386047 1.7497926354408264 3.595352530479431
10: 1.8188753128051758 1.7092517614364624 3.528127074241638
11: 1.7944123148918152 1.668013572692871 3.4624258875846863
12: 1.7922200560569763 1.62907212972641 3.4212921857833862
13: 1.7291628122329712 1.5797369480133057 3.308899760246277
14: 1.699225664138794 1.542140543460846 3.24136620759964
15: 1.664763629436493 1.4962786138057709 3.161042243242264
16: 1.722664713859558 1.4509125649929047 3.1735772788524628

136: 1.0965550541877747 0.4458543509244919 1.5424094051122665
137: 1.2309154570102692 0.4305751696228981 1.6614906266331673
138: 1.2277145683765411 0.43373341858386993 1.661447986960411
139: 1.038579523563385 0.45154881477355957 1.4901283383369446
140: 0.9527525305747986 0.4442572742700577 1.3970098048448563
141: 1.2122655510902405 0.4448719322681427 1.6571374833583832
142: 1.0037256181240082 0.4465469568967819 1.45027257502079
143: 0.9841763228178024 0.43720903992652893 1.4213853627443314
144: 0.9405811429023743 0.43769530951976776 1.378276452422142
145: 1.0787317156791687 0.4340311884880066 1.5127629041671753
146: 0.9132471531629562 0.43227531015872955 1.3455224633216858
147: 1.0040011405944824 0.4197639971971512 1.4237651377916336
148: 0.7981731295585632 0.43827682733535767 1.236449956893921
149: 0.9858585149049759 0.4210282862186432 1.406886801123619
150: 1.06795035302639 0.43625734746456146 1.5042077004909515
151: 1.1046622693538666 0.42531512677669525 1.5299773961305618
152: 1.24

269: 1.4650956690311432 0.3997972011566162 1.8648928701877594
270: 1.3194960057735443 0.39298345148563385 1.7124794572591782
271: 1.3004919588565826 0.3965262100100517 1.6970181688666344
272: 1.1156629920005798 0.4085252583026886 1.5241882503032684
273: 0.8888697028160095 0.4045463725924492 1.2934160754084587
274: 1.2040630280971527 0.41719387471675873 1.6212569028139114
275: 1.0382597893476486 0.4119175225496292 1.4501773118972778
276: 1.2126422226428986 0.4257049188017845 1.638347141444683
277: 1.075833797454834 0.4092900976538658 1.4851238951086998
278: 0.8209586590528488 0.43570995330810547 1.2566686123609543
279: 1.14056096971035 0.41843584179878235 1.5589968115091324
280: 1.0090510249137878 0.414963461458683 1.4240144863724709
281: 1.2979839146137238 0.40655558556318283 1.7045395001769066
282: 1.1631640791893005 0.39909446239471436 1.562258541584015
283: 0.5917696356773376 0.4084503725171089 1.0002200081944466
284: 1.129013478755951 0.4143582731485367 1.5433717519044876
285: 1.38

402: 0.8101433664560318 0.39615505933761597 1.2062984257936478
403: 0.9205784350633621 0.38998159766197205 1.3105600327253342
404: 1.1765262186527252 0.3934961259365082 1.5700223445892334
405: 1.0019189715385437 0.3935100734233856 1.3954290449619293
406: 1.2986542880535126 0.3821342885494232 1.6807885766029358
407: 0.834783524274826 0.3900289237499237 1.2248124480247498
408: 0.7010707557201385 0.3916151449084282 1.0926859006285667
409: 1.0025922358036041 0.3852691203355789 1.387861356139183
410: 1.1403018534183502 0.3967040777206421 1.5370059311389923
411: 0.8737642765045166 0.3832848221063614 1.257049098610878
412: 1.244022697210312 0.3929932042956352 1.6370159015059471
413: 1.0098336338996887 0.37841135263442993 1.3882449865341187
414: 0.8511381298303604 0.3696480691432953 1.2207861989736557
415: 1.2620094418525696 0.3809656351804733 1.642975077033043
416: 1.2011562287807465 0.38902053236961365 1.59017676115036
417: 1.212352991104126 0.37795159220695496 1.590304583311081
418: 1.05572

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_breast_augment')
model = fc_model()
model.load_state_dict(torch.load('./mixup_model_pytorch_breast_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9780701754385965


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
        outputs = model(inputs)
        predicts = (torch.sign(outputs) + 1) / 2
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9824046920821115
