In [142]:
!pip install pytorch-tabnet



In [143]:
from pytorch_tabnet.pretraining import TabNetPretrainer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
#df = pd.read_csv("train.csv")
#詳しくはここhttps://github.com/dreamquark-ai/tabnet/tree/develop

In [144]:
from sklearn.datasets import fetch_california_housing
data = fetch_california_housing()

In [145]:
df = np.concatenate([data.data, data.target.reshape(20640,1)], axis = 1)

In [146]:
X_train, X_valid, X_test = df[:int(0.8*len(df))], df[int(0.8*len(df)):int(0.9*len(df))],df[int(0.9*len(df)):]

In [147]:
unsupervised_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    mask_type='entmax' # "sparsemax"
)

unsupervised_model.fit(
    X_train=X_train,
    eval_set=[X_valid],
    pretraining_ratio=0.8,
    max_epochs=500,
    patience=35,
)



epoch 0  | loss: 418.27955| val_0_unsup_loss_numpy: 3633.772705078125|  0:00:02s
epoch 1  | loss: 290.70801| val_0_unsup_loss_numpy: 2458.705810546875|  0:00:05s
epoch 2  | loss: 138.6288| val_0_unsup_loss_numpy: 2745.99755859375|  0:00:07s
epoch 3  | loss: 26.27022| val_0_unsup_loss_numpy: 1885.794677734375|  0:00:10s
epoch 4  | loss: 3.8508  | val_0_unsup_loss_numpy: 1768.054931640625|  0:00:13s
epoch 5  | loss: 2.18467 | val_0_unsup_loss_numpy: 987.9741821289062|  0:00:15s
epoch 6  | loss: 2.03046 | val_0_unsup_loss_numpy: 376.2887878417969|  0:00:18s
epoch 7  | loss: 1.62305 | val_0_unsup_loss_numpy: 1765.5413818359375|  0:00:20s
epoch 8  | loss: 1.43801 | val_0_unsup_loss_numpy: 582.4559936523438|  0:00:22s
epoch 9  | loss: 1.30222 | val_0_unsup_loss_numpy: 440.5889892578125|  0:00:24s
epoch 10 | loss: 1.36428 | val_0_unsup_loss_numpy: 377.00537109375|  0:00:26s
epoch 11 | loss: 1.29948 | val_0_unsup_loss_numpy: 446.3124694824219|  0:00:27s
epoch 12 | loss: 1.2297  | val_0_unsup_l



In [148]:
unsupervised_encoder = unsupervised_model.network.encoder
unsupervised_decoder = unsupervised_model.network.decoder

In [149]:
batch_data = torch.tensor(X_train).to(torch.float32).to("cuda")

In [150]:
def create_noise(col, row, device = "cuda"):
  noise = [torch.tensor(np.random.rand(col, row)).to(torch.float32).to(device) for _ in range(3)]
  return noise
noise = create_noise(3, 8,)

In [151]:
class Discriminator(nn.Module):
  def __init__(self, encoder, features):
    super().__init__()
    for param in encoder.parameters():
      param.requires_grad = False
    self.encoder = encoder
    self.features = features #decoderに通した後の出力のrow数を意味する
    self.ffn = nn.Linear(self.features*3, self.features, bias = False)
    self.relu = nn.ReLU()
    self.linear = nn.Linear(self.features, 2, bias = False)
    self.sigmoid = nn.Sigmoid()
  def forward(self, table_data):
    extracted_data, _ = self.encoder(table_data)
    x = extracted_data[0]
    for ext_data in extracted_data[1:]:
      x = torch.concatenate([x, ext_data], dim = 1)
    x = self.ffn(x)
    x = self.relu(x)
    x = self.linear(x)
    x = self.sigmoid(x)
    return x

