## Linear data to be consumed by a neural network

build a classification network and perform binary classification on the result data

In [13]:
import os
import glob

from tqdm import tqdm
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn
from torch.nn import functional as F

## 1. Load data

load data and do some explorative data analysis

- Todo normalize the data (use a rescaler from scikit)
- Which features do have which values

In [14]:
def moveTo(obj, device): 
    """ 
    obj: the python object to move to a device, or to move its
    ➥ contents to a device
    device: the compute device to move objects to 
    """
    if isinstance(obj, list): 
        return [moveTo(x, device) for x in obj] 
    elif isinstance(obj, tuple): 
        return tuple(moveTo(list(obj), device)) 
    elif isinstance(obj, set): 
        return set(moveTo(list(obj), device)) 
    elif isinstance(obj, dict): 
        to_ret = dict() 
        for key, value in obj.items(): 
            to_ret[moveTo(key, device)] = moveTo(value, device) 
        return to_ret 
    elif hasattr(obj, "to"): 
        return obj.to(device) 
    else: 
        return obj

In [15]:
dataset_path = r'C:\data\wsc'
df_filepaths = glob.glob(os.path.join(dataset_path, '**/*.csv'), recursive=True)

In [16]:
def extract_data_from_csv(csv_filepath: str):
    df = pd.read_csv(df_filepaths[0], index_col=False)
    df = df.drop(columns=['Unnamed: 0'])
    x = df.drop(columns=['is_ws']).to_numpy()
    y = df['is_ws'].to_numpy()
    return x, y

In [17]:
x_data = []
y_data = []

for df_path in df_filepaths:
    x, y = extract_data_from_csv(df_path)
    x_data.append(x)
    y_data.append(y)

x_data = np.concatenate(x_data)
y_data = np.concatenate(y_data)

## 2. Modeling

In [18]:
class WSDataset(Dataset):
  def __init__(self,x,y):
    # self.x = torch.tensor(x,dtype=torch.float32)
    # self.y = torch.tensor(y,dtype=torch.float32)
    input_shape = x.shape[1]
    self.x = x
    self.y = y.reshape(-1,1) 
    self.length = self.x.shape[0]

  def __getitem__(self,idx):
    x_raw = self.x
    x_tmp = self.x[idx, :]
    return torch.tensor(self.x[idx,:], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.float32)

  def __len__(self): 
    return self.x.shape[0] 

In [19]:
trainset = WSDataset(x_data, y_data)
trainloader = DataLoader(trainset, batch_size=64, shuffle=False)

In [20]:
class Net(nn.Module):
  def __init__(self,input_shape):
    super(Net,self).__init__()
    self.fc1 = nn.Linear(input_shape,32)
    self.fc2 = nn.Linear(32,64)
    self.fc3 = nn.Linear(64,1)
  def forward(self,x):
    x = torch.relu(self.fc1(x))
    x = torch.relu(self.fc2(x))
    x = torch.sigmoid(self.fc3(x))
    return x

In [21]:
learning_rate = 0.01
epochs = 700
# Model , Optimizer, Loss
input_shape=x.shape[1]
model = Net(input_shape=x.shape[1])
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)
loss_fn = nn.BCELoss()

In [22]:
for x_sample, y_sample in trainloader:
    output = model(x_sample)
    break

In [23]:
def train_simple_network(model, loss_func, training_loader, epochs=20, device="cpu"): 
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)                    

    model.to(device)                                                             

    for epoch in tqdm(range(epochs), desc="Epoch"):                              
        model = model.train()                                                    
        running_loss = 0.0

        for inputs, labels in tqdm(training_loader, desc="Batch", leave=False):  
            inputs = moveTo(inputs, device)
            # labels = np.expand_dims(labels, axis=-1)                                      
            labels = moveTo(labels, device)                                      

            optimizer.zero_grad()                                                

            y_hat = model(inputs)
            
            shape_y_hat = y_hat.shape
            shape_labels = labels.shape
            
            # print(shape_y_hat)
            # print(shape_labels)

            loss = loss_func(y_hat, labels)                                      
            loss.backward()                                                      
            optimizer.step()                                                     
            running_loss += loss.item()    
        print(running_loss)