In [152]:
disc = Discriminator(unsupervised_encoder, 8).to("cuda")
gene = unsupervised_decoder.to("cuda")
optimizerD = optim.Adam(disc.parameters(), lr = 1e-4)
optimizerG = optim.Adam(gene.parameters(), lr = 1e-4)
col = 16512
row = 8
def trainTGAN(optimizerD, optimizerG, real_data, device = "cuda"):
  #まず識別器のモデルを改善した後に生成器のモデルを改善する
    criterion = nn.BCELoss()
    optimizerD.zero_grad()
    batch_size = real_data.size(0)
    real_proba = disc(real_data) #データの形状が合わない時はforward内部で入力ように整形変更することを想定して制作。
    real_label = torch.ones_like(real_proba).to(device)
    real_label[:,0] = 0
    D_loss_real = criterion(real_proba, real_label)
    input_z = create_noise(col, row,)
    fake_data = gene(input_z)
    fake_proba = disc(fake_data)
    fake_label = torch.ones_like(fake_proba).to(device)
    fake_label[:,1] = 0
    D_loss_fake = criterion(fake_proba,fake_label)
    D_loss = D_loss_fake + D_loss_real
    D_loss.backward()
    optimizerD.step()
    #D_lossのところでGenerate Modelのパラメータの勾配も計算されている為、Generate Modelに関する勾配はここで初期化を行う
    optimizerG.zero_grad()
    input_z = create_noise(col, row,)
    fake_data = gene(input_z)
    fake_proba = disc(fake_data)
    real_label = torch.ones_like(fake_proba).to(device)
    real_label[:,0] = 0 #ラベルが1になるように(騙すように)訓練するので、ラベルは1を予測するように作る
    G_loss = criterion(fake_proba,real_label)
    G_loss.backward()
    optimizerG.step()
    return D_loss.detach().item(), G_loss.detach().item()

In [153]:
epochs = 750
D_losses = []
G_losses = []
from tqdm import tqdm
for epoch in tqdm(range(epochs)):
  D_loss, G_loss = trainTGAN(optimizerD, optimizerG, batch_data, device = "cuda")
  if epoch % 10 == 0:
        print(f"Discriminator Loss: {D_loss},Generater Loss: {G_loss}")
print(f"Discriminator Loss: {D_loss},Generater Loss: {G_loss}")

  0%|          | 1/750 [00:00<08:22,  1.49it/s]

Discriminator Loss: 1.3839466571807861,Generater Loss: 0.7142834067344666


  1%|▏         | 11/750 [00:07<08:02,  1.53it/s]

Discriminator Loss: 1.3837580680847168,Generater Loss: 0.7140856385231018


  3%|▎         | 21/750 [00:14<08:23,  1.45it/s]

Discriminator Loss: 1.3836052417755127,Generater Loss: 0.7139275074005127


  4%|▍         | 31/750 [00:21<09:11,  1.30it/s]

Discriminator Loss: 1.3835499286651611,Generater Loss: 0.7135618925094604


  5%|▌         | 41/750 [00:29<08:05,  1.46it/s]

Discriminator Loss: 1.383561134338379,Generater Loss: 0.7133122086524963


  7%|▋         | 51/750 [00:37<11:50,  1.02s/it]

Discriminator Loss: 1.3836033344268799,Generater Loss: 0.7130866050720215


  8%|▊         | 61/750 [00:45<08:36,  1.33it/s]

Discriminator Loss: 1.3836809396743774,Generater Loss: 0.7128545641899109


  9%|▉         | 71/750 [00:52<08:05,  1.40it/s]

Discriminator Loss: 1.3837584257125854,Generater Loss: 0.7127472162246704


 11%|█         | 81/750 [00:59<07:16,  1.53it/s]

Discriminator Loss: 1.3837575912475586,Generater Loss: 0.7124091982841492


 12%|█▏        | 91/750 [01:06<07:30,  1.46it/s]

Discriminator Loss: 1.3836590051651,Generater Loss: 0.7121677994728088


 13%|█▎        | 101/750 [01:13<07:42,  1.40it/s]

Discriminator Loss: 1.38375723361969,Generater Loss: 0.7117449641227722


 15%|█▍        | 111/750 [01:21<07:12,  1.48it/s]

Discriminator Loss: 1.3837108612060547,Generater Loss: 0.711113452911377


 16%|█▌        | 121/750 [01:28<08:57,  1.17it/s]

Discriminator Loss: 1.3835926055908203,Generater Loss: 0.7105957269668579


 17%|█▋        | 131/750 [01:35<06:45,  1.53it/s]

Discriminator Loss: 1.3834363222122192,Generater Loss: 0.7100093364715576


 19%|█▉        | 141/750 [01:42<08:15,  1.23it/s]

Discriminator Loss: 1.3832476139068604,Generater Loss: 0.7093432545661926


 20%|██        | 151/750 [01:49<06:30,  1.53it/s]

Discriminator Loss: 1.3830084800720215,Generater Loss: 0.7088114619255066


 21%|██▏       | 161/750 [01:57<07:22,  1.33it/s]

Discriminator Loss: 1.3826565742492676,Generater Loss: 0.7081483006477356


 23%|██▎       | 171/750 [02:03<06:19,  1.52it/s]

Discriminator Loss: 1.3822720050811768,Generater Loss: 0.7076442241668701


 24%|██▍       | 181/750 [02:11<06:40,  1.42it/s]

Discriminator Loss: 1.3817880153656006,Generater Loss: 0.7071371078491211


 25%|██▌       | 191/750 [02:18<06:04,  1.53it/s]

Discriminator Loss: 1.381199836730957,Generater Loss: 0.7066367864608765


 27%|██▋       | 201/750 [02:25<06:11,  1.48it/s]

Discriminator Loss: 1.3805546760559082,Generater Loss: 0.7062922716140747


 28%|██▊       | 211/750 [02:32<07:27,  1.20it/s]

Discriminator Loss: 1.3797776699066162,Generater Loss: 0.7059836983680725


 29%|██▉       | 221/750 [02:40<05:55,  1.49it/s]

Discriminator Loss: 1.3790178298950195,Generater Loss: 0.7056843042373657


 31%|███       | 231/750 [02:56<13:47,  1.59s/it]

Discriminator Loss: 1.378110647201538,Generater Loss: 0.7054736018180847


 32%|███▏      | 241/750 [03:08<08:38,  1.02s/it]

Discriminator Loss: 1.377143144607544,Generater Loss: 0.7054073214530945


 33%|███▎      | 251/750 [03:16<06:10,  1.35it/s]

Discriminator Loss: 1.3759753704071045,Generater Loss: 0.7053278088569641


 35%|███▍      | 261/750 [03:22<05:25,  1.50it/s]

Discriminator Loss: 1.3748116493225098,Generater Loss: 0.7053204774856567


 36%|███▌      | 271/750 [03:30<05:42,  1.40it/s]

Discriminator Loss: 1.3735013008117676,Generater Loss: 0.7053980827331543


 37%|███▋      | 281/750 [03:38<06:06,  1.28it/s]

Discriminator Loss: 1.372166633605957,Generater Loss: 0.7055639028549194


 39%|███▉      | 291/750 [03:46<05:35,  1.37it/s]

Discriminator Loss: 1.3707094192504883,Generater Loss: 0.7056765556335449


 40%|████      | 301/750 [03:54<06:05,  1.23it/s]

Discriminator Loss: 1.3692514896392822,Generater Loss: 0.7058332562446594


 41%|████▏     | 311/750 [04:00<04:51,  1.50it/s]

Discriminator Loss: 1.3677680492401123,Generater Loss: 0.7060628533363342


 43%|████▎     | 321/750 [04:08<05:15,  1.36it/s]

Discriminator Loss: 1.3662712574005127,Generater Loss: 0.7061575651168823


 44%|████▍     | 331/750 [04:15<04:34,  1.53it/s]

Discriminator Loss: 1.3648622035980225,Generater Loss: 0.7061968445777893


 45%|████▌     | 341/750 [04:22<04:42,  1.45it/s]

Discriminator Loss: 1.3635368347167969,Generater Loss: 0.7061349749565125


 47%|████▋     | 351/750 [04:29<04:28,  1.49it/s]

Discriminator Loss: 1.3623838424682617,Generater Loss: 0.705812394618988


 48%|████▊     | 361/750 [04:37<04:27,  1.45it/s]

Discriminator Loss: 1.361501932144165,Generater Loss: 0.7051694989204407


 49%|████▉     | 371/750 [04:44<06:06,  1.04it/s]

Discriminator Loss: 1.361065149307251,Generater Loss: 0.7040795683860779


 51%|█████     | 381/750 [04:51<04:05,  1.51it/s]

Discriminator Loss: 1.361335277557373,Generater Loss: 0.7023307085037231


 52%|█████▏    | 391/750 [05:00<05:57,  1.00it/s]

Discriminator Loss: 1.3622257709503174,Generater Loss: 0.7000322937965393


 53%|█████▎    | 401/750 [05:11<07:31,  1.29s/it]