In [26]:
device = torch.device('cuda')
train_simple_network(model, loss_fn, trainloader, device=device, epochs=100)

Epoch:   1%|          | 1/100 [00:01<02:25,  1.47s/it]

0.19736407459867783


Epoch:   2%|▏         | 2/100 [00:01<01:11,  1.37it/s]

0.1952678913526062


Epoch:   3%|▎         | 3/100 [00:01<00:48,  2.02it/s]

0.1937608952017476


Epoch:   4%|▍         | 4/100 [00:02<00:39,  2.45it/s]

0.19201531552015777


Epoch:   5%|▌         | 5/100 [00:02<00:34,  2.72it/s]

0.1904689801268885


Epoch:   6%|▌         | 6/100 [00:02<00:30,  3.12it/s]

0.1887412894670819


Epoch:   7%|▋         | 7/100 [00:02<00:27,  3.37it/s]

0.1875862918575511


Epoch:   8%|▊         | 8/100 [00:03<00:25,  3.58it/s]

0.18595852481969383


Epoch:   9%|▉         | 9/100 [00:03<00:25,  3.60it/s]

0.18467941871402388


Epoch:  10%|█         | 10/100 [00:03<00:25,  3.58it/s]

0.18287133231572533


Epoch:  11%|█         | 11/100 [00:04<00:24,  3.63it/s]

0.18139278346695079


Epoch:  12%|█▏        | 12/100 [00:04<00:24,  3.63it/s]

0.17977833718658734


Epoch:  13%|█▎        | 13/100 [00:04<00:23,  3.73it/s]

0.17826376373316533


Epoch:  14%|█▍        | 14/100 [00:04<00:23,  3.59it/s]

0.17690969372158408


Epoch:  15%|█▌        | 15/100 [00:05<00:24,  3.46it/s]

0.17550821477455025


Epoch:  16%|█▌        | 16/100 [00:05<00:25,  3.35it/s]

0.174016679276578


Epoch:  17%|█▋        | 17/100 [00:05<00:25,  3.20it/s]

0.17271864825138392


Epoch:  18%|█▊        | 18/100 [00:06<00:24,  3.35it/s]

0.17142807368490495


Epoch:  19%|█▉        | 19/100 [00:06<00:23,  3.45it/s]

0.170383833739898


Epoch:  20%|██        | 20/100 [00:06<00:22,  3.54it/s]

0.16923679058594304


Epoch:  21%|██        | 21/100 [00:06<00:21,  3.64it/s]

0.16810861849614506


Epoch:  22%|██▏       | 22/100 [00:07<00:20,  3.72it/s]

0.16662116739022254


Epoch:  23%|██▎       | 23/100 [00:07<00:20,  3.78it/s]

0.16534394094957172


Epoch:  24%|██▍       | 24/100 [00:07<00:20,  3.74it/s]

0.16425288540646454


Epoch:  25%|██▌       | 25/100 [00:07<00:19,  3.78it/s]

0.1630528161396823


Epoch:  26%|██▌       | 26/100 [00:08<00:19,  3.76it/s]

0.16195210480160002


Epoch:  27%|██▋       | 27/100 [00:08<00:20,  3.58it/s]

0.16060582550100255


Epoch:  28%|██▊       | 28/100 [00:08<00:22,  3.25it/s]

0.15958414592220688


Epoch:  29%|██▉       | 29/100 [00:09<00:20,  3.38it/s]

0.15847232718945625


Epoch:  30%|███       | 30/100 [00:09<00:19,  3.55it/s]

0.15723468670067248


Epoch:  31%|███       | 31/100 [00:09<00:18,  3.66it/s]

0.15584014599999846


Epoch:  32%|███▏      | 32/100 [00:09<00:18,  3.75it/s]

0.15480196813655744


Epoch:  33%|███▎      | 33/100 [00:10<00:17,  3.75it/s]

0.15359138267200523


Epoch:  34%|███▍      | 34/100 [00:10<00:17,  3.82it/s]

0.15263369202435365


Epoch:  35%|███▌      | 35/100 [00:10<00:16,  3.87it/s]

0.1514990748245403


Epoch:  36%|███▌      | 36/100 [00:10<00:16,  3.85it/s]

0.1505088630483426


Epoch:  37%|███▋      | 37/100 [00:11<00:16,  3.84it/s]

0.14948398150752987


Epoch:  38%|███▊      | 38/100 [00:11<00:18,  3.32it/s]

0.14840950219426607


Epoch:  39%|███▉      | 39/100 [00:11<00:18,  3.34it/s]

0.14732777509379652


Epoch:  40%|████      | 40/100 [00:12<00:17,  3.51it/s]

0.14632794614771116


Epoch:  41%|████      | 41/100 [00:12<00:16,  3.64it/s]

0.14504753171617157


Epoch:  42%|████▏     | 42/100 [00:12<00:18,  3.18it/s]

0.14412863659474842


Epoch:  43%|████▎     | 43/100 [00:13<00:18,  3.01it/s]

0.14303431000375932


Epoch:  44%|████▍     | 44/100 [00:13<00:17,  3.21it/s]

0.1421483418680858


Epoch:  45%|████▌     | 45/100 [00:13<00:16,  3.38it/s]

0.14123401032089775


Epoch:  46%|████▌     | 46/100 [00:13<00:15,  3.52it/s]

0.14028856824711222


Epoch:  47%|████▋     | 47/100 [00:14<00:14,  3.58it/s]

0.1393359245176963


Epoch:  48%|████▊     | 48/100 [00:14<00:14,  3.62it/s]

0.13833394324251932


Epoch:  49%|████▉     | 49/100 [00:14<00:13,  3.70it/s]

0.13741438660490374


Epoch:  50%|█████     | 50/100 [00:15<00:13,  3.66it/s]

0.13633125126560958


Epoch:  51%|█████     | 51/100 [00:15<00:14,  3.30it/s]

0.1355006782941714


Epoch:  52%|█████▏    | 52/100 [00:15<00:14,  3.42it/s]

0.13475302728733307


Epoch:  53%|█████▎    | 53/100 [00:15<00:13,  3.53it/s]

0.13406870075836483


Epoch:  54%|█████▍    | 54/100 [00:16<00:12,  3.61it/s]

0.13308846008705516


Epoch:  55%|█████▌    | 55/100 [00:16<00:12,  3.67it/s]

0.13229719192692538


Epoch:  56%|█████▌    | 56/100 [00:16<00:13,  3.20it/s]

0.13147835672742322


Epoch:  57%|█████▋    | 57/100 [00:17<00:13,  3.30it/s]

0.13041966608634908


Epoch:  58%|█████▊    | 58/100 [00:17<00:12,  3.46it/s]

0.1296624728373475


Epoch:  59%|█████▉    | 59/100 [00:17<00:11,  3.62it/s]

0.12879510414128537


Epoch:  60%|██████    | 60/100 [00:17<00:10,  3.72it/s]

0.12803528438289105


Epoch:  61%|██████    | 61/100 [00:18<00:10,  3.85it/s]

0.12729600998720525


Epoch:  62%|██████▏   | 62/100 [00:18<00:09,  3.86it/s]

0.12656281474636633


Epoch:  63%|██████▎   | 63/100 [00:18<00:10,  3.56it/s]

0.12569294292138017


Epoch:  64%|██████▍   | 64/100 [00:18<00:10,  3.57it/s]

0.12503589814830196


Epoch:  65%|██████▌   | 65/100 [00:19<00:09,  3.63it/s]

0.12421237638333167


Epoch:  66%|██████▌   | 66/100 [00:19<00:09,  3.67it/s]

0.12334432867446576