Discriminator Loss: 1.3642643690109253,Generater Loss: 0.6971389055252075


 55%|█████▍    | 411/750 [05:21<05:10,  1.09it/s]

Discriminator Loss: 1.3669568300247192,Generater Loss: 0.6932492256164551


 56%|█████▌    | 421/750 [05:29<03:50,  1.43it/s]

Discriminator Loss: 1.3708505630493164,Generater Loss: 0.6887689828872681


 57%|█████▋    | 431/750 [05:37<04:23,  1.21it/s]

Discriminator Loss: 1.37530517578125,Generater Loss: 0.6838189363479614


 59%|█████▉    | 441/750 [05:47<04:17,  1.20it/s]

Discriminator Loss: 1.3799173831939697,Generater Loss: 0.6786330938339233


 60%|██████    | 451/750 [06:00<05:55,  1.19s/it]

Discriminator Loss: 1.3852934837341309,Generater Loss: 0.6734722852706909


 61%|██████▏   | 461/750 [06:10<04:35,  1.05it/s]

Discriminator Loss: 1.390136480331421,Generater Loss: 0.6684565544128418


 63%|██████▎   | 471/750 [06:20<04:10,  1.11it/s]

Discriminator Loss: 1.3944728374481201,Generater Loss: 0.6643355488777161


 64%|██████▍   | 481/750 [06:29<03:57,  1.13it/s]

Discriminator Loss: 1.3978028297424316,Generater Loss: 0.6609973907470703


 65%|██████▌   | 491/750 [06:38<03:29,  1.23it/s]

Discriminator Loss: 1.4001094102859497,Generater Loss: 0.6590545773506165


 67%|██████▋   | 501/750 [06:46<03:18,  1.26it/s]

Discriminator Loss: 1.4015097618103027,Generater Loss: 0.6575215458869934


 68%|██████▊   | 511/750 [06:52<02:34,  1.54it/s]

Discriminator Loss: 1.4026846885681152,Generater Loss: 0.6565098762512207


 69%|██████▉   | 521/750 [07:00<02:48,  1.36it/s]

Discriminator Loss: 1.4041911363601685,Generater Loss: 0.6549742221832275


 71%|███████   | 531/750 [07:06<02:26,  1.49it/s]

Discriminator Loss: 1.4067192077636719,Generater Loss: 0.6532664895057678


 72%|███████▏  | 541/750 [07:14<02:23,  1.46it/s]

Discriminator Loss: 1.4092166423797607,Generater Loss: 0.6516748666763306


 73%|███████▎  | 551/750 [07:21<02:09,  1.53it/s]

Discriminator Loss: 1.411320447921753,Generater Loss: 0.6501348614692688


 75%|███████▍  | 561/750 [07:28<02:08,  1.47it/s]

Discriminator Loss: 1.4129728078842163,Generater Loss: 0.6493032574653625


 76%|███████▌  | 571/750 [07:36<02:28,  1.21it/s]

Discriminator Loss: 1.413665533065796,Generater Loss: 0.6492605209350586


 77%|███████▋  | 581/750 [07:43<01:53,  1.49it/s]

Discriminator Loss: 1.4138044118881226,Generater Loss: 0.6498432755470276


 79%|███████▉  | 591/750 [07:50<02:13,  1.19it/s]

Discriminator Loss: 1.413116693496704,Generater Loss: 0.6510093808174133


 80%|████████  | 601/750 [07:57<01:39,  1.50it/s]

Discriminator Loss: 1.4119908809661865,Generater Loss: 0.6526111364364624


 81%|████████▏ | 611/750 [08:05<01:42,  1.36it/s]

Discriminator Loss: 1.4102771282196045,Generater Loss: 0.6546124815940857


 83%|████████▎ | 621/750 [08:12<01:27,  1.47it/s]

Discriminator Loss: 1.4081785678863525,Generater Loss: 0.6569646596908569


 84%|████████▍ | 631/750 [08:19<01:21,  1.46it/s]

Discriminator Loss: 1.4057502746582031,Generater Loss: 0.6594735980033875


 85%|████████▌ | 641/750 [08:26<01:14,  1.47it/s]

Discriminator Loss: 1.4030691385269165,Generater Loss: 0.6622321605682373


 87%|████████▋ | 651/750 [08:34<01:07,  1.47it/s]

Discriminator Loss: 1.4001429080963135,Generater Loss: 0.6650701761245728


 88%|████████▊ | 661/750 [08:41<01:17,  1.15it/s]

Discriminator Loss: 1.3972176313400269,Generater Loss: 0.6679739356040955


 89%|████████▉ | 671/750 [08:48<00:53,  1.49it/s]

Discriminator Loss: 1.3942019939422607,Generater Loss: 0.6708752512931824


 91%|█████████ | 681/750 [08:56<00:55,  1.25it/s]

Discriminator Loss: 1.3912725448608398,Generater Loss: 0.6736670136451721


 92%|█████████▏| 691/750 [09:02<00:39,  1.48it/s]

Discriminator Loss: 1.3885314464569092,Generater Loss: 0.6762608885765076


 93%|█████████▎| 701/750 [09:10<00:36,  1.36it/s]

Discriminator Loss: 1.3864160776138306,Generater Loss: 0.6783101558685303


 95%|█████████▍| 711/750 [09:17<00:25,  1.51it/s]

Discriminator Loss: 1.3851897716522217,Generater Loss: 0.6798689365386963


 96%|█████████▌| 721/750 [09:24<00:20,  1.45it/s]

Discriminator Loss: 1.3845922946929932,Generater Loss: 0.6814018487930298


 97%|█████████▋| 731/750 [09:31<00:12,  1.50it/s]

Discriminator Loss: 1.3841367959976196,Generater Loss: 0.6831267476081848


 99%|█████████▉| 741/750 [09:38<00:06,  1.47it/s]

Discriminator Loss: 1.3833162784576416,Generater Loss: 0.6853542923927307


100%|██████████| 750/750 [09:46<00:00,  1.28it/s]

Discriminator Loss: 1.3823182582855225,Generater Loss: 0.6878888010978699





In [154]:
import lightgbm as lgb
lgb_rg = lgb.LGBMRegressor(max_depth = 100, n_estimators = 100)

In [155]:
lgb_rg.fit(X_train[:,:-1], X_train[:,-1])

You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 1837
[LightGBM] [Info] Number of data points in the train set: 16512, number of used features: 8
[LightGBM] [Info] Start training from score 2.020670


In [156]:
pred = lgb_rg.predict(X_test[:,:-1])



In [157]:
(pred - X_test[:,-1]).std()
#元のデータでの予測

0.4670865970792997

In [158]:
input_z = create_noise(col, row,)
aug_data = gene(input_z).detach().cpu().numpy()

In [159]:
lgb_rg = lgb.LGBMRegressor(max_depth = 100, n_estimators = 100)
lgb_rg.fit(aug_data[:,:-1], aug_data[:,-1])

You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2040
[LightGBM] [Info] Number of data points in the train set: 16512, number of used features: 8
[LightGBM] [Info] Start training from score 3.403901


In [160]:
pred = lgb_rg.predict(X_test[:,:-1])
(pred - X_test[:,-1]).std()
#増殖したデータのみでの予測



0.6502803279115924

In [161]:
#さらに増やす
input_z = create_noise(30000, row,)
aug_data = gene(input_z).detach().cpu().numpy()
lgb_rg = lgb.LGBMRegressor(max_depth = 100, n_estimators = 100)
lgb_rg.fit(aug_data[:,:-1], aug_data[:,-1])

You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2040
[LightGBM] [Info] Number of data points in the train set: 30000, number of used features: 8
[LightGBM] [Info] Start training from score 3.403906


In [162]:
pred = lgb_rg.predict(X_test[:,:-1])
(pred - X_test[:,-1]).std()
#増殖したデータのみでの予測



0.6483851207250447

In [167]:
#さらに増やす
input_z = create_noise(50000, row,)
aug_data = gene(input_z).detach().cpu().numpy()
lgb_rg = lgb.LGBMRegressor(max_depth = 100, n_estimators = 100)
lgb_rg.fit(aug_data[:,:-1], aug_data[:,-1])

You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2040
[LightGBM] [Info] Number of data points in the train set: 50000, number of used features: 8
[LightGBM] [Info] Start training from score 3.404857


In [168]:
pred = lgb_rg.predict(X_test[:,:-1])
(pred - X_test[:,-1]).std()
#増殖したデータのみでの予測



0.6511483299173231