Epoch:  67%|██████▋   | 67/100 [00:19<00:08,  3.71it/s]

0.1228067600448275


Epoch:  68%|██████▊   | 68/100 [00:20<00:08,  3.75it/s]

0.12190922492629584


Epoch:  69%|██████▉   | 69/100 [00:20<00:08,  3.74it/s]

0.1209426964667289


Epoch:  70%|███████   | 70/100 [00:20<00:07,  3.76it/s]

0.12037613706790884


Epoch:  71%|███████   | 71/100 [00:20<00:07,  3.88it/s]

0.11961024895591091


Epoch:  72%|███████▏  | 72/100 [00:21<00:07,  3.52it/s]

0.11881314260684173


Epoch:  73%|███████▎  | 73/100 [00:21<00:08,  3.25it/s]

0.1180564960033839


Epoch:  74%|███████▍  | 74/100 [00:21<00:07,  3.46it/s]

0.11745991417878021


Epoch:  75%|███████▌  | 75/100 [00:22<00:07,  3.33it/s]

0.11676423283961795


Epoch:  76%|███████▌  | 76/100 [00:22<00:08,  3.00it/s]

0.1161051836285446


Epoch:  77%|███████▋  | 77/100 [00:22<00:07,  3.23it/s]

0.11544746300981294


Epoch:  78%|███████▊  | 78/100 [00:23<00:06,  3.37it/s]

0.11460920957375886


Epoch:  79%|███████▉  | 79/100 [00:23<00:05,  3.52it/s]

0.11402943612960134


Epoch:  80%|████████  | 80/100 [00:23<00:05,  3.62it/s]

0.11331900647519481


Epoch:  81%|████████  | 81/100 [00:23<00:05,  3.66it/s]

0.11268373833062648


Epoch:  82%|████████▏ | 82/100 [00:24<00:04,  3.68it/s]

0.11181914283714142


Epoch:  83%|████████▎ | 83/100 [00:24<00:04,  3.73it/s]

0.11137688407522194


Epoch:  84%|████████▍ | 84/100 [00:24<00:04,  3.77it/s]

0.11054900909243387


Epoch:  85%|████████▌ | 85/100 [00:24<00:03,  3.78it/s]

0.11011528546303395


Epoch:  86%|████████▌ | 86/100 [00:25<00:03,  3.63it/s]

0.10948509348335494


Epoch:  87%|████████▋ | 87/100 [00:25<00:04,  3.16it/s]

0.10872506968325808


Epoch:  88%|████████▊ | 88/100 [00:25<00:03,  3.30it/s]

0.10824229349380896


Epoch:  89%|████████▉ | 89/100 [00:26<00:03,  3.12it/s]

0.10748317615353573


Epoch:  90%|█████████ | 90/100 [00:26<00:03,  2.98it/s]

0.10691079952715324


Epoch:  91%|█████████ | 91/100 [00:26<00:02,  3.25it/s]

0.10631626230576952


Epoch:  92%|█████████▏| 92/100 [00:27<00:02,  3.40it/s]

0.10569581069037097


Epoch:  93%|█████████▎| 93/100 [00:27<00:02,  3.45it/s]

0.10520156629662006


Epoch:  94%|█████████▍| 94/100 [00:27<00:01,  3.52it/s]

0.10439895844466623


Epoch:  95%|█████████▌| 95/100 [00:27<00:01,  3.65it/s]

0.10392081188042952


Epoch:  96%|█████████▌| 96/100 [00:28<00:01,  3.70it/s]

0.10319210305112693


Epoch:  97%|█████████▋| 97/100 [00:28<00:00,  3.65it/s]

0.10276124921138322


Epoch:  98%|█████████▊| 98/100 [00:28<00:00,  3.64it/s]

0.10221620654717972


Epoch:  99%|█████████▉| 99/100 [00:29<00:00,  3.44it/s]

0.10165874545378206


Epoch: 100%|██████████| 100/100 [00:29<00:00,  3.40it/s]

0.1011314645112194



